Function memoization

Add support for `experimental_memoize` function, which enables memoization
for f(int) => int Perfetto SQL functions.

Combined with support for recursive SQL functions, it allows us to write
efficient operations over recursive trees.

Change-Id: I1593330ce1950b502fef5540e4355dcf8081edb4
diff --git a/src/trace_processor/prelude/functions/create_function.cc b/src/trace_processor/prelude/functions/create_function.cc
index 29800fe..1d2f352 100644
--- a/src/trace_processor/prelude/functions/create_function.cc
+++ b/src/trace_processor/prelude/functions/create_function.cc
@@ -59,6 +59,83 @@
   static void Cleanup(Context*);
 };
 
+class Memoizer {
+ public:
+  // Enables memoization.
+  // Only functions with a single int argument returning ints are supported.
+  base::Status EnableMemoization(const Prototype& prototype,
+                                 sql_argument::Type return_type) {
+    if (prototype.arguments.size() != 1 ||
+        TypeToSqlValueType(prototype.arguments[0].type()) !=
+            SqlValue::Type::kLong) {
+      return base::ErrStatus(
+          "EXPERIMENTAL_MEMOIZE: Function %s should take one int argument",
+          prototype.function_name.c_str());
+    }
+    if (TypeToSqlValueType(return_type) != SqlValue::Type::kLong) {
+      return base::ErrStatus(
+          "EXPERIMENTAL_MEMOIZE: Function %s should return an int",
+          prototype.function_name.c_str());
+    }
+    enabled_ = true;
+    return base::OkStatus();
+  }
+
+  // Returns the memoized value for the current invocation if it exists.
+  std::optional<SqlValue> GetMemoizedValue(size_t argc, sqlite3_value** argv) {
+    std::optional<int64_t> arg = ExtractArgForMemoization(argc, argv);
+    if (!arg) {
+      return std::nullopt;
+    }
+    int64_t* value = memoized_values_.Find(*arg);
+    if (!value) {
+      return std::nullopt;
+    }
+    is_returning_memoized_value_ = true;
+    return SqlValue::Long(*value);
+  }
+
+  // Saves the return value of the current invocation for memoization.
+  void Memoize(size_t argc, sqlite3_value** argv, SqlValue value) {
+    if (!enabled_ || value.type != SqlValue::Type::kLong) {
+      return;
+    }
+    std::optional<int64_t> arg = ExtractArgForMemoization(argc, argv);
+    if (!arg) {
+      return;
+    }
+    memoized_values_.Insert(*arg, value.AsLong());
+  }
+
+  // Returns true if memoization is enabled and the current invocation should
+  // bypass post-conditions (as we do not have a statement to check).
+  bool ShouldBypassPostConditions() {
+    bool is_returning_memoized_value = is_returning_memoized_value_;
+    is_returning_memoized_value_ = false;
+    return enabled_ && is_returning_memoized_value;
+  }
+
+ private:
+  std::optional<int64_t> ExtractArgForMemoization(size_t argc,
+                                                  sqlite3_value** argv) {
+    if (!enabled_ || argc != 1) {
+      return std::nullopt;
+    }
+    SqlValue arg = sqlite_utils::SqliteValueToSqlValue(argv[0]);
+    if (arg.type != SqlValue::Type::kLong) {
+      return std::nullopt;
+    }
+    return arg.AsLong();
+  }
+
+  bool enabled_ = false;
+  base::FlatHashMap<int64_t, int64_t> memoized_values_;
+  // This is used to skip post-conditions when we are returning a memoized
+  // value. True between a successful call to GetMemoizedValue and the call to
+  // ValidatePostConditions, false otherwise.
+  bool is_returning_memoized_value_ = false;
+};
+
 // This class is used to store the state of a CREATE_FUNCTION call.
 // It is used to store the state of the function across multiple invocations
 // of the function (e.g. when the function is called recursively).
@@ -123,6 +200,10 @@
     --current_recursion_level_;
   }
 
+  base::Status EnableMemoization() {
+    return memoizer_.EnableMemoization(prototype_, return_type_);
+  }
+
   PerfettoSqlEngine* engine() const { return engine_; }
 
   const Prototype& prototype() const { return prototype_; }
@@ -133,6 +214,8 @@
 
   bool is_valid() const { return is_valid_; }
 
+  Memoizer& memoizer() { return memoizer_; }
+
  private:
   PerfettoSqlEngine* engine_;
   Prototype prototype_;
@@ -150,6 +233,7 @@
   // by tracking whether the current function definition is valid (in which case
   // re-registration is not allowed).
   bool is_valid_ = false;
+  Memoizer memoizer_;
 };
 
 base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
