Merge "tp: introduces SqliteEngine to broker all interactions with SQLite"
diff --git a/Android.bp b/Android.bp
index 2077923..2efade0 100644
--- a/Android.bp
+++ b/Android.bp
@@ -10185,6 +10185,7 @@
     srcs: [
         "src/trace_processor/sqlite/db_sqlite_table.cc",
         "src/trace_processor/sqlite/sql_stats_table.cc",
+        "src/trace_processor/sqlite/sqlite_engine.cc",
         "src/trace_processor/sqlite/sqlite_utils.cc",
         "src/trace_processor/sqlite/stats_table.cc",
     ],
diff --git a/BUILD b/BUILD
index 6481aa5..c6cf29d 100644
--- a/BUILD
+++ b/BUILD
@@ -1924,6 +1924,8 @@
         "src/trace_processor/sqlite/query_cache.h",
         "src/trace_processor/sqlite/sql_stats_table.cc",
         "src/trace_processor/sqlite/sql_stats_table.h",
+        "src/trace_processor/sqlite/sqlite_engine.cc",
+        "src/trace_processor/sqlite/sqlite_engine.h",
         "src/trace_processor/sqlite/sqlite_utils.cc",
         "src/trace_processor/sqlite/sqlite_utils.h",
         "src/trace_processor/sqlite/stats_table.cc",
diff --git a/src/trace_processor/sqlite/BUILD.gn b/src/trace_processor/sqlite/BUILD.gn
index 8f16e17..25024b1 100644
--- a/src/trace_processor/sqlite/BUILD.gn
+++ b/src/trace_processor/sqlite/BUILD.gn
@@ -23,6 +23,8 @@
     "query_cache.h",
     "sql_stats_table.cc",
     "sql_stats_table.h",
+    "sqlite_engine.cc",
+    "sqlite_engine.h",
     "sqlite_utils.cc",
     "sqlite_utils.h",
     "stats_table.cc",
