btp: optionally allow failures when loading traces into BTP

This is needed so that pipelines don't hard-fail when encountering infra
errors and instead can increment error counters about loading failures.

Change-Id: Id5a456ab879be1b189fb5ee1cf6e9c826df3f696
Bug: 225399802
diff --git a/python/perfetto/batch_trace_processor/api.py b/python/perfetto/batch_trace_processor/api.py
index c54056e..0c84296 100644
--- a/python/perfetto/batch_trace_processor/api.py
+++ b/python/perfetto/batch_trace_processor/api.py
@@ -16,8 +16,9 @@
 
 import concurrent.futures as cf
 import dataclasses as dc
+from enum import Enum
 import multiprocessing
-from typing import Any, Callable, Dict, Tuple, List
+from typing import Any, Callable, Dict, Tuple, List, Optional
 
 import pandas as pd
 
@@ -39,12 +40,42 @@
 TraceListReference = registry.TraceListReference
 
 
+# Enum encoding how errors while loading traces in BatchTraceProcessor should
+# be handled.
+class LoadFailureHandling(Enum):
+  # If any trace fails to load, raises an exception causing the entire batch
+  # trace processor to fail.
+  # This is the default behaviour and the method which should be preferred for
+  # any interactive use of BatchTraceProcessor.
+  RAISE_EXCEPTION = 0
+
+  # If a trace fails to load, the trace processor for that trace is dropped but
+  # loading of other traces is unaffected. |load_failures| is incremented in the
+  # Stats class for the batch trace processor instance.
+  INCREMENT_STAT = 1
+
+
 @dc.dataclass
 class BatchTraceProcessorConfig:
   tp_config: TraceProcessorConfig
+  load_failure_handling: LoadFailureHandling
 
-  def __init__(self, tp_config: TraceProcessorConfig = TraceProcessorConfig()):
+  def __init__(self,
+               tp_config: TraceProcessorConfig = TraceProcessorConfig(),
+               load_failure_handling: LoadFailureHandling = LoadFailureHandling
+               .RAISE_EXCEPTION):
     self.tp_config = tp_config
+    self.load_failure_handling = load_failure_handling
+
+
+# Contains stats about the events which happened during the use of
+# BatchTraceProcessor.
+@dc.dataclass
+class Stats:
+  # The number of traces which failed to load; only non-zero if
+  # LoadFailureHandling.INCREMENT_STAT is chosen as the handling type.
+  load_failures: int = 0
+
 
 class BatchTraceProcessor:
   """Run ad-hoc SQL queries across many Perfetto traces.
@@ -92,6 +123,7 @@
 
     self.tps = None
     self.closed = False
+    self._stats = Stats()
 
     self.platform_delegate = PLATFORM_DELEGATE()
     self.tp_platform_delegate = TP_PLATFORM_DELEGATE()
@@ -117,7 +149,9 @@
 
     self.query_executor = query_executor
     self.metadata = [t.metadata for t in resolved]
-    self.tps = list(load_exectuor.map(self._create_tp, resolved))
+    self.tps = [
+        x for x in load_exectuor.map(self._create_tp, resolved) if x is not None
+    ]
 
   def metric(self, metrics: List[str]):
     """Computes the provided metrics.
@@ -279,8 +313,22 @@
       for tp in self.tps:
         tp.close()
 
-  def _create_tp(self, trace: ResolverRegistry.Result) -> TraceProcessor:
-    return TraceProcessor(trace=trace.generator, config=self.config.tp_config)
+  def stats(self):
+    """Statistics about the operation of this batch trace processor instance.
+    
+    See |Stats| class definition for the list of the statistics available."""
+    return self._stats
+
+  def _create_tp(self,
+                 trace: ResolverRegistry.Result) -> Optional[TraceProcessor]:
+    try:
+      return TraceProcessor(trace=trace.generator, config=self.config.tp_config)
+    except TraceProcessorException as ex:
+      if self.config.load_failure_handling == \
+        LoadFailureHandling.RAISE_EXCEPTION:
+        raise ex
+      self._stats.load_failures += 1
+      return None
 
   def __enter__(self):
     return self
diff --git a/python/test/api_integrationtest.py b/python/test/api_integrationtest.py
index d4755b7..9fde5fb 100644
--- a/python/test/api_integrationtest.py
+++ b/python/test/api_integrationtest.py
@@ -20,6 +20,7 @@
 import pandas as pd
 
 from perfetto.batch_trace_processor.api import BatchTraceProcessor
+from perfetto.batch_trace_processor.api import LoadFailureHandling
 from perfetto.batch_trace_processor.api import BatchTraceProcessorConfig
 from perfetto.batch_trace_processor.api import TraceListReference
 from perfetto.trace_processor.api import PLATFORM_DELEGATE
@@ -87,14 +88,18 @@
     ]
 
 
-def create_batch_tp(traces: TraceListReference):
+def create_batch_tp(
+    traces: TraceListReference,
+    failure_handling: LoadFailureHandling = LoadFailureHandling.RAISE_EXCEPTION
+):
   default = PLATFORM_DELEGATE().default_resolver_registry()
   default.register(SimpleResolver)
   default.register(RecursiveResolver)
   return BatchTraceProcessor(
       traces=traces,
       config=BatchTraceProcessorConfig(
-          TraceProcessorConfig(
+          load_failure_handling=failure_handling,
+          tp_config=TraceProcessorConfig(
               bin_path=os.environ["SHELL_PATH"], resolver_registry=default)))
 
 
@@ -223,3 +228,14 @@
             path=example_android_trace_path(), skip_resolve_file=True)) as btp:
       df = btp.query_and_flatten('select dur from slice limit 1')
       pd.testing.assert_frame_equal(df, expected, check_dtype=False)
+
+  def test_btp_failure(self):
+    f = io.BytesIO(b'<foo></foo>')
+    with self.assertRaises(TraceProcessorException):
+      _ = create_batch_tp(traces=f)
+
+  def test_btp_failure_increment_stat(self):
+    f = io.BytesIO(b'<foo></foo>')
+    btp = create_batch_tp(
+        traces=f, failure_handling=LoadFailureHandling.INCREMENT_STAT)
+    self.assertEqual(btp.stats().load_failures, 1)