tp: add config objects for trace processor and batch trace processor

This CL removes ad-hoc parameters being passed to contructor to instead
passing a config object which can override various functionality: this
will replace loader_vendor once G3 migrates to using just this.

Change-Id: Ibbf3960106726a00cb332bb3043b75e00681a324
Bug: 180499808
diff --git a/CHANGELOG b/CHANGELOG
index 19f75ba..c931c67 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -2,7 +2,13 @@
   Tracing service and probes:
     *
   Trace Processor:
-    *
+    * The argument for the trace path in constructor of TraceProcessor
+      in the Python API taks |trace| is renamed from |file_path| to |trace|.
+      |file_path| is deprecated and may be removed in the future.
+    * The Python API now takes a TraceProcessorConfig in the constructor
+      instead of passing parameters directly. This may break existing code
+      but migration should be trivial (all current options are still
+      supported).
   UI:
     *
   SDK:
diff --git a/docs/analysis/trace-processor.md b/docs/analysis/trace-processor.md
index 3606c00..2385175 100644
--- a/docs/analysis/trace-processor.md
+++ b/docs/analysis/trace-processor.md
@@ -583,17 +583,17 @@
 ```python
 from perfetto.trace_processor import TraceProcessor
 # Initialise TraceProcessor with a trace file
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 ```
 
 NOTE: The TraceProcessor can be initialized in a combination of ways including:
       <br> - An address at which there exists a running instance of `trace_processor` with a
-      loaded trace (e.g. `TraceProcessor(addr='localhost:9001')`)
+      loaded trace (e.g.`TraceProcessor(addr='localhost:9001')`)
       <br> - An address at which there exists a running instance of `trace_processor` and
       needs a trace to be loaded in
-      (e.g. `TraceProcessor(addr='localhost:9001', file_path='trace.perfetto-trace')`)
+      (e.g. `TraceProcessor(trace='trace.perfetto-trace', addr='localhost:9001')`)
       <br> - A path to a `trace_processor` binary and the trace to be loaded in
-      (e.g. `TraceProcessor(bin_path='./trace_processor', file_path='trace.perfetto-trace')`)
+      (e.g. `TraceProcessor(trace='trace.perfetto-trace', config=TraceProcessorConfig(bin_path='./trace_processor'))`)
 
 
 ### API
@@ -608,7 +608,7 @@
 
 ```python
 from perfetto.trace_processor import TraceProcessor
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 
 qr_it = tp.query('SELECT ts, dur, name FROM slice')
 for row in qr_it:
@@ -627,7 +627,7 @@
 requires you to have both the `NumPy` and `Pandas` modules installed.
 ```python
 from perfetto.trace_processor import TraceProcessor
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 
 qr_it = tp.query('SELECT ts, dur, name FROM slice')
 qr_df = qr_it.as_pandas_dataframe()
@@ -648,7 +648,7 @@
 make visualisations from the trace data.
 ```python
 from perfetto.trace_processor import TraceProcessor
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 
 qr_it = tp.query('SELECT ts, value FROM counter WHERE track_id=50')
 qr_df = qr_it.as_pandas_dataframe()
@@ -665,7 +665,7 @@
 
 ```python
 from perfetto.trace_processor import TraceProcessor
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 
 ad_cpu_metrics = tp.metric(['android_cpu'])
 print(ad_cpu_metrics)
diff --git a/docs/quickstart/trace-analysis.md b/docs/quickstart/trace-analysis.md
index 7141394..62269dc 100644
--- a/docs/quickstart/trace-analysis.md
+++ b/docs/quickstart/trace-analysis.md
@@ -334,7 +334,7 @@
 #### Query
 ```python
 from perfetto.trace_processor import TraceProcessor
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 
 qr_it = tp.query('SELECT name FROM slice')
 for row in qr_it:
@@ -352,7 +352,7 @@
 #### Query as Pandas DataFrame
 ```python
 from perfetto.trace_processor import TraceProcessor
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 
 qr_it = tp.query('SELECT ts, name FROM slice')
 qr_df = qr_it.as_pandas_dataframe()
