Function memoization
Add support for `experimental_memoize` function, which enables memoization
for f(int) => int Perfetto SQL functions.
Combined with support for recursive SQL functions, it allows us to write
efficient operations over recursive trees.
Change-Id: I1593330ce1950b502fef5540e4355dcf8081edb4
diff --git a/src/trace_processor/prelude/functions/create_function.cc b/src/trace_processor/prelude/functions/create_function.cc
index 29800fe..1d2f352 100644
--- a/src/trace_processor/prelude/functions/create_function.cc
+++ b/src/trace_processor/prelude/functions/create_function.cc
@@ -59,6 +59,83 @@
static void Cleanup(Context*);
};
+class Memoizer {
+ public:
+ // Enables memoization.
+ // Only functions with a single int argument returning ints are supported.
+ base::Status EnableMemoization(const Prototype& prototype,
+ sql_argument::Type return_type) {
+ if (prototype.arguments.size() != 1 ||
+ TypeToSqlValueType(prototype.arguments[0].type()) !=
+ SqlValue::Type::kLong) {
+ return base::ErrStatus(
+ "EXPERIMENTAL_MEMOIZE: Function %s should take one int argument",
+ prototype.function_name.c_str());
+ }
+ if (TypeToSqlValueType(return_type) != SqlValue::Type::kLong) {
+ return base::ErrStatus(
+ "EXPERIMENTAL_MEMOIZE: Function %s should return an int",
+ prototype.function_name.c_str());
+ }
+ enabled_ = true;
+ return base::OkStatus();
+ }
+
+ // Returns the memoized value for the current invocation if it exists.
+ std::optional<SqlValue> GetMemoizedValue(size_t argc, sqlite3_value** argv) {
+ std::optional<int64_t> arg = ExtractArgForMemoization(argc, argv);
+ if (!arg) {
+ return std::nullopt;
+ }
+ int64_t* value = memoized_values_.Find(*arg);
+ if (!value) {
+ return std::nullopt;
+ }
+ is_returning_memoized_value_ = true;
+ return SqlValue::Long(*value);
+ }
+
+ // Saves the return value of the current invocation for memoization.
+ void Memoize(size_t argc, sqlite3_value** argv, SqlValue value) {
+ if (!enabled_ || value.type != SqlValue::Type::kLong) {
+ return;
+ }
+ std::optional<int64_t> arg = ExtractArgForMemoization(argc, argv);
+ if (!arg) {
+ return;
+ }
+ memoized_values_.Insert(*arg, value.AsLong());
+ }
+
+ // Returns true if memoization is enabled and the current invocation should
+ // bypass post-conditions (as we do not have a statement to check).
+ bool ShouldBypassPostConditions() {
+ bool is_returning_memoized_value = is_returning_memoized_value_;
+ is_returning_memoized_value_ = false;
+ return enabled_ && is_returning_memoized_value;
+ }
+
+ private:
+ std::optional<int64_t> ExtractArgForMemoization(size_t argc,
+ sqlite3_value** argv) {
+ if (!enabled_ || argc != 1) {
+ return std::nullopt;
+ }
+ SqlValue arg = sqlite_utils::SqliteValueToSqlValue(argv[0]);
+ if (arg.type != SqlValue::Type::kLong) {
+ return std::nullopt;
+ }
+ return arg.AsLong();
+ }
+
+ bool enabled_ = false;
+ base::FlatHashMap<int64_t, int64_t> memoized_values_;
+ // This is used to skip post-conditions when we are returning a memoized
+ // value. True between a successful call to GetMemoizedValue and the call to
+ // ValidatePostConditions, false otherwise.
+ bool is_returning_memoized_value_ = false;
+};
+
// This class is used to store the state of a CREATE_FUNCTION call.
// It is used to store the state of the function across multiple invocations
// of the function (e.g. when the function is called recursively).
@@ -123,6 +200,10 @@
--current_recursion_level_;
}
+ base::Status EnableMemoization() {
+ return memoizer_.EnableMemoization(prototype_, return_type_);
+ }
+
PerfettoSqlEngine* engine() const { return engine_; }
const Prototype& prototype() const { return prototype_; }
@@ -133,6 +214,8 @@
bool is_valid() const { return is_valid_; }
+ Memoizer& memoizer() { return memoizer_; }
+
private:
PerfettoSqlEngine* engine_;
Prototype prototype_;
@@ -150,6 +233,7 @@
// by tracking whether the current function definition is valid (in which case
// re-registration is not allowed).
bool is_valid_ = false;
+ Memoizer memoizer_;
};
base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
@@ -181,6 +265,13 @@
}
}
+ std::optional<SqlValue> memoized_value =
+ ctx->memoizer().GetMemoizedValue(argc, argv);
+ if (memoized_value) {
+ out = *memoized_value;
+ return base::OkStatus();
+ }
+
PERFETTO_TP_TRACE(
metatrace::Category::FUNCTION, "CREATE_FUNCTION",
[ctx, argv](metatrace::Record* r) {
@@ -220,6 +311,7 @@
}
out = sqlite_utils::SqliteValueToSqlValue(
sqlite3_column_value(ctx->CurrentStatement(), 0));
+ ctx->memoizer().Memoize(argc, argv, out);
// If we return a bytes type but have a null pointer, SQLite will convert this
// to an SQL null. However, for proto build functions, we actively want to
@@ -233,6 +325,11 @@
}
base::Status CreatedFunction::VerifyPostConditions(Context* ctx) {
+ // If we returned a memoized value, we don't need to verify post-conditions as
+ // we didn't run a statement.
+ if (ctx->memoizer().ShouldBypassPostConditions()) {
+ return base::OkStatus();
+ }
int ret = sqlite3_step(ctx->CurrentStatement());
RETURN_IF_ERROR(SqliteRetToStatus(ctx->engine()->sqlite_engine()->db(),
ctx->prototype().function_name, ret));
@@ -260,11 +357,7 @@
sqlite3_value** argv,
SqlValue&,
Destructors&) {
- if (argc != 3) {
- return base::ErrStatus(
- "CREATE_FUNCTION: invalid number of args; expected %u, received %zu",
- 3u, argc);
- }
+ RETURN_IF_ERROR(sqlite_utils::CheckArgCount("CREATE_FUNCTION", argc, 3u));
sqlite3_value* prototype_value = argv[0];
sqlite3_value* return_type_value = argv[1];
@@ -372,5 +465,27 @@
return ctx->PrepareStatement();
}
+base::Status ExperimentalMemoize::Run(PerfettoSqlEngine* engine,
+ size_t argc,
+ sqlite3_value** argv,
+ SqlValue&,
+ Destructors&) {
+ RETURN_IF_ERROR(sqlite_utils::CheckArgCount("EXPERIMENTAL_MEMOIZE", argc, 1));
+ base::StatusOr<std::string> function_name =
+ sqlite_utils::ExtractStringArg("MEMOIZE", "function_name", argv[0]);
+ RETURN_IF_ERROR(function_name.status());
+
+ constexpr size_t kSupportedArgCount = 1;
+ CreatedFunction::Context* ctx = static_cast<CreatedFunction::Context*>(
+ engine->sqlite_engine()->GetFunctionContext(function_name->c_str(),
+ kSupportedArgCount));
+ if (!ctx) {
+ return base::ErrStatus(
+ "EXPERIMENTAL_MEMOIZE: Function %s(INT) does not exist",
+ function_name->c_str());
+ }
+ return ctx->EnableMemoization();
+}
+
} // namespace trace_processor
} // namespace perfetto
diff --git a/src/trace_processor/prelude/functions/create_function.h b/src/trace_processor/prelude/functions/create_function.h
index ed8c06c..612abf1 100644
--- a/src/trace_processor/prelude/functions/create_function.h
+++ b/src/trace_processor/prelude/functions/create_function.h
@@ -44,6 +44,23 @@
Destructors&);
};
+// Implementation of MEMOIZE SQL function.
+// SELECT EXPERIMENTAL_MEMOIZE('my_func') enables memoization for the results of
+// the calls to `my_func`. `my_func` must be a Perfetto SQL function created
+// through CREATE_FUNCTION that takes a single integer argument and returns a
+// int.
+struct ExperimentalMemoize : public SqlFunction {
+ using Context = PerfettoSqlEngine;
+
+ static constexpr bool kVoidReturn = true;
+
+ static base::Status Run(Context* ctx,
+ size_t argc,
+ sqlite3_value** argv,
+ SqlValue& out,
+ Destructors&);
+};
+
} // namespace trace_processor
} // namespace perfetto
diff --git a/src/trace_processor/sqlite/sqlite_utils.cc b/src/trace_processor/sqlite/sqlite_utils.cc
index c4077a5..3afb3ea 100644
--- a/src/trace_processor/sqlite/sqlite_utils.cc
+++ b/src/trace_processor/sqlite/sqlite_utils.cc
@@ -175,6 +175,61 @@
PERFETTO_FATAL("For GCC");
}
+base::Status CheckArgCount(const char* function_name,
+ size_t argc,
+ size_t expected_argc) {
+ if (argc == expected_argc) {
+ return base::OkStatus();
+ }
+ return base::ErrStatus("%s: expected %zu arguments, got %zu", function_name,
+ expected_argc, argc);
+}
+
+base::StatusOr<int64_t> ExtractIntArg(const char* function_name,
+ const char* arg_name,
+ sqlite3_value* sql_value) {
+ SqlValue value = SqliteValueToSqlValue(sql_value);
+ std::optional<int64_t> result;
+
+ base::Status status = ExtractFromSqlValue(value, result);
+ if (!status.ok()) {
+ return base::ErrStatus("%s(%s): %s", function_name, arg_name,
+ status.message().c_str());
+ }
+ PERFETTO_CHECK(result);
+ return *result;
+}
+
+base::StatusOr<double> ExtractDoubleArg(const char* function_name,
+ const char* arg_name,
+ sqlite3_value* sql_value) {
+ SqlValue value = SqliteValueToSqlValue(sql_value);
+ std::optional<double> result;
+
+ base::Status status = ExtractFromSqlValue(value, result);
+ if (!status.ok()) {
+ return base::ErrStatus("%s(%s): %s", function_name, arg_name,
+ status.message().c_str());
+ }
+ PERFETTO_CHECK(result);
+ return *result;
+}
+
+base::StatusOr<std::string> ExtractStringArg(const char* function_name,
+ const char* arg_name,
+ sqlite3_value* sql_value) {
+ SqlValue value = SqliteValueToSqlValue(sql_value);
+ std::optional<const char*> result;
+
+ base::Status status = ExtractFromSqlValue(value, result);
+ if (!status.ok()) {
+ return base::ErrStatus("%s(%s): %s", function_name, arg_name,
+ status.message().c_str());
+ }
+ PERFETTO_CHECK(result);
+ return std::string(*result);
+}
+
base::Status TypeCheckSqliteValue(sqlite3_value* value,
SqlValue::Type expected_type) {
return TypeCheckSqliteValue(value, expected_type,
diff --git a/src/trace_processor/sqlite/sqlite_utils.h b/src/trace_processor/sqlite/sqlite_utils.h
index ef44c45..d3ccfc5 100644
--- a/src/trace_processor/sqlite/sqlite_utils.h
+++ b/src/trace_processor/sqlite/sqlite_utils.h
@@ -256,6 +256,24 @@
// This should really only be used for debugging messages.
const char* SqliteTypeToFriendlyString(SqlValue::Type type);
+// Verifies if |argc| matches |expected_argc| and returns an appropriate error
+// message if they don't match.
+base::Status CheckArgCount(const char* function_name,
+ size_t argc,
+ size_t expected_argc);
+
+// Type-safe helpers to extract an arg value from a sqlite3_value*, returning an
+// appropriate message if it fails.
+base::StatusOr<int64_t> ExtractIntArg(const char* function_name,
+ const char* arg_name,
+ sqlite3_value* value);
+base::StatusOr<double> ExtractDoubleArg(const char* function_name,
+ const char* arg_name,
+ sqlite3_value* value);
+base::StatusOr<std::string> ExtractStringArg(const char* function_name,
+ const char* arg_name,
+ sqlite3_value* value);
+
// Verifies if |value| has the type represented by |expected_type|.
// Returns base::OkStatus if it does or an base::ErrStatus with an
// appropriate error mesage (incorporating |expected_type_str| if specified).
diff --git a/src/trace_processor/trace_processor_impl.cc b/src/trace_processor/trace_processor_impl.cc
index 800d6d6..7b329e5 100644
--- a/src/trace_processor/trace_processor_impl.cc
+++ b/src/trace_processor/trace_processor_impl.cc
@@ -405,6 +405,8 @@
RegisterFunction<ToMonotonic>(&engine_, "TO_MONOTONIC", 1,
context_.clock_converter.get());
RegisterFunction<CreateFunction>(&engine_, "CREATE_FUNCTION", 3, &engine_);
+ RegisterFunction<ExperimentalMemoize>(&engine_, "EXPERIMENTAL_MEMOIZE", 1,
+ &engine_);
RegisterFunction<CreateViewFunction>(
&engine_, "CREATE_VIEW_FUNCTION", 3,
std::unique_ptr<CreateViewFunction::Context>(