diff --git a/src/trace_processor/sqlite/sqlite_engine.cc b/src/trace_processor/sqlite/sqlite_engine.cc
new file mode 100644
index 0000000..429318a
--- /dev/null
+++ b/src/trace_processor/sqlite/sqlite_engine.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/trace_processor/sqlite/sqlite_engine.h"
+#include "src/trace_processor/sqlite/db_sqlite_table.h"
+#include "src/trace_processor/sqlite/query_cache.h"
+
+// In Android and Chromium tree builds, we don't have the percentile module.
+// Just don't include it.
+#if PERFETTO_BUILDFLAG(PERFETTO_TP_PERCENTILE)
+// defined in sqlite_src/ext/misc/percentile.c
+extern "C" int sqlite3_percentile_init(sqlite3* db,
+                                       char** error,
+                                       const sqlite3_api_routines* api);
+#endif  // PERFETTO_BUILDFLAG(PERFETTO_TP_PERCENTILE)
+
+namespace perfetto {
+namespace trace_processor {
+namespace {
+
+void EnsureSqliteInitialized() {
+  // sqlite3_initialize isn't actually thread-safe despite being documented
+  // as such; we need to make sure multiple TraceProcessorImpl instances don't
+  // call it concurrently and only gets called once per process, instead.
+  static bool init_once = [] { return sqlite3_initialize() == SQLITE_OK; }();
+  PERFETTO_CHECK(init_once);
+}
+
+void InitializeSqlite(sqlite3* db) {
+  char* error = nullptr;
+  sqlite3_exec(db, "PRAGMA temp_store=2", nullptr, nullptr, &error);
+  if (error) {
+    PERFETTO_FATAL("Error setting pragma temp_store: %s", error);
+  }
+// In Android tree builds, we don't have the percentile module.
+// Just don't include it.
+#if PERFETTO_BUILDFLAG(PERFETTO_TP_PERCENTILE)
+  sqlite3_percentile_init(db, &error, nullptr);
+  if (error) {
+    PERFETTO_ELOG("Error initializing: %s", error);
+    sqlite3_free(error);
+  }
+#endif
+}
+
+}  // namespace
+
+SqliteEngine::SqliteEngine() : query_cache_(new QueryCache()) {
+  sqlite3* db = nullptr;
+  EnsureSqliteInitialized();
+  PERFETTO_CHECK(sqlite3_open(":memory:", &db) == SQLITE_OK);
+  InitializeSqlite(db);
+  db_.reset(std::move(db));
+}
+
+void SqliteEngine::RegisterTable(const Table& table, const std::string& name) {
+  DbSqliteTable::RegisterTable(*db_, query_cache_.get(), &table, name);
+}
+
+void SqliteEngine::RegisterTableFunction(std::unique_ptr<TableFunction> fn) {
+  DbSqliteTable::RegisterTable(*db_, query_cache_.get(), std::move(fn));
+}
+
+}  // namespace trace_processor
+}  // namespace perfetto
diff --git a/src/trace_processor/sqlite/sqlite_engine.h b/src/trace_processor/sqlite/sqlite_engine.h
new file mode 100644
index 0000000..f1ebdd3
--- /dev/null
+++ b/src/trace_processor/sqlite/sqlite_engine.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_TRACE_PROCESSOR_SQLITE_SQLITE_ENGINE_H_
+#define SRC_TRACE_PROCESSOR_SQLITE_SQLITE_ENGINE_H_
+
+#include <sqlite3.h>
+
+#include "src/trace_processor/db/table.h"
+#include "src/trace_processor/prelude/table_functions/table_function.h"
+#include "src/trace_processor/sqlite/query_cache.h"
+#include "src/trace_processor/sqlite/scoped_db.h"
+
+namespace perfetto {
+namespace trace_processor {
+
+// Wrapper class around SQLite C API.
+//
+// The goal of this class is to provide a one-stop-shop mechanism to use SQLite.
+// Benefits of this include:
+// 1) It allows us to add code which intercepts registration of functions
+//    and tables and keeps track of this for later lookup.
+// 2) Allows easily auditing the SQLite APIs we use making it easy to determine
+//    what functionality we rely on.
+class SqliteEngine {
+ public:
+  SqliteEngine();
+
+  // Registers a trace processor C++ table with SQLite with an SQL name of
+  // |name|.
+  void RegisterTable(const Table& table, const std::string& name);
+
+  // Registers a trace processor C++ function with SQLite.
+  void RegisterTableFunction(std::unique_ptr<TableFunction> fn);
+
+  sqlite3* db() const { return db_.get(); }
+
+ private:
+  // Keep this first: we need this to be destroyed after we clean up
+  // everything else.
+  ScopedDb db_;
+  std::unique_ptr<QueryCache> query_cache_;
+};
+
+}  // namespace trace_processor
+}  // namespace perfetto
+
+#endif  // SRC_TRACE_PROCESSOR_SQLITE_SQLITE_ENGINE_H_
diff --git a/src/trace_processor/trace_processor_impl.cc b/src/trace_processor/trace_processor_impl.cc
index e92942d..5e10fc1 100644
--- a/src/trace_processor/trace_processor_impl.cc
+++ b/src/trace_processor/trace_processor_impl.cc
@@ -97,15 +97,6 @@
 #include "src/trace_processor/metrics/sql/amalgamated_sql_metrics.h"
 #include "src/trace_processor/stdlib/amalgamated_stdlib.h"
 
