Merge "tp: implement wrapper around sqlite window functions" into main
diff --git a/BUILD b/BUILD
index 51b4e5f..8eedd8d 100644
--- a/BUILD
+++ b/BUILD
@@ -2665,6 +2665,7 @@
         "src/trace_processor/sqlite/sqlite_tokenizer.h",
         "src/trace_processor/sqlite/sqlite_utils.cc",
         "src/trace_processor/sqlite/sqlite_utils.h",
+        "src/trace_processor/sqlite/sqlite_window_function.h",
         "src/trace_processor/sqlite/stats_table.cc",
         "src/trace_processor/sqlite/stats_table.h",
     ],
diff --git a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h
index e180027..f11d04e 100644
--- a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h
+++ b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h
@@ -43,6 +43,7 @@
 #include "src/trace_processor/sqlite/sqlite_engine.h"
 #include "src/trace_processor/sqlite/sqlite_result.h"
 #include "src/trace_processor/sqlite/sqlite_utils.h"
+#include "src/trace_processor/sqlite/sqlite_window_function.h"
 #include "src/trace_processor/util/sql_argument.h"
 #include "src/trace_processor/util/sql_modules.h"
 
@@ -87,7 +88,7 @@
   //
   // The format of the function is given by the |SqlFunction|.
   //
-  // |name|:        name of the function in SQL
+  // |name|:        name of the function in SQL.
   // |argc|:        number of arguments for this function. This can be -1 if
   //                the number of arguments is variable.
   // |ctx|:         context object for the function (see SqlFunction::Run);
@@ -114,6 +115,24 @@
       std::unique_ptr<typename Function::Context> ctx,
       bool deterministic = true);
 
+  // Registers a trace processor C++ window function to be runnable from SQL.
+  //
+  // The format of the function is given by the |SqliteWindowFunction|.
+  //
+  // |name|:        name of the function in SQL.
+  // |argc|:        number of arguments for this function. This can be -1 if
+  //                the number of arguments is variable.
+  // |ctx|:         context object for the function (see SqlFunction::Run);
+  //                this object *must* outlive the function so should likely be
+  //                either static or scoped to the lifetime of TraceProcessor.
+  // |determistic|: whether this function has deterministic output given the
+  //                same set of arguments.
+  template <typename Function = SqliteWindowFunction>
+  base::Status RegisterSqliteWindowFunction(const char* name,
+                                            int argc,
+                                            typename Function::Context* ctx,
+                                            bool deterministic = true);
+
   // Registers a function with the prototype |prototype| which returns a value
   // of |return_type| and is implemented by executing the SQL statement |sql|.
   base::Status RegisterRuntimeFunction(bool replace,
@@ -173,7 +192,7 @@
     // The missing objects from the above query are static functions, runtime
     // functions and macros. Add those in now.
     return query_count + static_function_count_ + runtime_function_count_ +
-           macros_.size();
+           static_window_function_count_ + macros_.size();
   }
 
   // Find RuntimeTable registered with engine with provided name.
@@ -234,6 +253,7 @@
 
   uint64_t static_function_count_ = 0;
   uint64_t runtime_function_count_ = 0;
+  uint64_t static_window_function_count_ = 0;
 
   base::FlatHashMap<std::string, std::unique_ptr<RuntimeTableFunction::State>>
       runtime_table_fn_states_;
@@ -316,6 +336,18 @@
 }
 
 template <typename Function>