@@ -372,7 +372,7 @@
 #### Metric
 ```python
 from perfetto.trace_processor import TraceProcessor
-tp = TraceProcessor(file_path='trace.perfetto-trace')
+tp = TraceProcessor(trace='trace.perfetto-trace')
 
 cpu_metrics = tp.metric(['android_cpu'])
 print(cpu_metrics)
diff --git a/src/trace_processor/python/example.py b/src/trace_processor/python/example.py
index b3925a4..ffdd9ac 100644
--- a/src/trace_processor/python/example.py
+++ b/src/trace_processor/python/example.py
@@ -15,7 +15,7 @@
 
 import argparse
 
-from perfetto.trace_processor import TraceProcessor
+from perfetto.trace_processor import TraceProcessor, TraceProcessorConfig
 
 
 def main():
@@ -34,16 +34,17 @@
   parser.add_argument("-f", "--file", help="Absolute path to trace", type=str)
   args = parser.parse_args()
 
+  config = TraceProcessorConfig(bin_path=args.binary)
+
   # Pass arguments into api to construct the trace processor and load the trace
   if args.address is None and args.file is None:
     raise Exception("You must specify an address or a file path to trace")
   elif args.address is None:
-    tp = TraceProcessor(file_path=args.file, bin_path=args.binary)
+    tp = TraceProcessor(trace=args.file, config=config)
   elif args.file is None:
-    tp = TraceProcessor(addr=args.address)
+    tp = TraceProcessor(addr=args.address, config=config)
   else:
-    tp = TraceProcessor(
-        addr=args.address, file_path=args.file, bin_path=args.binary)
+    tp = TraceProcessor(trace=args.file, addr=args.address, config=config)
 
   # Iterate through QueryResultIterator
   res_it = tp.query('select * from slice limit 10')
diff --git a/src/trace_processor/python/perfetto/trace_processor/__init__.py b/src/trace_processor/python/perfetto/trace_processor/__init__.py
index 7106a6c..ad09ce6 100644
--- a/src/trace_processor/python/perfetto/trace_processor/__init__.py
+++ b/src/trace_processor/python/perfetto/trace_processor/__init__.py
@@ -13,5 +13,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .api import TraceProcessor, TraceProcessorException
+from .api import LoadableTrace, TraceProcessor, TraceProcessorConfig, TraceProcessorException
 from .http import TraceProcessorHttp
diff --git a/src/trace_processor/python/perfetto/trace_processor/api.py b/src/trace_processor/python/perfetto/trace_processor/api.py
index c52e628..0b67632 100644
--- a/src/trace_processor/python/perfetto/trace_processor/api.py
+++ b/src/trace_processor/python/perfetto/trace_processor/api.py
@@ -12,8 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import dataclasses as dc
+from enum import unique
 from urllib.parse import urlparse
-from typing import BinaryIO, Generator, List, Optional, Union
+from typing import BinaryIO, Callable, Generator, List, Optional, Tuple, Union
 
 from .http import TraceProcessorHttp
 from .loader import get_loader
@@ -31,6 +33,53 @@
     super().__init__(message)
 
 
+@dc.dataclass
+class TraceProcessorConfig:
+  bin_path: Optional[str]
+  unique_port: bool
+  verbose: bool
+
+  read_tp_descriptor: Callable[[], bytes]
+  read_metrics_descriptor: Callable[[], bytes]
+  parse_file: Callable[[TraceProcessorHttp, str], TraceProcessorHttp]
+  get_shell_path: Callable[[str], None]
+  get_free_port: Callable[[bool], Tuple[str, str]]
+
+  def __init__(
+      self,
+      bin_path: Optional[str] = None,
+      unique_port: bool = True,
+      verbose: bool = False,
+      read_tp_descriptor: Callable[[], bytes] = get_loader().read_tp_descriptor,
+      read_metrics_descriptor: Callable[[], bytes] = get_loader(
+      ).read_metrics_descriptor,
+      parse_file: Callable[[TraceProcessorHttp, str],
+                           TraceProcessorHttp] = get_loader().parse_file,
+      get_shell_path: Callable[[str], None] = get_loader().get_shell_path,
+      get_free_port: Callable[[bool], Tuple[str, str]] = get_loader(
+      ).get_free_port):
+    self.bin_path = bin_path
+    self.unique_port = unique_port
+    self.verbose = verbose
+
+    self.read_tp_descriptor = read_tp_descriptor
+    self.read_metrics_descriptor = read_metrics_descriptor
+    self.parse_file = parse_file
+    self.get_shell_path = get_shell_path
+    self.get_free_port = get_free_port
+
+    try:
+      # This is the only place in trace processor which should import
+      # from a "vendor" namespace - the purpose of this code is to allow
+      # for users to set their own "default" config for trace processor
+      # without needing to specify the config in every place when trace
+      # processor is used.
+      from .vendor import override_default_tp_config
+      return override_default_tp_config(self)
+    except ModuleNotFoundError:
+      pass
+
+
 class TraceProcessor:
 
   # Values of these constants correspond to the QueryResponse message at
@@ -181,52 +230,54 @@
   def __init__(self,
                trace: LoadableTrace = None,
                addr: Optional[str] = None,
-               bin_path: Optional[str] = None,
-               unique_port: bool = True,
-               verbose: bool = False,
+               config: TraceProcessorConfig = TraceProcessorConfig(),
                file_path: Optional[str] = None):
     """Create a trace processor instance.
 
     Args:
-      trace: Trace to be loaded into the trace processor instance. One of
+      trace: trace to be loaded into the trace processor instance. One of
         three types of argument is supported:
         1) path to a trace file to open and read
         2) a file like object (file, io.BytesIO or similar) to read
         3) a generator yielding bytes
-      addr: address of a running trace processor instance. For advanced
-        use only.
-      bin_path: path to a trace processor shell binary. For advanced use
-        only.
-      unique_port: whether the trace processor shell instance should be
-        be started on a unique port. Only used when |addr| is not set.
-        For advanced use only.
-      verbose: whether trace processor shell should emit verbose logs;
-        can be very spammy. For advanced use only.
-      file_path (deprecated): path to a trace file to load. Please use
+        4) a custom string format which can be understood by
+           TraceProcessorConfig.parse_file function. The default
+           implementation of this function only supports file paths (i.e. option
+           1) but callers can choose to change the implementation to parse
+           a custom string format and use that to retrieve a race.
+      addr: address of a running trace processor instance. Useful to query an
+        already loaded trace.
+      config: configuration options which customize functionality of trace
+        processor and the Python binding.
+      file_path (deprecated): path to a trace file to load. Use
         |trace| instead of this field: specifying both will cause
         an exception to be thrown.
     """
 
-    def create_tp_http():
+    def create_tp_http(protos: ProtoFactory) -> TraceProcessorHttp:
       if addr:
         p = urlparse(addr)
-        return TraceProcessorHttp(p.netloc if p.netloc else p.path)
+        return TraceProcessorHttp(
+            p.netloc if p.netloc else p.path, protos=protos)
 
       url, self.subprocess = load_shell(
-          bin_path=bin_path, unique_port=unique_port, verbose=verbose)
-      return TraceProcessorHttp(url)
+          bin_path=config.bin_path,
+          unique_port=config.unique_port,
+          verbose=config.verbose)
+      return TraceProcessorHttp(url, protos=protos)
 
     if trace and file_path:
       raise TraceProcessorException(
           "trace and file_path cannot both be specified.")
 
-    self.http = create_tp_http()
-    self.protos = ProtoFactory()
+    self.protos = ProtoFactory(config.read_tp_descriptor(),
+                               config.read_metrics_descriptor())
+    self.http = create_tp_http(self.protos)
 
     if file_path:
-      get_loader().parse_file(self.http, file_path)
+      config.parse_file(self.http, file_path)
     elif isinstance(trace, str):
-      get_loader().parse_file(self.http, trace)
+      config.parse_file(self.http, trace)
     elif hasattr(trace, 'read'):
       while True:
         chunk = trace.read(32 * 1024 * 1024)
diff --git a/src/trace_processor/python/perfetto/trace_processor/http.py b/src/trace_processor/python/perfetto/trace_processor/http.py
index bf751f9..f3cbfb5 100644
--- a/src/trace_processor/python/perfetto/trace_processor/http.py
+++ b/src/trace_processor/python/perfetto/trace_processor/http.py
@@ -14,17 +14,18 @@
 # limitations under the License.
 
 import http.client
+from typing import List
 
 from .protos import ProtoFactory
 
 
 class TraceProcessorHttp:
 
-  def __init__(self, url):
-    self.protos = ProtoFactory()
+  def __init__(self, url: str, protos: ProtoFactory):
+    self.protos = protos
     self.conn = http.client.HTTPConnection(url)
 
-  def execute_query(self, query):
+  def execute_query(self, query: str):
     args = self.protos.RawQueryArgs()
     args.sql_query = query
     byte_data = args.SerializeToString()
@@ -34,7 +35,7 @@
       result.ParseFromString(f.read())
       return result
 
-  def compute_metric(self, metrics):
+  def compute_metric(self, metrics: List[str]):
     args = self.protos.ComputeMetricArgs()
     args.metric_names.extend(metrics)
     byte_data = args.SerializeToString()
@@ -44,7 +45,7 @@
       result.ParseFromString(f.read())
       return result
 
-  def parse(self, chunk):
+  def parse(self, chunk: bytes):
     self.conn.request('POST', '/parse', body=chunk)
     with self.conn.getresponse() as f:
       return f.read()
diff --git a/src/trace_processor/python/perfetto/trace_processor/loader.py b/src/trace_processor/python/perfetto/trace_processor/loader.py
index e57145f..0a1b16b 100644
--- a/src/trace_processor/python/perfetto/trace_processor/loader.py
+++ b/src/trace_processor/python/perfetto/trace_processor/loader.py
@@ -53,7 +53,7 @@
     tp_http.notify_eof()
     return tp_http
 
-  def get_shell_path(bin_path=None):
+  def get_shell_path(bin_path):
     # Try to use preexisting binary before attempting to download
     # trace_processor
     if bin_path is None:
@@ -68,7 +68,7 @@
         raise Exception('Path to binary is not valid')
       return bin_path
 
-  def get_free_port(unique_port=False):
+  def get_free_port(unique_port):
     if not unique_port:
       return LoaderStandalone.TP_PORT, f'localhost:{LoaderStandalone.TP_PORT}'
     free_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -80,6 +80,8 @@
 
 
 # Return vendor class if it exists before falling back on LoaderStandalone
+# TODO(lalitm): remove this after migrating all consumers to
+# TraceProcessorConfig.
 def get_loader():
   try:
     from .loader_vendor import LoaderVendor
diff --git a/src/trace_processor/python/perfetto/trace_processor/protos.py b/src/trace_processor/python/perfetto/trace_processor/protos.py
index b5e3700..37be4f2 100644
--- a/src/trace_processor/python/perfetto/trace_processor/protos.py
+++ b/src/trace_processor/python/perfetto/trace_processor/protos.py
@@ -16,27 +16,24 @@
 from google.protobuf import message_factory
 from google.protobuf.descriptor_pool import DescriptorPool
 
-from .loader import get_loader
-
 
 class ProtoFactory:
 
-  def __init__(self):
+  def __init__(self, tp_descriptor: bytes,
+               metrics_descriptor: bytes):
     # Declare descriptor pool
     self.descriptor_pool = DescriptorPool()
 
     # Load trace processor descriptor and add to descriptor pool
-    tp_descriptor_bytes = get_loader().read_tp_descriptor()
     tp_file_desc_set_pb2 = descriptor_pb2.FileDescriptorSet()
-    tp_file_desc_set_pb2.MergeFromString(tp_descriptor_bytes)
+    tp_file_desc_set_pb2.MergeFromString(tp_descriptor)
 
     for f_desc_pb2 in tp_file_desc_set_pb2.file:
       self.descriptor_pool.Add(f_desc_pb2)
 
     # Load metrics descriptor and add to descriptor pool
-    metrics_descriptor_bytes = get_loader().read_metrics_descriptor()
     metrics_file_desc_set_pb2 = descriptor_pb2.FileDescriptorSet()
-    metrics_file_desc_set_pb2.MergeFromString(metrics_descriptor_bytes)
+    metrics_file_desc_set_pb2.MergeFromString(metrics_descriptor)
 
     for f_desc_pb2 in metrics_file_desc_set_pb2.file:
       self.descriptor_pool.Add(f_desc_pb2)
diff --git a/test/trace_processor/python/api_integrationtest.py b/test/trace_processor/python/api_integrationtest.py
index 9e7bef3..03c8d9b 100644
--- a/test/trace_processor/python/api_integrationtest.py
+++ b/test/trace_processor/python/api_integrationtest.py
@@ -15,19 +15,30 @@
 
 import io
 import os
+from typing import Optional
 import unittest
 
 from trace_processor.api import TraceProcessor
+from trace_processor.api import TraceProcessorConfig
+from trace_processor.api import LoadableTrace
+
+
+def create_tp(trace: LoadableTrace):
+  return TraceProcessor(
+      trace=trace,
+      config=TraceProcessorConfig(bin_path=os.environ["SHELL_PATH"]))
+
+
+def example_android_trace_path():
+  return os.path.join(os.environ["ROOT_DIR"], 'test', 'data',
+                      'example_android_trace_30s.pb')
 
 
 class TestApi(unittest.TestCase):
 
   def test_trace_path(self):
     # Get path to trace_processor_shell and construct TraceProcessor
-    tp = TraceProcessor(
-        trace=os.path.join(os.environ["ROOT_DIR"], 'test', 'data',
-                           'example_android_trace_30s.pb'),
-        bin_path=os.environ["SHELL_PATH"])
+    tp = create_tp(trace=example_android_trace_path())
     qr_iterator = tp.query('select * from slice limit 10')
     dur_result = [
         178646, 119740, 58073, 155000, 173177, 20209377, 3589167, 90104, 275312,
@@ -54,7 +65,7 @@
     f = io.BytesIO(
         b'\n(\n&\x08\x00\x12\x12\x08\x01\x10\xc8\x01\x1a\x0b\x12\t'
         b'B|200|foo\x12\x0e\x08\x02\x10\xc8\x01\x1a\x07\x12\x05E|200')
-    with TraceProcessor(trace=f, bin_path=os.environ["SHELL_PATH"]) as tp:
+    with create_tp(trace=f) as tp:
       qr_iterator = tp.query('select * from slice limit 10')
       res = list(qr_iterator)
 
@@ -66,10 +77,8 @@
       self.assertEqual(row.name, 'foo')
 
   def test_trace_file(self):
-    path = os.path.join(os.environ["ROOT_DIR"], 'test', 'data',
-                        'example_android_trace_30s.pb')
-    with open(path, 'rb') as file:
-      with TraceProcessor(trace=file, bin_path=os.environ["SHELL_PATH"]) as tp:
+    with open(example_android_trace_path(), 'rb') as file:
+      with create_tp(trace=file) as tp:
         qr_iterator = tp.query('select * from slice limit 10')
         dur_result = [
             178646, 119740, 58073, 155000, 173177, 20209377, 3589167, 90104,
@@ -82,13 +91,10 @@
   def test_trace_generator(self):
 
     def reader_generator():
-      path = os.path.join(os.environ["ROOT_DIR"], 'test', 'data',
-                          'example_android_trace_30s.pb')
-      with open(path, 'rb') as file:
+      with open(example_android_trace_path(), 'rb') as file:
         yield file.read(1024)
 
-    with TraceProcessor(
-        trace=reader_generator(), bin_path=os.environ["SHELL_PATH"]) as tp:
+    with create_tp(trace=reader_generator()) as tp:
       qr_iterator = tp.query('select * from slice limit 10')
       dur_result = [
           178646, 119740, 58073, 155000, 173177, 20209377, 3589167, 90104,
diff --git a/test/trace_processor/python/api_unittest.py b/test/trace_processor/python/api_unittest.py
index 29abd5f..2306e26 100755
--- a/test/trace_processor/python/api_unittest.py
+++ b/test/trace_processor/python/api_unittest.py
@@ -15,23 +15,30 @@
 
 import unittest
 
-from trace_processor.api import TraceProcessor, TraceProcessorException
+from trace_processor.api import TraceProcessor
+from trace_processor.api import TraceProcessorException
+from trace_processor.api import TraceProcessorConfig
 from trace_processor.protos import ProtoFactory
 
+TP_CONFIG = TraceProcessorConfig()
+PROTO_FACTORY = ProtoFactory(
+    tp_descriptor=TP_CONFIG.read_tp_descriptor(),
+    metrics_descriptor=TP_CONFIG.read_metrics_descriptor())
+
 
 class TestQueryResultIterator(unittest.TestCase):
   # The numbers input into cells correspond the CellType enum values
   # defined under trace_processor.proto
-  CELL_VARINT = ProtoFactory().CellsBatch().CELL_VARINT
-  CELL_STRING = ProtoFactory().CellsBatch().CELL_STRING
-  CELL_INVALID = ProtoFactory().CellsBatch().CELL_INVALID
-  CELL_NULL = ProtoFactory().CellsBatch().CELL_NULL
+  CELL_VARINT = PROTO_FACTORY.CellsBatch().CELL_VARINT
+  CELL_STRING = PROTO_FACTORY.CellsBatch().CELL_STRING
+  CELL_INVALID = PROTO_FACTORY.CellsBatch().CELL_INVALID
+  CELL_NULL = PROTO_FACTORY.CellsBatch().CELL_NULL
 
   def test_one_batch(self):
     int_values = [100, 200]
     str_values = ['bar1', 'bar2']
 
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -56,7 +63,7 @@
     int_values = [100, 200, 300, 400]
     str_values = ['bar1', 'bar2', 'bar3', 'bar4']
 
-    batch_1 = ProtoFactory().CellsBatch()
+    batch_1 = PROTO_FACTORY.CellsBatch()
     batch_1.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -69,7 +76,7 @@
     batch_1.string_cells = "\0".join(str_values[:2]) + "\0"
     batch_1.is_last_batch = False
 
-    batch_2 = ProtoFactory().CellsBatch()
+    batch_2 = PROTO_FACTORY.CellsBatch()
     batch_2.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -91,7 +98,7 @@
       self.assertEqual(row.foo_null, None)
 
   def test_empty_batch(self):
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.is_last_batch = True
 
     qr_iterator = TraceProcessor.QueryResultIterator([], [batch])
@@ -101,7 +108,7 @@
       self.assertIsNone(row.foo_num)
 
   def test_invalid_batch(self):
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
 
     # Since the batch isn't defined as the last batch, the QueryResultsIterator
     # expects another batch and thus raises IndexError as no next batch exists.
@@ -112,7 +119,7 @@
     int_values = [100, 200, 300, 500, 600]
     str_values = ['bar1', 'bar2', 'bar3']
 
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -143,7 +150,7 @@
   def test_incorrect_cells_batch(self):
     str_values = ['bar1', 'bar2']
 
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -163,7 +170,7 @@
         pass
 
   def test_incorrect_columns_batch(self):
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_VARINT, TestQueryResultIterator.CELL_VARINT
     ])
@@ -178,7 +185,7 @@
           ['foo_id', 'foo_num', 'foo_dur', 'foo_ms'], [batch])
 
   def test_invalid_cell_type(self):
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_INVALID,
         TestQueryResultIterator.CELL_VARINT
@@ -200,7 +207,7 @@
     int_values = [100, 200]
     str_values = ['bar1', 'bar2']
 
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -226,7 +233,7 @@
     int_values = [100, 200, 300, 400]
     str_values = ['bar1', 'bar2', 'bar3', 'bar4']
 
-    batch_1 = ProtoFactory().CellsBatch()
+    batch_1 = PROTO_FACTORY.CellsBatch()
     batch_1.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -239,7 +246,7 @@
     batch_1.string_cells = "\0".join(str_values[:2]) + "\0"
     batch_1.is_last_batch = False
 
-    batch_2 = ProtoFactory().CellsBatch()
+    batch_2 = PROTO_FACTORY.CellsBatch()
     batch_2.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -262,7 +269,7 @@
       self.assertEqual(row['foo_null'], None)
 
   def test_empty_batch_as_pandas(self):
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.is_last_batch = True
 
     qr_iterator = TraceProcessor.QueryResultIterator([], [batch])
@@ -276,7 +283,7 @@
     int_values = [100, 200, 300, 500, 600]
     str_values = ['bar1', 'bar2', 'bar3']
 
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -308,7 +315,7 @@
   def test_incorrect_cells_batch_as_pandas(self):
     str_values = ['bar1', 'bar2']
 
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_STRING,
         TestQueryResultIterator.CELL_VARINT,
@@ -327,7 +334,7 @@
       qr_df = qr_iterator.as_pandas_dataframe()
 
   def test_invalid_cell_type_as_pandas(self):
-    batch = ProtoFactory().CellsBatch()
+    batch = PROTO_FACTORY.CellsBatch()
     batch.cells.extend([
         TestQueryResultIterator.CELL_INVALID,
         TestQueryResultIterator.CELL_VARINT
diff --git a/tools/batch_trace_processor/main.py b/tools/batch_trace_processor/main.py
index b73f7d9..651247d 100644
--- a/tools/batch_trace_processor/main.py
+++ b/tools/batch_trace_processor/main.py
@@ -23,8 +23,8 @@
 import pandas as pd
 import plotille
 
-from perfetto.batch_trace_processor.api import BatchTraceProcessor
-from perfetto.trace_processor import TraceProcessorException
+from perfetto.batch_trace_processor.api import BatchTraceProcessor, BatchTraceProcessorConfig
+from perfetto.trace_processor import TraceProcessorException, TraceProcessorConfig
 from typing import List
 
 
@@ -107,8 +107,13 @@
     logging.info("At least one file must be specified in files or file list")
 
   logging.info('Loading traces...')
-  with BatchTraceProcessor(
-      files, bin_path=args.shell_path, verbose=args.verbose) as batch_tp:
+  config = BatchTraceProcessorConfig(
+      tp_config=TraceProcessorConfig(
+          bin_path=args.shell_path,
+          verbose=args.verbose,
+      ))
+
+  with BatchTraceProcessor(files, config) as batch_tp:
     if args.query_file:
       logging.info('Running query file...')
 
diff --git a/tools/batch_trace_processor/perfetto/batch_trace_processor/api.py b/tools/batch_trace_processor/perfetto/batch_trace_processor/api.py
index 4ae5f14..dcd640c 100644
--- a/tools/batch_trace_processor/perfetto/batch_trace_processor/api.py
+++ b/tools/batch_trace_processor/perfetto/batch_trace_processor/api.py
@@ -16,22 +16,18 @@
 
 """Contains classes for BatchTraceProcessor API."""
 
-import concurrent.futures as cf
+from concurrent.futures.thread import ThreadPoolExecutor
 import dataclasses as dc
-from typing import Any, Callable, Dict, Tuple, Union, List
+import multiprocessing
+from typing import Any, Callable, Dict, Optional, Tuple, Union, List
+from numpy.lib.npyio import load
 
 import pandas as pd
 
 from perfetto.trace_processor import LoadableTrace
 from perfetto.trace_processor import TraceProcessor
 from perfetto.trace_processor import TraceProcessorException
-
-
-@dc.dataclass
-class _TpArg:
-  bin_path: str
-  verbose: bool
-  trace: LoadableTrace
+from perfetto.trace_processor import TraceProcessorConfig
 
 
 @dc.dataclass
@@ -40,6 +36,48 @@
   args: Dict[str, str]
 
 
+@dc.dataclass
+class BatchTraceProcessorConfig:
+  TraceProvider = Callable[[str], List[
+      Union[LoadableTrace, BatchLoadableTrace]]]
+
+  tp_config: TraceProcessorConfig
+
+  query_executor: Optional[ThreadPoolExecutor]
+  load_executor: Optional[ThreadPoolExecutor]
+
+  trace_provider: TraceProvider
+
+  def __default_trace_provider(custom_string: str):
+    del custom_string
+    raise TraceProcessorException(
+        'Passed a string to batch trace processor constructor without '
+        'a trace provider being registered.')
+
+  def __init__(self,
+               tp_config: TraceProcessorConfig = TraceProcessorConfig(),
+               query_executor: Optional[ThreadPoolExecutor] = None,
+               load_executor: Optional[ThreadPoolExecutor] = None,
+               trace_provider: TraceProvider = __default_trace_provider):
+    self.tp_config = tp_config
+
+    self.query_executor = query_executor
+    self.load_executor = load_executor
+
+    self.trace_provider = trace_provider
+
+    try:
+      # This is the only place in batch trace processor which should import
+      # from a "vendor" namespace - the purpose of this code is to allow
+      # for users to set their own "default" config for batch trace processor
+      # without needing to specify the config in every place when batch
+      # trace processor is used.
+      from .vendor import override_batch_tp_config
+      override_batch_tp_config(self)
+    except ModuleNotFoundError:
+      pass
+
+
 class BatchTraceProcessor:
   """Run ad-hoc SQL queries across many Perfetto traces.
 
@@ -50,28 +88,33 @@
         print(df)
   """
 
-  def __init__(self,
-               traces: List[Union[LoadableTrace, BatchLoadableTrace]],
-               bin_path: str = None,
-               verbose: bool = False):
+  def __init__(
+      self,
+      traces: Union[str, List[Union[LoadableTrace, BatchLoadableTrace]]],
+      config: BatchTraceProcessorConfig = BatchTraceProcessorConfig()):
     """Creates a batch trace processor instance.
 
     BatchTraceProcessor is the blessed way of running ad-hoc queries in
     Python across many traces.
 
     Args:
-      traces: A list of traces to load into this instance. Each object in
-        the list can be one of the following types:
+      traces: Either a list of traces or a custom string which will be
+        converted to a list of traces.
+
+        If a list, each item can be one of the following types:
         1) path to a trace file to open and read
         2) a file like object (file, io.BytesIO or similar) to read
         3) a generator yielding bytes
         4) a BatchLoadableTrace object; this is basically a wrapper around
            one of the above types plus an args field; see |query_and_flatten|
            for the motivation for the args field.
-      bin_path: Optional path to a trace processor shell binary to use to
-        load the traces.
-      verbose: Optional flag indiciating whether verbose trace processor
-        output should be printed to stderr.
+
+        If a string, it is passed to BatchTraceProcessorConfig.trace_provider to
+        convert to a list of traces; the default implementation of this
+        function just throws an exception so an implementation must be provided
+        if strings will be passed.
+      config: configuration options which customize functionality of batch
+        trace processor and underlying trace processors.
     """
 
     def _create_batch_trace(x: Union[LoadableTrace, BatchLoadableTrace]
@@ -80,19 +123,27 @@
         return x
       return BatchLoadableTrace(trace=x, args={})
 
-    def create_tp(arg: _TpArg) -> TraceProcessor:
-      return TraceProcessor(
-          trace=arg.trace, bin_path=arg.bin_path, verbose=arg.verbose)
+    def create_tp(trace: BatchLoadableTrace) -> TraceProcessor:
+      return TraceProcessor(trace=trace.trace, config=config.tp_config)
+
+    if isinstance(traces, str):
+      trace_list = config.trace_provider(traces)
+    else:
+      trace_list = traces
+
+    batch_traces = [_create_batch_trace(t) for t in trace_list]
+
+    # As trace processor is completely CPU bound, it makes sense to just
+    # max out the CPUs available.
+    query_executor = config.query_executor or ThreadPoolExecutor(
+        max_workers=multiprocessing.cpu_count())
+    load_exectuor = config.load_executor or query_executor
 
     self.tps = None
     self.closed = False
-    self.executor = cf.ThreadPoolExecutor()
-
-    batch_traces = [_create_batch_trace(t) for t in traces]
+    self.query_executor = query_executor
     self.args = [t.args for t in batch_traces]
-
-    tp_args = [_TpArg(bin_path, verbose, t.trace) for t in batch_traces]
-    self.tps = list(self.executor.map(create_tp, tp_args))
+    self.tps = list(load_exectuor.map(create_tp, batch_traces))
 
   def metric(self, metrics: List[str]):
     """Computes the provided metrics.
@@ -199,7 +250,7 @@
       A list of values with the result of executing the fucntion (one per
       trace).
     """
-    return list(self.executor.map(fn, self.tps))
+    return list(self.query_executor.map(fn, self.tps))
 
   def execute_and_flatten(self, fn: Callable[[TraceProcessor], pd.DataFrame]
                          ) -> pd.DataFrame:
@@ -225,7 +276,8 @@
         df[key] = value
       return df
 
-    df = pd.concat(list(self.executor.map(wrapped, zip(self.tps, self.args))))
+    df = pd.concat(
+        list(self.query_executor.map(wrapped, zip(self.tps, self.args))))
     return df.reset_index(drop=True)
 
   def close(self):
@@ -240,7 +292,6 @@
     if self.closed:
       return
     self.closed = True
-    self.executor.shutdown()
 
     if self.tps:
       for tp in self.tps:
diff --git a/tools/slice_breakdown/main.py b/tools/slice_breakdown/main.py
index 960e44b..f6a4c4c 100644
--- a/tools/slice_breakdown/main.py
+++ b/tools/slice_breakdown/main.py
@@ -19,21 +19,21 @@
 import argparse
 import sys
 
-from perfetto.slice_breakdown import compute_breakdown, compute_breakdown_for_startup
+from perfetto.slice_breakdown import compute_breakdown
+from perfetto.slice_breakdown import compute_breakdown_for_startup
 from perfetto.trace_processor import TraceProcessor
+from perfetto.trace_processor import TraceProcessorConfig
 
 
 def compute_breakdown_wrapper(args):
-  tp = TraceProcessor(
-      file_path=args.file, bin_path=args.shell_path, verbose=args.verbose)
-  if args.startup_bounds:
-    breakdown = compute_breakdown_for_startup(tp, args.startup_package,
-                                              args.process_name)
-  else:
-    breakdown = compute_breakdown(tp, args.start_ts, args.end_ts,
-                                  args.process_name)
-  tp.close()
-
+  config = TraceProcessorConfig(bin_path=args.shell_path, verbose=args.verbose)
+  with TraceProcessor(trace=args.file, config=config) as tp:
+    if args.startup_bounds:
+      breakdown = compute_breakdown_for_startup(tp, args.startup_package,
+                                                args.process_name)
+    else:
+      breakdown = compute_breakdown(tp, args.start_ts, args.end_ts,
+                                    args.process_name)
   return breakdown