@@ -181,6 +265,13 @@
     }
   }
 
+  std::optional<SqlValue> memoized_value =
+      ctx->memoizer().GetMemoizedValue(argc, argv);
+  if (memoized_value) {
+    out = *memoized_value;
+    return base::OkStatus();
+  }
+
   PERFETTO_TP_TRACE(
       metatrace::Category::FUNCTION, "CREATE_FUNCTION",
       [ctx, argv](metatrace::Record* r) {
@@ -220,6 +311,7 @@
   }
   out = sqlite_utils::SqliteValueToSqlValue(
       sqlite3_column_value(ctx->CurrentStatement(), 0));
+  ctx->memoizer().Memoize(argc, argv, out);
 
   // If we return a bytes type but have a null pointer, SQLite will convert this
   // to an SQL null. However, for proto build functions, we actively want to
@@ -233,6 +325,11 @@
 }
 
 base::Status CreatedFunction::VerifyPostConditions(Context* ctx) {
+  // If we returned a memoized value, we don't need to verify post-conditions as
+  // we didn't run a statement.
+  if (ctx->memoizer().ShouldBypassPostConditions()) {
+    return base::OkStatus();
+  }
   int ret = sqlite3_step(ctx->CurrentStatement());
   RETURN_IF_ERROR(SqliteRetToStatus(ctx->engine()->sqlite_engine()->db(),
                                     ctx->prototype().function_name, ret));
@@ -260,11 +357,7 @@
                                  sqlite3_value** argv,
                                  SqlValue&,
                                  Destructors&) {
-  if (argc != 3) {
-    return base::ErrStatus(
-        "CREATE_FUNCTION: invalid number of args; expected %u, received %zu",
-        3u, argc);
-  }
+  RETURN_IF_ERROR(sqlite_utils::CheckArgCount("CREATE_FUNCTION", argc, 3u));
 
   sqlite3_value* prototype_value = argv[0];
   sqlite3_value* return_type_value = argv[1];
@@ -372,5 +465,27 @@
   return ctx->PrepareStatement();
 }
 
+base::Status ExperimentalMemoize::Run(PerfettoSqlEngine* engine,
+                                      size_t argc,
+                                      sqlite3_value** argv,
+                                      SqlValue&,
+                                      Destructors&) {
+  RETURN_IF_ERROR(sqlite_utils::CheckArgCount("EXPERIMENTAL_MEMOIZE", argc, 1));
+  base::StatusOr<std::string> function_name =
+      sqlite_utils::ExtractStringArg("MEMOIZE", "function_name", argv[0]);
+  RETURN_IF_ERROR(function_name.status());
+
+  constexpr size_t kSupportedArgCount = 1;
+  CreatedFunction::Context* ctx = static_cast<CreatedFunction::Context*>(
+      engine->sqlite_engine()->GetFunctionContext(function_name->c_str(),
+                                                  kSupportedArgCount));
+  if (!ctx) {
+    return base::ErrStatus(
+        "EXPERIMENTAL_MEMOIZE: Function %s(INT) does not exist",
+        function_name->c_str());
+  }
+  return ctx->EnableMemoization();
+}
+
 }  // namespace trace_processor
 }  // namespace perfetto
diff --git a/src/trace_processor/prelude/functions/create_function.h b/src/trace_processor/prelude/functions/create_function.h
index ed8c06c..612abf1 100644
--- a/src/trace_processor/prelude/functions/create_function.h
+++ b/src/trace_processor/prelude/functions/create_function.h
@@ -44,6 +44,23 @@
                           Destructors&);
 };
 
+// Implementation of MEMOIZE SQL function.
+// SELECT EXPERIMENTAL_MEMOIZE('my_func') enables memoization for the results of
+// the calls to `my_func`. `my_func` must be a Perfetto SQL function created
+// through CREATE_FUNCTION that takes a single integer argument and returns a
+// int.
+struct ExperimentalMemoize : public SqlFunction {
+  using Context = PerfettoSqlEngine;
+
+  static constexpr bool kVoidReturn = true;
+
+  static base::Status Run(Context* ctx,
+                          size_t argc,
+                          sqlite3_value** argv,
+                          SqlValue& out,
+                          Destructors&);
+};
+
 }  // namespace trace_processor
 }  // namespace perfetto
 
diff --git a/src/trace_processor/sqlite/sqlite_utils.cc b/src/trace_processor/sqlite/sqlite_utils.cc
index c4077a5..3afb3ea 100644
--- a/src/trace_processor/sqlite/sqlite_utils.cc
+++ b/src/trace_processor/sqlite/sqlite_utils.cc
@@ -175,6 +175,61 @@
   PERFETTO_FATAL("For GCC");
 }
 