+base::Status PerfettoSqlEngine::RegisterSqliteWindowFunction(
+    const char* name,
+    int argc,
+    typename Function::Context* ctx,
+    bool deterministic) {
+  static_window_function_count_++;
+  return engine_->RegisterWindowFunction(
+      name, argc, Function::Step, Function::Inverse, Function::Value,
+      Function::Final, ctx, nullptr, deterministic);
+}
+
+template <typename Function>
 base::Status PerfettoSqlEngine::RegisterStaticFunction(
     const char* name,
     int argc,
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.cc b/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.cc
index 2f37c86..70b3bce 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.cc
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.cc
@@ -14,11 +14,18 @@
 
 #include "src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.h"
 
+#include <cstddef>
+#include <cstdint>
 #include <queue>
 #include <vector>
+#include "perfetto/base/logging.h"
+#include "perfetto/base/status.h"
 #include "perfetto/ext/base/status_or.h"
 #include "perfetto/trace_processor/basic_types.h"
+#include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
+#include "src/trace_processor/sqlite/sqlite_result.h"
 #include "src/trace_processor/sqlite/sqlite_utils.h"
+#include "src/trace_processor/sqlite/sqlite_window_function.h"
 #include "src/trace_processor/util/status_macros.h"
 
 namespace perfetto::trace_processor {
@@ -118,7 +125,7 @@
 
 base::StatusOr<SlicePacker*> GetOrCreateAggregationContext(
     sqlite3_context* ctx) {
-  SlicePacker** packer = static_cast<SlicePacker**>(
+  auto** packer = static_cast<SlicePacker**>(
       sqlite3_aggregate_context(ctx, sizeof(SlicePacker*)));
   if (!packer) {
     return base::ErrStatus("Failed to allocate aggregate context");
@@ -130,7 +137,9 @@
   return *packer;
 }
 
-base::Status Step(sqlite3_context* ctx, size_t argc, sqlite3_value** argv) {
+base::Status StepStatus(sqlite3_context* ctx,
+                        size_t argc,
+                        sqlite3_value** argv) {
   base::StatusOr<SlicePacker*> slice_packer =
       GetOrCreateAggregationContext(ctx);
   RETURN_IF_ERROR(slice_packer.status());
@@ -146,58 +155,50 @@
   return slice_packer.value()->AddSlice(ts->AsLong(), dur.value().AsLong());
 }
 
-void StepWrapper(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
-  PERFETTO_CHECK(argc >= 0);
+struct InternalLayout : public SqliteWindowFunction {
+  static void Step(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
+    PERFETTO_CHECK(argc >= 0);
 
-  base::Status status = Step(ctx, static_cast<size_t>(argc), argv);
-  if (!status.ok()) {
-    sqlite::utils::SetError(ctx, kFunctionName, status);
-    return;
+    base::Status status = StepStatus(ctx, static_cast<size_t>(argc), argv);
+    if (!status.ok()) {
+      return sqlite::utils::SetError(ctx, kFunctionName, status);
+    }
   }
-}
 
-void FinalWrapper(sqlite3_context* ctx) {
-  SlicePacker** slice_packer = static_cast<SlicePacker**>(
-      sqlite3_aggregate_context(ctx, sizeof(SlicePacker*)));
-  if (!slice_packer || !*slice_packer) {
-    return;
-  }
-  sqlite::result::Long(ctx,
-                       static_cast<int64_t>((*slice_packer)->GetLastDepth()));
-  delete *slice_packer;
-}
-
-void ValueWrapper(sqlite3_context* ctx) {
-  base::StatusOr<SlicePacker*> slice_packer =
-      GetOrCreateAggregationContext(ctx);
-  if (!slice_packer.ok()) {
-    sqlite::utils::SetError(ctx, kFunctionName, slice_packer.status());
-    return;
-  }
-  sqlite::result::Long(
-      ctx, static_cast<int64_t>(slice_packer.value()->GetLastDepth()));
-}
-
-void InverseWrapper(sqlite3_context* ctx, int, sqlite3_value**) {
-  sqlite::utils::SetError(ctx, kFunctionName, base::ErrStatus(R"(
+  static void Inverse(sqlite3_context* ctx, int, sqlite3_value**) {
+    sqlite::utils::SetError(ctx, kFunctionName, base::ErrStatus(R"(
 The inverse step is not supported: the window clause should be
 "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".
 )"));
-}
+  }
+
+  static void Value(sqlite3_context* ctx) {
+    base::StatusOr<SlicePacker*> slice_packer =
+        GetOrCreateAggregationContext(ctx);
+    if (!slice_packer.ok()) {
+      return sqlite::utils::SetError(ctx, kFunctionName, slice_packer.status());
+    }
+    return sqlite::result::Long(
+        ctx, static_cast<int64_t>(slice_packer.value()->GetLastDepth()));
+  }
+
+  static void Final(sqlite3_context* ctx) {
+    auto** slice_packer = static_cast<SlicePacker**>(
+        sqlite3_aggregate_context(ctx, sizeof(SlicePacker*)));
+    if (!slice_packer || !*slice_packer) {
+      return;
+    }
+    sqlite::result::Long(ctx,
+                         static_cast<int64_t>((*slice_packer)->GetLastDepth()));
+    delete *slice_packer;
+  }
+};
 
 }  // namespace
 
-base::Status LayoutFunctions::Register(sqlite3* db,
-                                       TraceProcessorContext* context) {
-  int flags = SQLITE_UTF8 | SQLITE_DETERMINISTIC;
-  int ret = sqlite3_create_window_function(
-      db, kFunctionName, 2, flags, context, StepWrapper, FinalWrapper,
-      ValueWrapper, InverseWrapper, nullptr);
-  if (ret != SQLITE_OK) {
-    return base::ErrStatus("Unable to register function with name %s",
-                           kFunctionName);
-  }
-  return base::OkStatus();
+base::Status RegisterLayoutFunctions(PerfettoSqlEngine& engine) {
+  return engine.RegisterSqliteWindowFunction<InternalLayout>(kFunctionName, 2,
+                                                             nullptr);
 }
 
 }  // namespace perfetto::trace_processor
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.h b/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.h
index 8d3543c..3b1dc4c 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.h
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.h
@@ -18,6 +18,7 @@
 #include <sqlite3.h>
 
 #include "perfetto/base/status.h"
+#include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
 
 namespace perfetto::trace_processor {
 
@@ -26,14 +27,12 @@
 // Implements INTERNAL_LAYOUT(ts, dur) window aggregate function.
 // This function takes a set of slices (ordered by ts) and computes depths
 // allowing them to be displayed on a single track in a non-overlapping manner,
-// while trying to minimising total height.
+// while trying to minimise the total height.
 //
 // TODO(altimin): this should support grouping sets of sets of slices (aka
 // "tracks") by passing 'track_id' parameter. The complication is that we will
 // need to know the max depth for each "track", so it's punted for now.
-struct LayoutFunctions {
-  static base::Status Register(sqlite3* db, TraceProcessorContext* context);
-};
+base::Status RegisterLayoutFunctions(PerfettoSqlEngine& engine);
 
 }  // namespace perfetto::trace_processor
 
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h b/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h
index 2da34fc..00de395 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h
@@ -22,10 +22,12 @@
 #include <type_traits>
 
 #include "perfetto/base/logging.h"
+#include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
 #include "src/trace_processor/sqlite/sqlite_result.h"
-#include "src/trace_processor/sqlite/sqlite_utils.h"
+#include "src/trace_processor/sqlite/sqlite_window_function.h"
 
 namespace perfetto::trace_processor {
+
 // Keeps track of the latest non null value and its position withing the
 // window. Every time the window shrinks (`xInverse` is called) the window size
 // is reduced by one and the position of the value moves one back, if it gets
@@ -92,55 +94,52 @@
               "Must be able to be destroyed by just calling free (i.e. no "
               "destructor called)");
 
-inline void LastNonNullStep(sqlite3_context* ctx,
-                            int argc,
-                            sqlite3_value** argv) {
-  if (argc != 1) {
-    return sqlite::result::Error(
-        ctx, "Unsupported number of args passed to LAST_NON_NULL");
+class LastNonNull : public SqliteWindowFunction {
+ public:
+  static void Step(sqlite3_context* ctx, int argc, sqlite3_value** argv) {
+    if (argc != 1) {
+      return sqlite::result::Error(
+          ctx, "Unsupported number of args passed to LAST_NON_NULL");
+    }
+
+    auto* ptr = LastNonNullAggregateContext::GetOrCreate(ctx);
+    if (!ptr) {
+      return sqlite::result::Error(ctx,
+                                   "LAST_NON_NULL: Failed to allocate context");
+    }
+
+    ptr->PushBack(argv[0]);
   }
 
-  auto* ptr = LastNonNullAggregateContext::GetOrCreate(ctx);
-  if (!ptr) {
-    return sqlite::result::Error(ctx,
-                                 "LAST_NON_NULL: Failed to allocate context");
+  static void Inverse(sqlite3_context* ctx, int, sqlite3_value**) {
+    auto* ptr = LastNonNullAggregateContext::GetOrCreate(ctx);
+    PERFETTO_CHECK(ptr != nullptr);
+    ptr->PopFront();
   }
 
-  ptr->PushBack(argv[0]);
-}
-
-inline void LastNonNullInverse(sqlite3_context* ctx, int, sqlite3_value**) {
-  auto* ptr = LastNonNullAggregateContext::GetOrCreate(ctx);
-  PERFETTO_CHECK(ptr != nullptr);
-  ptr->PopFront();
-}
-
-inline void LastNonNullValue(sqlite3_context* ctx) {
-  auto* ptr = LastNonNullAggregateContext::GetOrCreate(ctx);
-  if (!ptr || !ptr->last_non_null_value()) {
-    return sqlite::result::Null(ctx);
+  static void Value(sqlite3_context* ctx) {
+    auto* ptr = LastNonNullAggregateContext::GetOrCreate(ctx);
+    if (!ptr || !ptr->last_non_null_value()) {
+      return sqlite::result::Null(ctx);
+    }
+    sqlite3_result_value(ctx, ptr->last_non_null_value());
   }
-  sqlite::result::Value(ctx, ptr->last_non_null_value());
+
+  static void Final(sqlite3_context* ctx) {
+    auto* ptr = LastNonNullAggregateContext::Get(ctx);
+    if (!ptr || !ptr->last_non_null_value()) {
+      return sqlite::result::Null(ctx);
+    }
+    sqlite::result::Value(ctx, ptr->last_non_null_value());
+    ptr->Destroy();
+  }
+};
+
+inline base::Status RegisterLastNonNullFunction(PerfettoSqlEngine& engine) {
+  return engine.RegisterSqliteWindowFunction<LastNonNull>("LAST_NON_NULL", 1,
+                                                          nullptr);
 }
 
-inline void LastNonNullFinal(sqlite3_context* ctx) {
-  auto* ptr = LastNonNullAggregateContext::Get(ctx);
-  if (!ptr || !ptr->last_non_null_value()) {
-    return sqlite::result::Null(ctx);
-  }
-  sqlite::result::Value(ctx, ptr->last_non_null_value());
-  ptr->Destroy();
-}
-
-inline void RegisterLastNonNullFunction(sqlite3* db) {
-  auto ret = sqlite3_create_window_function(
-      db, "LAST_NON_NULL", 1, SQLITE_UTF8 | SQLITE_DETERMINISTIC, nullptr,
-      &LastNonNullStep, &LastNonNullFinal, &LastNonNullValue,
-      &LastNonNullInverse, nullptr);
-  if (ret) {
-    PERFETTO_ELOG("Error initializing LAST_NON_NULL");
-  }
-}
 }  // namespace perfetto::trace_processor
 
 #endif  // SRC_TRACE_PROCESSOR_PERFETTO_SQL_INTRINSICS_FUNCTIONS_WINDOW_FUNCTIONS_H_
diff --git a/src/trace_processor/sqlite/BUILD.gn b/src/trace_processor/sqlite/BUILD.gn
index 37682fa..e82712d 100644
--- a/src/trace_processor/sqlite/BUILD.gn
+++ b/src/trace_processor/sqlite/BUILD.gn
@@ -36,6 +36,7 @@
     "sqlite_utils.cc",
     "sqlite_utils.h",
     "sqlite_utils.h",
+    "sqlite_window_function.h",
     "stats_table.cc",
     "stats_table.h",
   ]
diff --git a/src/trace_processor/sqlite/sqlite_engine.cc b/src/trace_processor/sqlite/sqlite_engine.cc
index 9c5da35..ec9c01d 100644
--- a/src/trace_processor/sqlite/sqlite_engine.cc
+++ b/src/trace_processor/sqlite/sqlite_engine.cc
@@ -199,6 +199,25 @@
   return base::OkStatus();
 }
 
+base::Status SqliteEngine::RegisterWindowFunction(const char* name,
+                                                  int argc,
+                                                  WindowFnStep* step,
+                                                  WindowFnInverse* inverse,
+                                                  WindowFnValue* value,
+                                                  WindowFnFinal* final,
+                                                  void* ctx,
+                                                  FnCtxDestructor* destructor,
+                                                  bool deterministic) {
+  int flags = SQLITE_UTF8 | (deterministic ? SQLITE_DETERMINISTIC : 0);
+  int ret = sqlite3_create_window_function(
+      db_.get(), name, static_cast<int>(argc), flags, ctx, step, final, value,
+      inverse, destructor);
+  if (ret != SQLITE_OK) {
+    return base::ErrStatus("Unable to register function with name %s", name);
+  }
+  return base::OkStatus();
+}
+
 base::Status SqliteEngine::UnregisterFunction(const char* name, int argc) {
   int ret = sqlite3_create_function_v2(db_.get(), name, static_cast<int>(argc),
                                        SQLITE_UTF8, nullptr, nullptr, nullptr,
diff --git a/src/trace_processor/sqlite/sqlite_engine.h b/src/trace_processor/sqlite/sqlite_engine.h
index 9af23fe..e46b9ee 100644
--- a/src/trace_processor/sqlite/sqlite_engine.h
+++ b/src/trace_processor/sqlite/sqlite_engine.h
@@ -51,6 +51,14 @@
 class SqliteEngine {
  public:
   using Fn = void(sqlite3_context* ctx, int argc, sqlite3_value** argv);
+  using WindowFnStep = void(sqlite3_context* ctx,
+                            int argc,
+                            sqlite3_value** argv);
+  using WindowFnInverse = void(sqlite3_context* ctx,
+                               int argc,
+                               sqlite3_value** argv);
+  using WindowFnValue = void(sqlite3_context* ctx);
+  using WindowFnFinal = void(sqlite3_context* ctx);
   using FnCtxDestructor = void(void*);
 
   // Wrapper class for SQLite's |sqlite3_stmt| struct and associated functions.
@@ -90,6 +98,17 @@
                                 FnCtxDestructor* ctx_destructor,
                                 bool deterministic);
 
+  // Registers a C++ window function to be runnable from SQL.
+  base::Status RegisterWindowFunction(const char* name,
+                                      int argc,
+                                      WindowFnStep* step,
+                                      WindowFnInverse* inverse,
+                                      WindowFnValue* value,
+                                      WindowFnFinal* final,
+                                      void* ctx,
+                                      FnCtxDestructor* ctx_destructor,
+                                      bool deterministic);
+
   // Unregisters a C++ function from SQL.
   base::Status UnregisterFunction(const char* name, int argc);
 
diff --git a/src/trace_processor/sqlite/sqlite_window_function.h b/src/trace_processor/sqlite/sqlite_window_function.h
new file mode 100644
index 0000000..ac61c4f
--- /dev/null
+++ b/src/trace_processor/sqlite/sqlite_window_function.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_TRACE_PROCESSOR_SQLITE_SQLITE_WINDOW_FUNCTION_H_
+#define SRC_TRACE_PROCESSOR_SQLITE_SQLITE_WINDOW_FUNCTION_H_
+
+struct sqlite3_context;
+struct sqlite3_value;
+
+namespace perfetto::trace_processor {
+
+// Prototype for a window function which can be registered with SQLite.
+//
+// See https://www.sqlite.org/windowfunctions.html#udfwinfunc for details on how
+// to implement the methods of this class.
+class SqliteWindowFunction {
+ public:
+  // The type of the context object which will be passed to the function.
+  // Can be redefined in any sub-classes to override the context.
+  using Context = void;
+
+  // The xStep function which will be executed by SQLite to add a row of values
+  // to the current window.
+  //
+  // Implementations MUST define this function themselves; this function is
+  // declared but *not* defined so linker errors will be thrown if not defined.
+  static void Step(sqlite3_context*, int argc, sqlite3_value** argv);
+
+  // The xStep function which will be executed by SQLite to remove a row of
+  // values from the current window.
+  //
+  // Implementations MUST define this function themselves; this function is
+  // declared but *not* defined so linker errors will be thrown if not defined.
+  static void Inverse(sqlite3_context* ctx, int argc, sqlite3_value** argv);
+
+  // The xValue function which will be executed by SQLite to obtain the current
+  // value of the aggregate.
+  //
+  // Implementations MUST define this function themselves; this function is
+  // declared but *not* defined so linker errors will be thrown if not defined.
+  static void Value(sqlite3_context* ctx);
+
+  // The xInverse function which will be executed by SQLite to obtain the
+  // current value of the aggregate *and* free all resources allocated by
+  // previous calls to Step, Inverse and Value.
+  //
+  // Implementations MUST define this function themselves; this function is
+  // declared but *not* defined so linker errors will be thrown if not defined.
+  static void Final(sqlite3_context* ctx);
+};
+
+}  // namespace perfetto::trace_processor
+
+#endif  // SRC_TRACE_PROCESSOR_SQLITE_SQLITE_WINDOW_FUNCTION_H_
diff --git a/src/trace_processor/trace_processor_impl.cc b/src/trace_processor/trace_processor_impl.cc
index 20bee47..9ac4a5c 100644
--- a/src/trace_processor/trace_processor_impl.cc
+++ b/src/trace_processor/trace_processor_impl.cc
@@ -693,9 +693,13 @@
   // Old style function registration.
   // TODO(lalitm): migrate this over to using RegisterFunction once aggregate
   // functions are supported.
-  RegisterLastNonNullFunction(db);
   RegisterValueAtMaxTsFunction(db);
   {
+    base::Status status = RegisterLastNonNullFunction(*engine_);
+    if (!status.ok())
+      PERFETTO_ELOG("%s", status.c_message());
+  }
+  {
     base::Status status = RegisterStackFunctions(engine_.get(), &context_);
     if (!status.ok())
       PERFETTO_ELOG("%s", status.c_message());
@@ -706,7 +710,7 @@
       PERFETTO_ELOG("%s", status.c_message());
   }
   {
-    base::Status status = LayoutFunctions::Register(db, &context_);
+    base::Status status = RegisterLayoutFunctions(*engine_);
     if (!status.ok())
       PERFETTO_ELOG("%s", status.c_message());
   }