-// In Android and Chromium tree builds, we don't have the percentile module.
-// Just don't include it.
-#if PERFETTO_BUILDFLAG(PERFETTO_TP_PERCENTILE)
-// defined in sqlite_src/ext/misc/percentile.c
-extern "C" int sqlite3_percentile_init(sqlite3* db,
-                                       char** error,
-                                       const sqlite3_api_routines* api);
-#endif  // PERFETTO_BUILDFLAG(PERFETTO_TP_PERCENTILE)
-
 namespace perfetto {
 namespace trace_processor {
 namespace {
@@ -126,24 +117,6 @@
     PERFETTO_ELOG("%s", status.c_message());
 }
 
-void InitializeSqlite(sqlite3* db) {
-  char* error = nullptr;
-  sqlite3_exec(db, "PRAGMA temp_store=2", nullptr, nullptr, &error);
-  if (error) {
-    PERFETTO_FATAL("Error setting pragma temp_store: %s", error);
-  }
-  sqlite3_str_split_init(db);
-// In Android tree builds, we don't have the percentile module.
-// Just don't include it.
-#if PERFETTO_BUILDFLAG(PERFETTO_TP_PERCENTILE)
-  sqlite3_percentile_init(db, &error, nullptr);
-  if (error) {
-    PERFETTO_ELOG("Error initializing: %s", error);
-    sqlite3_free(error);
-  }
-#endif
-}
-
 void BuildBoundsTable(sqlite3* db, std::pair<int64_t, int64_t> bounds) {
   char* error = nullptr;
   sqlite3_exec(db, "DELETE FROM trace_bounds", nullptr, nullptr, &error);
@@ -313,14 +286,6 @@
   }
 }
 
-void EnsureSqliteInitialized() {
-  // sqlite3_initialize isn't actually thread-safe despite being documented
-  // as such; we need to make sure multiple TraceProcessorImpl instances don't
-  // call it concurrently and only gets called once per process, instead.
-  static bool init_once = [] { return sqlite3_initialize() == SQLITE_OK; }();
-  PERFETTO_CHECK(init_once);
-}
-
 void InsertIntoTraceMetricsTable(sqlite3* db, const std::string& metric_name) {
   char* insert_sql = sqlite3_mprintf(
       "INSERT INTO trace_metrics(name) VALUES('%q')", metric_name.c_str());
@@ -546,64 +511,60 @@
     context_.content_analyzer.reset(new ProtoContentAnalyzer(&context_));
   }
 
+  sqlite3_str_split_init(engine_.db());
   RegisterAdditionalModules(&context_);
-
-  sqlite3* db = nullptr;
-  EnsureSqliteInitialized();
-  PERFETTO_CHECK(sqlite3_open(":memory:", &db) == SQLITE_OK);
-  InitializeSqlite(db);
-  InitializePreludeTablesViews(db);
-  db_.reset(std::move(db));
+  InitializePreludeTablesViews(engine_.db());
 
   // New style function registration.
   if (cfg.enable_dev_features) {
-    RegisterDevFunctions(db);
+    RegisterDevFunctions(engine_.db());
   }
-  RegisterFunction<Glob>(db, "glob", 2);
-  RegisterFunction<Hash>(db, "HASH", -1);
-  RegisterFunction<Base64Encode>(db, "BASE64_ENCODE", 1);
-  RegisterFunction<Demangle>(db, "DEMANGLE", 1);
-  RegisterFunction<SourceGeq>(db, "SOURCE_GEQ", -1);
-  RegisterFunction<ExportJson>(db, "EXPORT_JSON", 1, context_.storage.get(),
-                               false);
-  RegisterFunction<ExtractArg>(db, "EXTRACT_ARG", 2, context_.storage.get());
-  RegisterFunction<AbsTimeStr>(db, "ABS_TIME_STR", 1,
+  RegisterFunction<Glob>(engine_.db(), "glob", 2);
+  RegisterFunction<Hash>(engine_.db(), "HASH", -1);
+  RegisterFunction<Base64Encode>(engine_.db(), "BASE64_ENCODE", 1);
+  RegisterFunction<Demangle>(engine_.db(), "DEMANGLE", 1);
+  RegisterFunction<SourceGeq>(engine_.db(), "SOURCE_GEQ", -1);
+  RegisterFunction<ExportJson>(engine_.db(), "EXPORT_JSON", 1,
+                               context_.storage.get(), false);
+  RegisterFunction<ExtractArg>(engine_.db(), "EXTRACT_ARG", 2,
+                               context_.storage.get());
+  RegisterFunction<AbsTimeStr>(engine_.db(), "ABS_TIME_STR", 1,
                                context_.clock_converter.get());
-  RegisterFunction<ToMonotonic>(db, "TO_MONOTONIC", 1,
+  RegisterFunction<ToMonotonic>(engine_.db(), "TO_MONOTONIC", 1,
                                 context_.clock_converter.get());
   RegisterFunction<CreateFunction>(
-      db, "CREATE_FUNCTION", 3,
+      engine_.db(), "CREATE_FUNCTION", 3,
       std::unique_ptr<CreateFunction::Context>(
-          new CreateFunction::Context{db_.get(), &create_function_state_}));
+          new CreateFunction::Context{engine_.db(), &create_function_state_}));
   RegisterFunction<CreateViewFunction>(
-      db, "CREATE_VIEW_FUNCTION", 3,
+      engine_.db(), "CREATE_VIEW_FUNCTION", 3,
       std::unique_ptr<CreateViewFunction::Context>(
-          new CreateViewFunction::Context{db_.get()}));
-  RegisterFunction<Import>(db, "IMPORT", 1,
+          new CreateViewFunction::Context{engine_.db()}));
+  RegisterFunction<Import>(engine_.db(), "IMPORT", 1,
                            std::unique_ptr<Import::Context>(new Import::Context{
-                               db_.get(), this, &sql_modules_}));
+                               engine_.db(), this, &sql_modules_}));
   RegisterFunction<ToFtrace>(
-      db, "TO_FTRACE", 1,
+      engine_.db(), "TO_FTRACE", 1,
       std::unique_ptr<ToFtrace::Context>(new ToFtrace::Context{
           context_.storage.get(), SystraceSerializer(&context_)}));
 
   // Old style function registration.
   // TODO(lalitm): migrate this over to using RegisterFunction once aggregate
   // functions are supported.
