tp: implement CREATE_FUNCTION SQL function

This allows defining functions directly from SQL. This will make writing
reusable snippets of code a lot more lightweight.

Change-Id: If6207fac3815b9b584da9648a0fdc525f2f63670
Bug: 190219056
diff --git a/src/trace_processor/metrics/sql/android/android_startup.sql b/src/trace_processor/metrics/sql/android/android_startup.sql
index f561d52..20eaefe 100644
--- a/src/trace_processor/metrics/sql/android/android_startup.sql
+++ b/src/trace_processor/metrics/sql/android/android_startup.sql
@@ -301,6 +301,15 @@
 FROM long_binder_transactions s
 LEFT JOIN binder_to_destination_process bdp USING(slice_id);
 
+SELECT CREATE_FUNCTION(
+  'MAIN_PROCESS_SLICE_PROTO(launch_id LONG, name STRING)',
+  'PROTO', '
+    SELECT slice_proto
+    FROM main_process_slice s
+    WHERE s.launch_id = $launch_id AND name LIKE $name
+    LIMIT 1
+  ');
+
 DROP VIEW IF EXISTS startup_view;
 CREATE VIEW startup_view AS
 SELECT
@@ -430,41 +439,18 @@
         FROM launching_events l
         WHERE l.ts BETWEEN launches.ts AND launches.ts + launches.dur
       ),
-      'time_post_fork', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'PostFork'
-      ),
-      'time_activity_thread_main', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'ActivityThreadMain'
-      ),
-      'time_bind_application', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'bindApplication'
-      ),
-      'time_activity_start', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'activityStart'
-      ),
-      'time_activity_resume', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'activityResume'
-      ),
-      'time_activity_restart', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'activityRestart'
-      ),
-      'time_choreographer', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name LIKE 'Choreographer#doFrame%'
-      ),
+      'time_post_fork', MAIN_PROCESS_SLICE_PROTO(launches.id, 'PostFork'),
+      'time_activity_thread_main', MAIN_PROCESS_SLICE_PROTO(launches.id, 'ActivityThreadMain'),
+      'time_bind_application', MAIN_PROCESS_SLICE_PROTO(launches.id, 'bindApplication'),
+      'time_activity_start', MAIN_PROCESS_SLICE_PROTO(launches.id, 'activityStart'),
+      'time_activity_resume', MAIN_PROCESS_SLICE_PROTO(launches.id, 'activityResume'),
+      'time_activity_restart', MAIN_PROCESS_SLICE_PROTO(launches.id, 'activityRestart'),
+      'time_choreographer', MAIN_PROCESS_SLICE_PROTO(launches.id, 'Choreographer#doFrame%'),
+      'time_inflate', MAIN_PROCESS_SLICE_PROTO(launches.id, 'inflate'),
+      'time_get_resources', MAIN_PROCESS_SLICE_PROTO(launches.id, 'ResourcesManager#getResources'),
+      'time_dex_open', MAIN_PROCESS_SLICE_PROTO(launches.id, 'OpenDexFilesFromOat'),
+      'time_verify_class', MAIN_PROCESS_SLICE_PROTO(launches.id, 'VerifyClass'),
+      'time_gc_total', MAIN_PROCESS_SLICE_PROTO(launches.id, 'GC'),
       'time_before_start_process', (
         SELECT AndroidStartupMetric_Slice(
           'dur_ns', ts - launches.ts,
@@ -481,27 +467,6 @@
         FROM zygote_forks_by_id z
         WHERE z.id = launches.id
       ),
-      'time_inflate', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'inflate'
-      ),
-      'time_get_resources', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id
-        AND name = 'ResourcesManager#getResources'
-      ),
-      'time_dex_open', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'OpenDexFilesFromOat'
-      ),
-      'time_verify_class', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'VerifyClass'
-      ),
       'jit_compiled_methods', (
         SELECT count
         FROM jit_compiled_methods_materialized s
@@ -516,11 +481,6 @@
         FROM launch_threads_cpu_materialized
         WHERE launch_id = launches.id
       ),