+base::Status CheckArgCount(const char* function_name,
+                           size_t argc,
+                           size_t expected_argc) {
+  if (argc == expected_argc) {
+    return base::OkStatus();
+  }
+  return base::ErrStatus("%s: expected %zu arguments, got %zu", function_name,
+                         expected_argc, argc);
+}
+
+base::StatusOr<int64_t> ExtractIntArg(const char* function_name,
+                                      const char* arg_name,
+                                      sqlite3_value* sql_value) {
+  SqlValue value = SqliteValueToSqlValue(sql_value);
+  std::optional<int64_t> result;
+
+  base::Status status = ExtractFromSqlValue(value, result);
+  if (!status.ok()) {
+    return base::ErrStatus("%s(%s): %s", function_name, arg_name,
+                           status.message().c_str());
+  }
+  PERFETTO_CHECK(result);
+  return *result;
+}
+
+base::StatusOr<double> ExtractDoubleArg(const char* function_name,
+                                        const char* arg_name,
+                                        sqlite3_value* sql_value) {
+  SqlValue value = SqliteValueToSqlValue(sql_value);
+  std::optional<double> result;
+
+  base::Status status = ExtractFromSqlValue(value, result);
+  if (!status.ok()) {
+    return base::ErrStatus("%s(%s): %s", function_name, arg_name,
+                           status.message().c_str());
+  }
+  PERFETTO_CHECK(result);
+  return *result;
+}
+
+base::StatusOr<std::string> ExtractStringArg(const char* function_name,
+                                             const char* arg_name,
+                                             sqlite3_value* sql_value) {
+  SqlValue value = SqliteValueToSqlValue(sql_value);
+  std::optional<const char*> result;
+
+  base::Status status = ExtractFromSqlValue(value, result);
+  if (!status.ok()) {
+    return base::ErrStatus("%s(%s): %s", function_name, arg_name,
+                           status.message().c_str());
+  }
+  PERFETTO_CHECK(result);
+  return std::string(*result);
+}
+
 base::Status TypeCheckSqliteValue(sqlite3_value* value,
                                   SqlValue::Type expected_type) {
   return TypeCheckSqliteValue(value, expected_type,
diff --git a/src/trace_processor/sqlite/sqlite_utils.h b/src/trace_processor/sqlite/sqlite_utils.h
index ef44c45..d3ccfc5 100644
--- a/src/trace_processor/sqlite/sqlite_utils.h
+++ b/src/trace_processor/sqlite/sqlite_utils.h
@@ -256,6 +256,24 @@
 // This should really only be used for debugging messages.
 const char* SqliteTypeToFriendlyString(SqlValue::Type type);
 
+// Verifies if |argc| matches |expected_argc| and returns an appropriate error
+// message if they don't match.
+base::Status CheckArgCount(const char* function_name,
+                           size_t argc,
+                           size_t expected_argc);
+
+// Type-safe helpers to extract an arg value from a sqlite3_value*, returning an
+// appropriate message if it fails.
+base::StatusOr<int64_t> ExtractIntArg(const char* function_name,
+                                      const char* arg_name,
+                                      sqlite3_value* value);
+base::StatusOr<double> ExtractDoubleArg(const char* function_name,
+                                        const char* arg_name,
+                                        sqlite3_value* value);
+base::StatusOr<std::string> ExtractStringArg(const char* function_name,
+                                             const char* arg_name,
+                                             sqlite3_value* value);
+
 // Verifies if |value| has the type represented by |expected_type|.
 // Returns base::OkStatus if it does or an base::ErrStatus with an
 // appropriate error mesage (incorporating |expected_type_str| if specified).
diff --git a/src/trace_processor/trace_processor_impl.cc b/src/trace_processor/trace_processor_impl.cc
index 800d6d6..7b329e5 100644
--- a/src/trace_processor/trace_processor_impl.cc
+++ b/src/trace_processor/trace_processor_impl.cc
@@ -405,6 +405,8 @@
   RegisterFunction<ToMonotonic>(&engine_, "TO_MONOTONIC", 1,
                                 context_.clock_converter.get());
   RegisterFunction<CreateFunction>(&engine_, "CREATE_FUNCTION", 3, &engine_);
+  RegisterFunction<ExperimentalMemoize>(&engine_, "EXPERIMENTAL_MEMOIZE", 1,
+                                        &engine_);
   RegisterFunction<CreateViewFunction>(
       &engine_, "CREATE_VIEW_FUNCTION", 3,
       std::unique_ptr<CreateViewFunction::Context>(