-  RegisterLastNonNullFunction(db);
-  RegisterValueAtMaxTsFunction(db);
+  RegisterLastNonNullFunction(engine_.db());
+  RegisterValueAtMaxTsFunction(engine_.db());
   {
-    base::Status status = RegisterStackFunctions(db, &context_);
+    base::Status status = RegisterStackFunctions(engine_.db(), &context_);
     if (!status.ok())
       PERFETTO_ELOG("%s", status.c_message());
   }
   {
-    base::Status status = PprofFunctions::Register(db, &context_);
+    base::Status status = PprofFunctions::Register(engine_.db(), &context_);
     if (!status.ok())
       PERFETTO_ELOG("%s", status.c_message());
   }
   {
-    base::Status status = LayoutFunctions::Register(db, &context_);
+    base::Status status = LayoutFunctions::Register(engine_.db(), &context_);
     if (!status.ok())
       PERFETTO_ELOG("%s", status.c_message());
   }
@@ -616,20 +577,18 @@
       PERFETTO_ELOG("%s", status.c_message());
   }
 
-  SetupMetrics(this, *db_, &sql_metrics_, cfg.skip_builtin_metric_paths);
-
-  // Setup the query cache.
-  query_cache_.reset(new QueryCache());
+  SetupMetrics(this, engine_.db(), &sql_metrics_,
+               cfg.skip_builtin_metric_paths);
 
   const TraceStorage* storage = context_.storage.get();
 
-  SqlStatsTable::RegisterTable(*db_, storage);
-  StatsTable::RegisterTable(*db_, storage);
+  SqlStatsTable::RegisterTable(engine_.db(), storage);
+  StatsTable::RegisterTable(engine_.db(), storage);
 
   // Operator tables.