-      'time_gc_total', (
-        SELECT slice_proto
-        FROM main_process_slice s
-        WHERE s.launch_id = launches.id AND name = 'GC'
-      ),
       'time_gc_on_cpu', (
         SELECT
           NULL_IF_EMPTY(AndroidStartupMetric_Slice(
diff --git a/src/trace_processor/sqlite/BUILD.gn b/src/trace_processor/sqlite/BUILD.gn
index af1adbe..7863cc5 100644
--- a/src/trace_processor/sqlite/BUILD.gn
+++ b/src/trace_processor/sqlite/BUILD.gn
@@ -17,6 +17,8 @@
 if (enable_perfetto_trace_processor_sqlite) {
   source_set("sqlite") {
     sources = [
+      "create_function.cc",
+      "create_function.h",
       "db_sqlite_table.cc",
       "db_sqlite_table.h",
       "query_cache.h",
diff --git a/src/trace_processor/sqlite/create_function.cc b/src/trace_processor/sqlite/create_function.cc
new file mode 100644
index 0000000..692018e
--- /dev/null
+++ b/src/trace_processor/sqlite/create_function.cc
@@ -0,0 +1,397 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/trace_processor/sqlite/create_function.h"
+
+#include "perfetto/base/status.h"
+#include "perfetto/trace_processor/basic_types.h"
+#include "src/trace_processor/sqlite/scoped_db.h"
+#include "src/trace_processor/sqlite/sqlite_utils.h"
+#include "src/trace_processor/util/status_macros.h"
+
+namespace perfetto {
+namespace trace_processor {
+
+namespace {
+
+bool IsValidName(base::StringView str) {
+  auto pred = [](char c) { return !(isalnum(c) || c == '_'); };
+  return std::find_if(str.begin(), str.end(), pred) == str.end();
+}
+
+base::Optional<SqlValue::Type> ParseType(base::StringView str) {
+  if (str == "INT" || str == "LONG" || str == "BOOL") {
+    return SqlValue::Type::kLong;
+  } else if (str == "DOUBLE" || str == "FLOAT") {
+    return SqlValue::Type::kDouble;
+  } else if (str == "STRING") {
+    return SqlValue::Type::kString;
+  } else if (str == "PROTO" || str == "BYTES") {
+    return SqlValue::Type::kBytes;
+  }
+  return base::nullopt;
+}
+
+const char* SqliteTypeToFriendlyString(SqlValue::Type type) {
+  switch (type) {
+    case SqlValue::Type::kNull:
+      return "NULL";
+    case SqlValue::Type::kLong:
+      return "INT/LONG/BOOL";
+    case SqlValue::Type::kDouble:
+      return "FLOAT/DOUBLE";
+    case SqlValue::Type::kString:
+      return "STRING";
+    case SqlValue::Type::kBytes:
+      return "BYTES/PROTO";
+  }
+  PERFETTO_FATAL("For GCC");
+}
+
+base::Status TypeCheckSqliteValue(sqlite3_value* value,
+                                  SqlValue::Type expected_type) {
+  SqlValue::Type actual_type =
+      sqlite_utils::SqliteTypeToSqlValueType(sqlite3_value_type(value));
+  if (actual_type != SqlValue::Type::kNull && actual_type != expected_type) {
+    return base::ErrStatus(
+        "does not have expected type: expected %s, actual %s",
+        SqliteTypeToFriendlyString(expected_type),
+        SqliteTypeToFriendlyString(actual_type));
+  }
+  return base::OkStatus();
+}
+
+struct Prototype {
+  struct Argument {
+    std::string dollar_name;
+    SqlValue::Type type;
+
+    bool operator==(const Argument& other) const {
+      return dollar_name == other.dollar_name && type == other.type;
+    }
+  };
+  std::string function_name;
+  std::vector<Argument> arguments;
+
+  bool operator==(const Prototype& other) const {
+    return function_name == other.function_name && arguments == other.arguments;
+  }
+  bool operator!=(const Prototype& other) const { return !(*this == other); }
+};
+
+base::Status ParsePrototype(base::StringView raw, Prototype& out) {
+  // Examples of function prototypes:
+  // ANDROID_SDK_LEVEL()
+  // STARTUP_SLICE(dur_ns INT)
+  // FIND_NEXT_SLICE_WITH_NAME(ts INT, name STRING)
+
+  size_t function_name_end = raw.find('(');
+  if (function_name_end == base::StringView::npos) {
+    return base::ErrStatus(
+        "CREATE_FUNCTION[prototype=%s]: unable to find bracket starting "
+        "argument list",
+        raw.ToStdString().c_str());
+  }
+
+  base::StringView function_name = raw.substr(0, function_name_end);
+  if (!IsValidName(function_name)) {
+    return base::ErrStatus(
+        "CREATE_FUNCTION[prototype=%s]: function name %s is not alphanumeric",
+        raw.ToStdString().c_str(), function_name.ToStdString().c_str());
+  }
+
+  size_t args_start = function_name_end + 1;
+  size_t args_end = raw.find(')', function_name_end);
+  if (args_end == base::StringView::npos) {
+    return base::ErrStatus(
+        "CREATE_FUNCTION[prototype=%s]: unable to find bracket ending "
+        "argument list",
+        raw.ToStdString().c_str());
+  }
+
+  base::StringView args_str = raw.substr(args_start, args_end - args_start);
+  for (const auto& arg : base::SplitString(args_str.ToStdString(), ",")) {
+    const auto& arg_name_and_type = base::SplitString(arg, " ");
+    if (arg_name_and_type.size() != 2) {
+      return base::ErrStatus(
+          "CREATE_FUNCTION[prototype=%s, arg=%s]: argument in function "
+          "prototye should be of the form `name type`",
+          raw.ToStdString().c_str(), arg.c_str());
+    }
+
+    const auto& arg_name = arg_name_and_type[0];
+    const auto& arg_type_str = arg_name_and_type[1];
+    if (!IsValidName(base::StringView(arg_name))) {
+      return base::ErrStatus(
+          "CREATE_FUNCTION[prototype=%s, arg=%s]: argument is not alphanumeric",
+          raw.ToStdString().c_str(), arg.c_str());
+    }
+
+    auto opt_arg_type = ParseType(base::StringView(arg_type_str));
+    if (!opt_arg_type) {
+      return base::ErrStatus(
+          "CREATE_FUNCTION[prototype=%s, arg=%s]: unknown arg type",
+          raw.ToStdString().c_str(), arg.c_str());
+    }
+
+    SqlValue::Type arg_type = *opt_arg_type;
+    PERFETTO_DCHECK(arg_type != SqlValue::Type::kNull);
+    out.arguments.push_back({"$" + arg_name, arg_type});
+  }
+
+  out.function_name = function_name.ToStdString();
+  return base::OkStatus();
+}
+
+struct CreatedFunction : public SqlFunction {
+  struct Context {
+    sqlite3* db;
+    Prototype prototype;
+    SqlValue::Type return_type;
+    std::string sql;
+    sqlite3_stmt* stmt;
+  };
+
+  static base::Status Run(Context* ctx,
+                          size_t argc,
+                          sqlite3_value** argv,
+                          SqlValue& out,
+                          Destructors&);
+  static base::Status Cleanup(Context*);
+};
+
+base::Status SqliteRetToStatus(CreatedFunction::Context* ctx, int ret) {
+  if (ret != SQLITE_ROW && ret != SQLITE_DONE) {
+    return base::ErrStatus("%s: SQLite error while executing function body: %s",
+                           ctx->prototype.function_name.c_str(),
+                           sqlite3_errmsg(ctx->db));
+  }
+  return base::OkStatus();
+}
+
+base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
+                                  size_t argc,
+                                  sqlite3_value** argv,
+                                  SqlValue& out,
+                                  Destructors&) {
+  if (argc != ctx->prototype.arguments.size()) {
+    return base::ErrStatus(
+        "%s: invalid number of args; expected %zu, received %zu",
+        ctx->prototype.function_name.c_str(), ctx->prototype.arguments.size(),
+        argc);
+  }
+
+  // Type check all the arguments.
+  for (size_t i = 0; i < argc; ++i) {
+    sqlite3_value* arg = argv[i];
+    base::Status status =
+        TypeCheckSqliteValue(arg, ctx->prototype.arguments[i].type);
+    if (!status.ok()) {
+      return base::ErrStatus("%s[arg=%s]: argument %zu %s",
+                             ctx->prototype.function_name.c_str(),
+                             sqlite3_value_text(arg), i, status.c_message());
+    }
+  }
+
+  // Bind all the arguments to the appropriate places in the function.
+  for (size_t i = 0; i < argc; ++i) {
+    const auto& arg = ctx->prototype.arguments[i];
+    int index =
+        sqlite3_bind_parameter_index(ctx->stmt, arg.dollar_name.c_str());
+
+    // If the argument is not in the query, this just means its an unused
+    // argument which we can just ignore.
+    if (index == 0)
+      continue;
+
+    int ret = sqlite3_bind_value(ctx->stmt, index, argv[i]);
+    if (ret != SQLITE_OK) {
+      return base::ErrStatus(
+          "%s: SQLite error while binding value to argument %zu: %s",
+          ctx->prototype.function_name.c_str(), i, sqlite3_errmsg(ctx->db));
+    }
+  }
+
+  int ret = sqlite3_step(ctx->stmt);
+  RETURN_IF_ERROR(SqliteRetToStatus(ctx, ret));
+  if (ret == SQLITE_DONE)
+    // No return value means we just return don't set |out|.
+    return base::OkStatus();
+
+  PERFETTO_DCHECK(ret == SQLITE_ROW);
+  size_t col_count = static_cast<size_t>(sqlite3_column_count(ctx->stmt));
+  if (col_count != 1) {
+    return base::ErrStatus(
+        "%s: SQL definition should only return one column: returned %zu "
+        "columns",
+        ctx->prototype.function_name.c_str(), col_count);
+  }
+
+  out = sqlite_utils::SqliteValueToSqlValue(sqlite3_column_value(ctx->stmt, 0));
+  return base::OkStatus();
+}
+
+base::Status CreatedFunction::Cleanup(CreatedFunction::Context* ctx) {
+  int ret = sqlite3_step(ctx->stmt);
+  RETURN_IF_ERROR(SqliteRetToStatus(ctx, ret));
+  if (ret == SQLITE_ROW) {
+    return base::ErrStatus(
+        "%s: multiple values were returned when executing function body",
+        ctx->prototype.function_name.c_str());
+  }
+  PERFETTO_DCHECK(ret == SQLITE_DONE);
+
+  // Make sure to reset the statement to remove any bindings.
+  ret = sqlite3_reset(ctx->stmt);
+  if (ret != SQLITE_OK) {
+    return base::ErrStatus("%s: error while resetting metric",
+                           ctx->prototype.function_name.c_str());
+  }
+  return base::OkStatus();
+}
+
+}  // namespace
+
+size_t CreateFunction::NameAndArgc::Hasher::operator()(
+    const NameAndArgc& s) const noexcept {
+  base::Hash hash;
+  hash.Update(s.name.data(), s.name.size());
+  hash.Update(s.argc);
+  return static_cast<size_t>(hash.digest());
+}
+
+base::Status CreateFunction::Run(CreateFunction::Context* ctx,
+                                 size_t argc,
+                                 sqlite3_value** argv,
+                                 SqlValue&,
+                                 Destructors&) {
+  if (argc != 3) {
+    return base::ErrStatus(
+        "CREATE_FUNCTION: invalid number of args; expected %u, received %zu",
+        3u, argc);
+  }
+
+  sqlite3_value* prototype_value = argv[0];
+  sqlite3_value* return_type_value = argv[1];
+  sqlite3_value* sql_defn_value = argv[2];
+
+  // Type check all the arguments.
+  {
+    auto type_check = [prototype_value](sqlite3_value* value,
+                                        SqlValue::Type type, const char* desc) {
+      base::Status status = TypeCheckSqliteValue(value, type);
+      if (!status.ok()) {
+        return base::ErrStatus("CREATE_FUNCTION[prototype=%s]: %s %s",
+                               sqlite3_value_text(prototype_value), desc,
+                               status.c_message());
+      }
+      return base::OkStatus();
+    };
+
+    RETURN_IF_ERROR(type_check(prototype_value, SqlValue::Type::kString,
+                               "function name (first argument)"));
+    RETURN_IF_ERROR(type_check(return_type_value, SqlValue::Type::kString,
+                               "return type (second argument)"));
+    RETURN_IF_ERROR(type_check(sql_defn_value, SqlValue::Type::kString,
+                               "SQL definition (third argument)"));
+  }
+
+  // Extract the arguments from the value wrappers.
+  auto extract_string = [](sqlite3_value* value) -> base::StringView {
+    return reinterpret_cast<const char*>(sqlite3_value_text(value));
+  };
+  base::StringView prototype_str = extract_string(prototype_value);
+  base::StringView return_type_str = extract_string(return_type_value);
+  std::string sql_defn_str = extract_string(sql_defn_value).ToStdString();
+
+  // Parse all the arguments into a more friendly form.
+  Prototype prototype;
+  RETURN_IF_ERROR(ParsePrototype(prototype_str, prototype));
+
+  // Parse the return type into a enum format.
+  auto opt_return_type = ParseType(return_type_str);
+  if (!opt_return_type) {
+    return base::ErrStatus(
+        "CREATE_FUNCTION[prototype=%s, return=%s]: unknown return type "
+        "specified",
+        prototype_str.ToStdString().c_str(),
+        return_type_str.ToStdString().c_str());
+  }
+  SqlValue::Type return_type = *opt_return_type;
+
+  int created_argc = static_cast<int>(prototype.arguments.size());
+  NameAndArgc key{prototype.function_name, created_argc};
+  auto it = ctx->state->find(key);
+  if (it != ctx->state->end()) {
+    // If the function already exists, just verify that the prototype, return
+    // type and SQL matches exactly with what we already had registered. By
+    // doing this, we can avoid the problem plaguing C++ macros where macro
+    // ordering determines which one gets run.
+    auto* created_ctx = static_cast<CreatedFunction::Context*>(
+        it->second.created_functon_context);
+
+    if (created_ctx->prototype != prototype) {
+      return base::ErrStatus(
+          "CREATE_FUNCTION[prototype=%s]: function prototype changed",
+          prototype_str.ToStdString().c_str());
+    }
+
+    if (created_ctx->return_type != return_type) {
+      return base::ErrStatus(
+          "CREATE_FUNCTION[prototype=%s]: return type changed from %s to %s",
+          prototype_str.ToStdString().c_str(),
+          SqliteTypeToFriendlyString(created_ctx->return_type),
+          return_type_str.ToStdString().c_str());
+    }
+
+    if (created_ctx->sql != sql_defn_str) {
+      return base::ErrStatus(
+          "CREATE_FUNCTION[prototype=%s]: function SQL changed from %s to %s",
+          prototype_str.ToStdString().c_str(), created_ctx->sql.c_str(),
+          sql_defn_str.c_str());
+    }
+    return base::OkStatus();
+  }
+
+  // Prepare the SQL definition as a statement using SQLite.
+  ScopedStmt stmt;
+  sqlite3_stmt* stmt_raw = nullptr;
+  int ret = sqlite3_prepare_v2(ctx->db, sql_defn_str.data(),
+                               static_cast<int>(sql_defn_str.size()), &stmt_raw,
+                               nullptr);
+  if (ret != SQLITE_OK) {
+    return base::ErrStatus(
+        "CREATE_FUNCTION[prototype=%s]: SQLite error when preparing "
+        "statement "
+        "%s",
+        prototype_str.ToStdString().c_str(), sqlite3_errmsg(ctx->db));
+  }
+  stmt.reset(stmt_raw);
+
+  std::unique_ptr<CreatedFunction::Context> created(
+      new CreatedFunction::Context{ctx->db, std::move(prototype), return_type,
+                                   std::move(sql_defn_str), stmt.get()});
+  CreatedFunction::Context* created_ptr = created.get();
+  RETURN_IF_ERROR(RegisterSqlFunction<CreatedFunction>(
+      ctx->db, key.name.c_str(), created_argc, std::move(created)));
+  ctx->state->emplace(key, PerFunctionState{std::move(stmt), created_ptr});
+
+  // CREATE_FUNCTION doesn't have a return value so just don't sent |out|.
+  return base::OkStatus();
+}
+
+}  // namespace trace_processor
+}  // namespace perfetto
diff --git a/src/trace_processor/sqlite/create_function.h b/src/trace_processor/sqlite/create_function.h
new file mode 100644
index 0000000..c6778c2
--- /dev/null
+++ b/src/trace_processor/sqlite/create_function.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_TRACE_PROCESSOR_SQLITE_CREATE_FUNCTION_H_
+#define SRC_TRACE_PROCESSOR_SQLITE_CREATE_FUNCTION_H_
+
+#include <sqlite3.h>
+#include <unordered_map>
+
+#include "src/trace_processor/sqlite/register_function.h"
+
+namespace perfetto {
+namespace trace_processor {
+
+// Implementation of CREATE_FUNCTION SQL function.
+// See https://perfetto.dev/docs/analysis/metrics#metric-helper-functions for
+// usage of this function.
+struct CreateFunction : public SqlFunction {
+  struct PerFunctionState {
+    ScopedStmt stmt;
+    // void* to avoid leaking state.
+    void* created_functon_context;
+  };
+  struct NameAndArgc {
+    std::string name;
+    int argc;
+
+    struct Hasher {
+      std::size_t operator()(const NameAndArgc& s) const noexcept;
+    };
+    bool operator==(const NameAndArgc& other) const {
+      return name == other.name && argc == other.argc;
+    }
+  };
+  using State = std::unordered_map<NameAndArgc,
+                                   CreateFunction::PerFunctionState,
+                                   NameAndArgc::Hasher>;
+
+  struct Context {
+    sqlite3* db;
+    State* state;
+  };
+
+  static base::Status Run(Context* ctx,
+                          size_t argc,
+                          sqlite3_value** argv,
+                          SqlValue& out,
+                          Destructors&);
+};
+
+}  // namespace trace_processor
+}  // namespace perfetto
+
+#endif  // SRC_TRACE_PROCESSOR_SQLITE_CREATE_FUNCTION_H_
diff --git a/src/trace_processor/sqlite/register_function.h b/src/trace_processor/sqlite/register_function.h
index 0df7dcc..1761af4 100644
--- a/src/trace_processor/sqlite/register_function.h
+++ b/src/trace_processor/sqlite/register_function.h
@@ -18,13 +18,8 @@
 #define SRC_TRACE_PROCESSOR_SQLITE_REGISTER_FUNCTION_H_
 
 #include <sqlite3.h>
-#include <cstddef>
 #include <memory>
-#include <set>
 
-#include "perfetto/base/status.h"
-#include "perfetto/trace_processor/basic_types.h"
-#include "src/trace_processor/sqlite/scoped_db.h"
 #include "src/trace_processor/sqlite/sqlite_utils.h"
 
 namespace perfetto {
diff --git a/src/trace_processor/sqlite/sqlite_utils.h b/src/trace_processor/sqlite/sqlite_utils.h
index 66a3ed6..9bb26d7 100644
--- a/src/trace_processor/sqlite/sqlite_utils.h
+++ b/src/trace_processor/sqlite/sqlite_utils.h
@@ -20,12 +20,6 @@
 #include <math.h>
 #include <sqlite3.h>
 
-#include <functional>
-#include <limits>
-#include <string>
-
-#include "perfetto/base/logging.h"
-#include "perfetto/ext/base/optional.h"
 #include "perfetto/ext/base/string_utils.h"
 #include "perfetto/trace_processor/basic_types.h"
 #include "src/trace_processor/sqlite/scoped_db.h"
diff --git a/src/trace_processor/trace_processor_impl.cc b/src/trace_processor/trace_processor_impl.cc
index 9f265ec..6fd4579 100644
--- a/src/trace_processor/trace_processor_impl.cc
+++ b/src/trace_processor/trace_processor_impl.cc
@@ -47,6 +47,7 @@
 #include "src/trace_processor/importers/proto/metadata_tracker.h"
 #include "src/trace_processor/importers/systrace/systrace_trace_parser.h"
 #include "src/trace_processor/iterator_impl.h"
+#include "src/trace_processor/sqlite/create_function.h"
 #include "src/trace_processor/sqlite/register_function.h"
 #include "src/trace_processor/sqlite/span_join_operator_table.h"
 #include "src/trace_processor/sqlite/sql_stats_table.h"
@@ -742,6 +743,10 @@
   RegisterFunction<ExportJson>(db, "EXPORT_JSON", 1, context_.storage.get(),
                                false);
   RegisterFunction<ExtractArg>(db, "EXTRACT_ARG", 2, context_.storage.get());
+  RegisterFunction<CreateFunction>(
+      db, "CREATE_FUNCTION", 3,
+      std::unique_ptr<CreateFunction::Context>(
+          new CreateFunction::Context{db_.get(), &create_function_state_}));
 
   // Old style function registration.
   // TODO(lalitm): migrate this over to using RegisterFunction once aggregate
diff --git a/src/trace_processor/trace_processor_impl.h b/src/trace_processor/trace_processor_impl.h
index f94e6d5..90a4393 100644
--- a/src/trace_processor/trace_processor_impl.h
+++ b/src/trace_processor/trace_processor_impl.h
@@ -21,6 +21,7 @@
 
 #include <atomic>
 #include <functional>
+#include <map>
 #include <string>
 #include <vector>
 
@@ -28,6 +29,7 @@
 #include "perfetto/trace_processor/basic_types.h"
 #include "perfetto/trace_processor/status.h"
 #include "perfetto/trace_processor/trace_processor.h"
+#include "src/trace_processor/sqlite/create_function.h"
 #include "src/trace_processor/sqlite/db_sqlite_table.h"
 #include "src/trace_processor/sqlite/query_cache.h"
 #include "src/trace_processor/sqlite/scoped_db.h"
@@ -46,6 +48,12 @@
  public:
   explicit TraceProcessorImpl(const Config&);
 
+  TraceProcessorImpl(const TraceProcessorImpl&) = delete;
+  TraceProcessorImpl& operator=(const TraceProcessorImpl&) = delete;
+
+  TraceProcessorImpl(TraceProcessorImpl&&) = delete;
+  TraceProcessorImpl& operator=(TraceProcessorImpl&&) = delete;
+
   ~TraceProcessorImpl() override;
 
   // TraceProcessorStorage implementation:
@@ -104,7 +112,15 @@
   }
 
   bool IsRootMetricField(const std::string& metric_name);
+
+  // Keep this first: we need this to be destroyed after we clean up
+  // everything else.
   ScopedDb db_;
+
+  // State necessary for CREATE_FUNCTION invocations. We store this here as we
+  // need to finalize any prepared statements *before* we destroy the database.
+  CreateFunction::State create_function_state_;
+
   std::unique_ptr<QueryCache> query_cache_;
 
   DescriptorPool pool_;