-  SpanJoinOperatorTable::RegisterTable(*db_, storage);
-  WindowOperatorTable::RegisterTable(*db_, storage);
-  CreateViewFunction::RegisterTable(*db_);
+  SpanJoinOperatorTable::RegisterTable(engine_.db(), storage);
+  WindowOperatorTable::RegisterTable(engine_.db(), storage);
+  CreateViewFunction::RegisterTable(engine_.db());
 
   // Tables dynamically generated at query time.
   RegisterTableFunction(std::unique_ptr<ExperimentalFlamegraph>(
@@ -774,7 +733,7 @@
       context_.storage->InternString(TraceTypeToString(context_.trace_type));
   context_.metadata_tracker->SetMetadata(metadata::trace_type,
                                          Variadic::String(trace_type_id));
-  BuildBoundsTable(*db_, context_.storage->GetTraceTimestampBoundsNs());
+  BuildBoundsTable(engine_.db(), context_.storage->GetTraceTimestampBoundsNs());
 }
 
 void TraceProcessorImpl::NotifyEndOfFile() {
@@ -812,7 +771,7 @@
   // TraceProcessorStorageImpl::NotifyEndOfFile, this will be counted in
   // trace bounds: this is important for parsers like ninja which wait until
   // the end to flush all their data.
-  BuildBoundsTable(*db_, context_.storage->GetTraceTimestampBoundsNs());
+  BuildBoundsTable(engine_.db(), context_.storage->GetTraceTimestampBoundsNs());
 
   TraceProcessorStorageImpl::DestroyContext();
 }
@@ -859,19 +818,20 @@
   ScopedStmt stmt;
   IteratorImpl::StmtMetadata metadata;
   base::Status status =
-      PrepareAndStepUntilLastValidStmt(*db_, sql, &stmt, &metadata);
+      PrepareAndStepUntilLastValidStmt(engine_.db(), sql, &stmt, &metadata);
   PERFETTO_DCHECK((status.ok() && stmt) || (!status.ok() && !stmt));
 
-  std::unique_ptr<IteratorImpl> impl(new IteratorImpl(
-      this, *db_, status, std::move(stmt), std::move(metadata), sql_stats_row));
+  std::unique_ptr<IteratorImpl> impl(
+      new IteratorImpl(this, engine_.db(), status, std::move(stmt),
+                       std::move(metadata), sql_stats_row));
   return Iterator(std::move(impl));
 }
 
 void TraceProcessorImpl::InterruptQuery() {
-  if (!db_)
+  if (!engine_.db())
     return;
   query_interrupted_.store(true);
-  sqlite3_interrupt(db_.get());
+  sqlite3_interrupt(engine_.db());
 }
 
 bool TraceProcessorImpl::IsRootMetricField(const std::string& metric_name) {
@@ -963,7 +923,7 @@
           prev_path.c_str(), path.c_str(), metric.proto_field_name->c_str());
     }
 
-    InsertIntoTraceMetricsTable(*db_, no_ext_name);
+    InsertIntoTraceMetricsTable(engine_.db(), no_ext_name);
   }
 
   sql_metrics_.emplace_back(metric);
@@ -991,7 +951,7 @@
     auto fn_name = desc.full_name().substr(desc.package_name().size() + 1);
     std::replace(fn_name.begin(), fn_name.end(), '.', '_');
     RegisterFunction<metrics::BuildProto>(
-        db_.get(), fn_name.c_str(), -1,
+        engine_.db(), fn_name.c_str(), -1,
         std::unique_ptr<metrics::BuildProto::Context>(
             new metrics::BuildProto::Context{this, &pool_, i}));
   }
diff --git a/src/trace_processor/trace_processor_impl.h b/src/trace_processor/trace_processor_impl.h
index 37e7a2c..88b4b6e 100644
--- a/src/trace_processor/trace_processor_impl.h
+++ b/src/trace_processor/trace_processor_impl.h
@@ -36,6 +36,7 @@
 #include "src/trace_processor/sqlite/db_sqlite_table.h"
 #include "src/trace_processor/sqlite/query_cache.h"
 #include "src/trace_processor/sqlite/scoped_db.h"
+#include "src/trace_processor/sqlite/sqlite_engine.h"
 #include "src/trace_processor/trace_processor_storage_impl.h"
 #include "src/trace_processor/util/sql_modules.h"
 
@@ -107,13 +108,11 @@
 
   template <typename Table>
   void RegisterDbTable(const Table& table) {
-    DbSqliteTable::RegisterTable(*db_, query_cache_.get(), &table,
-                                 Table::Name());
+    engine_.RegisterTable(table, Table::Name());
   }
 
-  void RegisterTableFunction(std::unique_ptr<TableFunction> generator) {
-    DbSqliteTable::RegisterTable(*db_, query_cache_.get(),
-                                 std::move(generator));
+  void RegisterTableFunction(std::unique_ptr<TableFunction> fn) {
+    engine_.RegisterTableFunction(std::move(fn));
   }
 
   template <typename View>
@@ -121,16 +120,12 @@
 
   bool IsRootMetricField(const std::string& metric_name);
 
-  // Keep this first: we need this to be destroyed after we clean up
-  // everything else.
-  ScopedDb db_;
+  SqliteEngine engine_;
 
   // State necessary for CREATE_FUNCTION invocations. We store this here as we
   // need to finalize any prepared statements *before* we destroy the database.
   CreateFunction::State create_function_state_;
 
-  std::unique_ptr<QueryCache> query_cache_;
-
   DescriptorPool pool_;
 
   // Map from module name to module contents. Used for IMPORT function.