tp: decouple function creation code from CREATE_FUNCTION function
This CL moves all the relevant code in CREATE_FUNCTION which actually
involves creating a function to the engine instead. This will allow
implementation of the CREATE PERFETTO FUNCTION syntax in a followup CL
without needing to execute any SQL.
Change-Id: Ic684ee2f6363e5e4578c4f94956b3a056b19d202
diff --git a/Android.bp b/Android.bp
index ca76452..0b72856 100644
--- a/Android.bp
+++ b/Android.bp
@@ -10232,6 +10232,8 @@
filegroup {
name: "perfetto_src_trace_processor_perfetto_sql_engine_engine",
srcs: [
+ "src/trace_processor/perfetto_sql/engine/created_function.cc",
+ "src/trace_processor/perfetto_sql/engine/function_util.cc",
"src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc",
"src/trace_processor/perfetto_sql/engine/perfetto_sql_parser.cc",
],
@@ -10242,7 +10244,6 @@
name: "perfetto_src_trace_processor_perfetto_sql_intrinsics_functions_functions",
srcs: [
"src/trace_processor/perfetto_sql/intrinsics/functions/create_function.cc",
- "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.cc",
"src/trace_processor/perfetto_sql/intrinsics/functions/create_view_function.cc",
"src/trace_processor/perfetto_sql/intrinsics/functions/import.cc",
"src/trace_processor/perfetto_sql/intrinsics/functions/layout_functions.cc",
diff --git a/BUILD b/BUILD
index 9c369ca..17fc83d 100644
--- a/BUILD
+++ b/BUILD
@@ -2038,6 +2038,10 @@
perfetto_filegroup(
name = "src_trace_processor_perfetto_sql_engine_engine",
srcs = [
+ "src/trace_processor/perfetto_sql/engine/created_function.cc",
+ "src/trace_processor/perfetto_sql/engine/created_function.h",
+ "src/trace_processor/perfetto_sql/engine/function_util.cc",
+ "src/trace_processor/perfetto_sql/engine/function_util.h",
"src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc",
"src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h",
"src/trace_processor/perfetto_sql/engine/perfetto_sql_parser.cc",
@@ -2052,8 +2056,6 @@
"src/trace_processor/perfetto_sql/intrinsics/functions/clock_functions.h",
"src/trace_processor/perfetto_sql/intrinsics/functions/create_function.cc",
"src/trace_processor/perfetto_sql/intrinsics/functions/create_function.h",
- "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.cc",
- "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h",
"src/trace_processor/perfetto_sql/intrinsics/functions/create_view_function.cc",
"src/trace_processor/perfetto_sql/intrinsics/functions/create_view_function.h",
"src/trace_processor/perfetto_sql/intrinsics/functions/import.cc",
diff --git a/src/trace_processor/perfetto_sql/engine/BUILD.gn b/src/trace_processor/perfetto_sql/engine/BUILD.gn
index d3c926a..20a7110 100644
--- a/src/trace_processor/perfetto_sql/engine/BUILD.gn
+++ b/src/trace_processor/perfetto_sql/engine/BUILD.gn
@@ -18,6 +18,10 @@
source_set("engine") {
sources = [
+ "created_function.cc",
+ "created_function.h",
+ "function_util.cc",
+ "function_util.h",
"perfetto_sql_engine.cc",
"perfetto_sql_engine.h",
"perfetto_sql_parser.cc",
@@ -31,7 +35,9 @@
"../../perfetto_sql/intrinsics/functions:interface",
"../../perfetto_sql/intrinsics/table_functions:interface",
"../../sqlite",
+ "../../types",
"../../util",
+ "../../util:sql_argument",
]
}
diff --git a/src/trace_processor/perfetto_sql/engine/created_function.cc b/src/trace_processor/perfetto_sql/engine/created_function.cc
new file mode 100644
index 0000000..afce971
--- /dev/null
+++ b/src/trace_processor/perfetto_sql/engine/created_function.cc
@@ -0,0 +1,738 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/trace_processor/perfetto_sql/engine/created_function.h"
+
+#include <queue>
+#include <stack>
+
+#include "perfetto/base/status.h"
+#include "src/trace_processor/perfetto_sql/engine/function_util.h"
+#include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
+#include "src/trace_processor/sqlite/scoped_db.h"
+#include "src/trace_processor/sqlite/sql_source.h"
+#include "src/trace_processor/sqlite/sqlite_engine.h"
+#include "src/trace_processor/sqlite/sqlite_utils.h"
+#include "src/trace_processor/tp_metatrace.h"
+#include "src/trace_processor/util/status_macros.h"
+
+namespace perfetto {
+namespace trace_processor {
+
+namespace {
+
+base::StatusOr<SqliteEngine::PreparedStatement> CreateStatement(
+ PerfettoSqlEngine* engine,
+ const std::string& sql,
+ const std::string& prototype) {
+ auto res = engine->sqlite_engine()->PrepareStatement(
+ SqlSource::FromFunction(sql.c_str(), prototype));
+ RETURN_IF_ERROR(res.status());
+ return std::move(res.value());
+}
+
+base::Status CheckNoMoreRows(sqlite3_stmt* stmt,
+ sqlite3* db,
+ const Prototype& prototype) {
+ int ret = sqlite3_step(stmt);
+ RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
+ if (ret == SQLITE_ROW) {
+ auto expanded_sql = sqlite_utils::ExpandedSqlForStmt(stmt);
+ return base::ErrStatus(
+ "%s: multiple values were returned when executing function body. "
+ "Executed SQL was %s",
+ prototype.function_name.c_str(), expanded_sql.get());
+ }
+ PERFETTO_DCHECK(ret == SQLITE_DONE);
+ return base::OkStatus();
+}
+
+// Note: if the returned type is string / bytes, it will be invalidated by the
+// next call to SQLite, so the caller must take care to either copy or use the
+// value before calling SQLite again.
+base::StatusOr<SqlValue> EvaluateScalarStatement(sqlite3_stmt* stmt,
+ sqlite3* db,
+ const Prototype& prototype) {
+ int ret = sqlite3_step(stmt);
+ RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
+ if (ret == SQLITE_DONE) {
+ // No return value means we just return don't set |out|.
+ return SqlValue();
+ }
+
+ PERFETTO_DCHECK(ret == SQLITE_ROW);
+ size_t col_count = static_cast<size_t>(sqlite3_column_count(stmt));
+ if (col_count != 1) {
+ return base::ErrStatus(
+ "%s: SQL definition should only return one column: returned %zu "
+ "columns",
+ prototype.function_name.c_str(), col_count);
+ }
+
+ SqlValue result =
+ sqlite_utils::SqliteValueToSqlValue(sqlite3_column_value(stmt, 0));
+
+ // If we return a bytes type but have a null pointer, SQLite will convert this
+ // to an SQL null. However, for proto build functions, we actively want to
+ // distinguish between nulls and 0 byte strings. Therefore, change the value
+ // to an empty string.
+ if (result.type == SqlValue::kBytes && result.bytes_value == nullptr) {
+ PERFETTO_DCHECK(result.bytes_count == 0);
+ result.bytes_value = "";
+ }
+
+ return result;
+}
+
+base::Status BindArguments(sqlite3_stmt* stmt,
+ const Prototype& prototype,
+ size_t argc,
+ sqlite3_value** argv) {
+ // Bind all the arguments to the appropriate places in the function.
+ for (size_t i = 0; i < argc; ++i) {
+ RETURN_IF_ERROR(MaybeBindArgument(stmt, prototype.function_name,
+ prototype.arguments[i], argv[i]));
+ }
+ return base::OkStatus();
+}
+
+struct StoredSqlValue {
+ // unique_ptr to ensure that the pointers to these values are long-lived.
+ using OwnedString = std::unique_ptr<std::string>;
+ using OwnedBytes = std::unique_ptr<std::vector<uint8_t>>;
+ // variant is a pain to use, but it's the simplest way to ensure that
+ // the destructors run correctly for non-trivial members of the
+ // union.
+ using Data =
+ std::variant<int64_t, double, OwnedString, OwnedBytes, nullptr_t>;
+
+ StoredSqlValue(SqlValue value) {
+ switch (value.type) {
+ case SqlValue::Type::kNull:
+ data = nullptr;
+ break;
+ case SqlValue::Type::kLong:
+ data = value.long_value;
+ break;
+ case SqlValue::Type::kDouble:
+ data = value.double_value;
+ break;
+ case SqlValue::Type::kString:
+ data = std::make_unique<std::string>(value.string_value);
+ break;
+ case SqlValue::Type::kBytes:
+ const uint8_t* ptr = static_cast<const uint8_t*>(value.bytes_value);
+ data = std::make_unique<std::vector<uint8_t>>(ptr,
+ ptr + value.bytes_count);
+ break;
+ }
+ }
+
+ SqlValue AsSqlValue() {
+ if (std::holds_alternative<nullptr_t>(data)) {
+ return SqlValue();
+ } else if (std::holds_alternative<int64_t>(data)) {
+ return SqlValue::Long(std::get<int64_t>(data));
+ } else if (std::holds_alternative<double>(data)) {
+ return SqlValue::Double(std::get<double>(data));
+ } else if (std::holds_alternative<OwnedString>(data)) {
+ const auto& str_ptr = std::get<OwnedString>(data);
+ return SqlValue::String(str_ptr->c_str());
+ } else if (std::holds_alternative<OwnedBytes>(data)) {
+ const auto& bytes_ptr = std::get<OwnedBytes>(data);
+ return SqlValue::Bytes(bytes_ptr->data(), bytes_ptr->size());
+ }
+ // GCC doesn't realize that the switch is exhaustive.
+ PERFETTO_CHECK(false);
+ return SqlValue();
+ }
+
+ Data data = nullptr;
+};
+
+class Memoizer {
+ public:
+ // Supported arguments. For now, only functions with a single int argument are
+ // supported.
+ using MemoizedArgs = int64_t;
+
+ // Enables memoization.
+ // Only functions with a single int argument returning ints are supported.
+ base::Status EnableMemoization(const Prototype& prototype) {
+ if (prototype.arguments.size() != 1 ||
+ TypeToSqlValueType(prototype.arguments[0].type()) !=
+ SqlValue::Type::kLong) {
+ return base::ErrStatus(
+ "EXPERIMENTAL_MEMOIZE: Function %s should take one int argument",
+ prototype.function_name.c_str());
+ }
+ enabled_ = true;
+ return base::OkStatus();
+ }
+
+ // Returns the memoized value for the current invocation if it exists.
+ std::optional<SqlValue> GetMemoizedValue(MemoizedArgs args) {
+ if (!enabled_) {
+ return std::nullopt;
+ }
+ StoredSqlValue* value = memoized_values_.Find(args);
+ if (!value) {
+ return std::nullopt;
+ }
+ return value->AsSqlValue();
+ }
+
+ bool HasMemoizedValue(MemoizedArgs args) {
+ return GetMemoizedValue(args).has_value();
+ }
+
+ // Saves the return value of the current invocation for memoization.
+ void Memoize(MemoizedArgs args, SqlValue value) {
+ if (!enabled_) {
+ return;
+ }
+ memoized_values_.Insert(args, StoredSqlValue(value));
+ }
+
+ // Checks that the function has a single int argument and returns it.
+ static std::optional<MemoizedArgs> AsMemoizedArgs(size_t argc,
+ sqlite3_value** argv) {
+ if (argc != 1) {
+ return std::nullopt;
+ }
+ SqlValue arg = sqlite_utils::SqliteValueToSqlValue(argv[0]);
+ if (arg.type != SqlValue::Type::kLong) {
+ return std::nullopt;
+ }
+ return arg.AsLong();
+ }
+
+ bool enabled() const { return enabled_; }
+
+ private:
+ bool enabled_ = false;
+ base::FlatHashMap<MemoizedArgs, StoredSqlValue> memoized_values_;
+};
+
+// A helper to unroll recursive calls: to minimise the amount of stack space
+// used, memoized recursive calls are evaluated using an on-heap queue.
+//
+// We compute the function in two passes:
+// - In the first pass, we evaluate the statement to discover which recursive
+// calls it makes, returning null from recursive calls and ignoring the
+// result.
+// - In the second pass, we evaluate the statement again, but this time we
+// memoize the result of each recursive call.
+//
+// We maintain a queue for scheduled "first pass" calls and a stack for the
+// scheduled "second pass" calls, evaluating available first pass calls, then
+// second pass calls. When we evaluate a first pass call, the further calls to
+// CreatedFunction::Run will just add it to the "first pass" queue. The second
+// pass, however, will evaluate the function normally, typically just using the
+// memoized result for the dependent calls. However, if the recursive calls
+// depend on the return value of the function, we will proceed with normal
+// recursion.
+//
+// To make it more concrete, consider an following example.
+// We have a function computing factorial (f) and we want to compute f(3).
+//
+// SELECT create_function('f(x INT)', 'INT',
+// 'SELECT IIF($x = 0, 1, $x * f($x - 1))');
+// SELECT experimental_memoize('f');
+// SELECT f(3);
+//
+// - We start with a call to f(3). It executes the statement as normal, which
+// recursively calls f(2).
+// - When f(2) is called, we detect that it is a recursive call and we start
+// unrolling it, entering RecursiveCallUnroller::Run.
+// - We schedule first pass for 2 and the state of the unroller
+// is first_pass: [2], second_pass: [].
+// - Then we compute the first pass for f(2). It calls f(1), which is ignored
+// due to OnFunctionCall returning kIgnoreDueToFirstPass and 1 is added to the
+// first pass queue. 2 is taked out of the first pass queue and moved to the
+// second pass stack. State: first_pass: [1], second_pass: [2].
+// - Then we compute the first pass for 1. The similar thing happens: f(0) is
+// called and ignored, 0 is added to first_pass, 1 is added to second_pass.
+// State: first_pass: [0], second_pass: [2, 1].
+// - Then we compute the first pass for 0. It doesn't make further calls, so
+// 0 is moved to the second pass stack.
+// State: first_pass: [], second_pass: [2, 1, 0].
+// - Then we compute the second pass for 0. It just returns 1.
+// State: first_pass: [], second_pass: [2, 1], results: {0: 1}.
+// - Then we compute the second pass for 1. It calls f(0), which is memoized.
+// State: first_pass: [], second_pass: [2], results: {0: 1, 1: 1}.
+// - Then we compute the second pass for 1. It calls f(1), which is memoized.
+// State: first_pass: [], second_pass: [], results: {0: 1, 1: 1, 2: 2}.
+// - As both first_pass and second_pass are empty, we return from
+// RecursiveCallUnroller::Run.
+// - Control is returned to CreatedFunction::Run for f(2), which returns
+// memoized value.
+// - Then control is returned to CreatedFunction::Run for f(3), which completes
+// the computation.
+class RecursiveCallUnroller {
+ public:
+ RecursiveCallUnroller(PerfettoSqlEngine* engine,
+ sqlite3_stmt* stmt,
+ const Prototype& prototype,
+ Memoizer& memoizer)
+ : engine_(engine),
+ stmt_(stmt),
+ prototype_(prototype),
+ memoizer_(memoizer) {}
+
+ // Whether we should just return null due to us being in the "first pass".
+ enum class FunctionCallState {
+ kIgnoreDueToFirstPass,
+ kEvaluate,
+ };
+
+ base::StatusOr<FunctionCallState> OnFunctionCall(
+ Memoizer::MemoizedArgs args) {
+ // If we are in the second pass, we just continue the function execution,
+ // including checking if a memoized value is available and returning it.
+ //
+ // We generally expect a memoized value to be available, but there are
+ // cases when it might not be the case, e.g. when which recursive calls are
+ // made depends on the return value of the function, e.g. for the following
+ // function, the first pass will not detect f(y) calls, so they will
+ // be computed recursively.
+ // f(x): SELECT max(f(y)) FROM y WHERE y < f($x - 1);
+ if (state_ == State::kComputingSecondPass) {
+ return FunctionCallState::kEvaluate;
+ }
+ if (!memoizer_.HasMemoizedValue(args)) {
+ ArgState* state = visited_.Find(args);
+ if (state) {
+ // Detect recursive loops, e.g. f(1) calling f(2) calling f(1).
+ if (*state == ArgState::kEvaluating) {
+ return base::ErrStatus("Infinite recursion detected");
+ }
+ } else {
+ visited_.Insert(args, ArgState::kScheduled);
+ first_pass_.push(args);
+ }
+ }
+ return FunctionCallState::kIgnoreDueToFirstPass;
+ }
+
+ base::Status Run(Memoizer::MemoizedArgs initial_args) {
+ PERFETTO_TP_TRACE(metatrace::Category::FUNCTION,
+ "UNROLL_RECURSIVE_FUNCTION_CALL",
+ [&](metatrace::Record* r) {
+ r->AddArg("Function", prototype_.function_name);
+ r->AddArg("Arg 0", std::to_string(initial_args));
+ });
+
+ first_pass_.push(initial_args);
+ visited_.Insert(initial_args, ArgState::kScheduled);
+
+ while (!first_pass_.empty() || !second_pass_.empty()) {
+ // If we have scheduled first pass calls, we evaluate them first.
+ if (!first_pass_.empty()) {
+ state_ = State::kComputingFirstPass;
+ Memoizer::MemoizedArgs args = first_pass_.front();
+
+ PERFETTO_TP_TRACE(metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
+ [&](metatrace::Record* r) {
+ r->AddArg("Function", prototype_.function_name);
+ r->AddArg("Type", "UnrollRecursiveCall_FirstPass");
+ r->AddArg("Arg 0", std::to_string(args));
+ });
+
+ first_pass_.pop();
+ second_pass_.push(args);
+ Evaluate(args).status();
+ continue;
+ }
+
+ state_ = State::kComputingSecondPass;
+ Memoizer::MemoizedArgs args = second_pass_.top();
+
+ PERFETTO_TP_TRACE(metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
+ [&](metatrace::Record* r) {
+ r->AddArg("Function", prototype_.function_name);
+ r->AddArg("Type", "UnrollRecursiveCall_SecondPass");
+ r->AddArg("Arg 0", std::to_string(args));
+ });
+
+ visited_.Insert(args, ArgState::kEvaluating);
+ second_pass_.pop();
+ base::StatusOr<std::optional<int64_t>> result = Evaluate(args);
+ RETURN_IF_ERROR(result.status());
+ std::optional<int64_t> maybe_int_result = result.value();
+ if (!maybe_int_result.has_value()) {
+ continue;
+ }
+ visited_.Insert(args, ArgState::kEvaluated);
+ memoizer_.Memoize(args, SqlValue::Long(*maybe_int_result));
+ }
+ return base::OkStatus();
+ }
+
+ private:
+ // This function returns:
+ // - base::ErrStatus if the evaluation of the function failed.
+ // - std::nullopt if the function returned a non-integer value.
+ // - the result of the function otherwise.
+ base::StatusOr<std::optional<int64_t>> Evaluate(Memoizer::MemoizedArgs args) {
+ RETURN_IF_ERROR(MaybeBindIntArgument(stmt_, prototype_.function_name,
+ prototype_.arguments[0], args));
+ base::StatusOr<SqlValue> result = EvaluateScalarStatement(
+ stmt_, engine_->sqlite_engine()->db(), prototype_);
+ sqlite3_reset(stmt_);
+ sqlite3_clear_bindings(stmt_);
+ RETURN_IF_ERROR(result.status());
+ if (result->type != SqlValue::Type::kLong) {
+ return std::optional<int64_t>(std::nullopt);
+ }
+ return std::optional<int64_t>(result->long_value);
+ }
+
+ PerfettoSqlEngine* engine_;
+ sqlite3_stmt* stmt_;
+ const Prototype& prototype_;
+ Memoizer& memoizer_;
+
+ // Current state of the evaluation.
+ enum class State {
+ kComputingFirstPass,
+ kComputingSecondPass,
+ };
+ State state_ = State::kComputingFirstPass;
+
+ // A state of evaluation of a given argument.
+ enum class ArgState {
+ kScheduled,
+ kEvaluating,
+ kEvaluated,
+ };
+
+ // See the class-level comment for the explanation of the two passes.
+ std::queue<Memoizer::MemoizedArgs> first_pass_;
+ base::FlatHashMap<Memoizer::MemoizedArgs, ArgState> visited_;
+ std::stack<Memoizer::MemoizedArgs> second_pass_;
+};
+
+} // namespace
+
+// This class is used to store the state of a CREATE_FUNCTION call.
+// It is used to store the state of the function across multiple invocations
+// of the function (e.g. when the function is called recursively).
+class State : public CreatedFunction::Context {
+ public:
+ explicit State(PerfettoSqlEngine* engine) : engine_(engine) {}
+ ~State() override;
+
+ // Prepare a statement and push it into the stack of allocated statements
+ // for this function.
+ base::Status PrepareStatement() {
+ base::StatusOr<SqliteEngine::PreparedStatement> stmt =
+ CreateStatement(engine_, sql_, prototype_str_);
+ RETURN_IF_ERROR(stmt.status());
+ is_valid_ = true;
+ stmts_.push_back(std::move(stmt.value()));
+ return base::OkStatus();
+ }
+
+ // Sets the state of the function. Should be called only when the function
+ // is invalid (i.e. when it is first created or when the previous statement
+ // failed to prepare).
+ void Reset(Prototype prototype,
+ std::string prototype_str,
+ sql_argument::Type return_type,
+ std::string sql) {
+ // Re-registration of valid functions is not allowed.
+ PERFETTO_DCHECK(!is_valid_);
+ PERFETTO_DCHECK(stmts_.empty());
+
+ prototype_ = std::move(prototype);
+ prototype_str_ = std::move(prototype_str);
+ return_type_ = return_type;
+ sql_ = std::move(sql);
+ }
+
+ // This function is called each time the function is called.
+ // It ensures that we have a statement for the current recursion level,
+ // allocating a new one if needed.
+ base::Status PushStackEntry() {
+ ++current_recursion_level_;
+ if (current_recursion_level_ > stmts_.size()) {
+ return PrepareStatement();
+ }
+ return base::OkStatus();
+ }
+
+ // Returns the statement that is used for the current invocation.
+ sqlite3_stmt* CurrentStatement() {
+ return stmts_[current_recursion_level_ - 1].sqlite_stmt();
+ }
+
+ // This function is called each time the function returns and resets the
+ // statement that this invocation used.
+ void PopStackEntry() {
+ if (current_recursion_level_ > stmts_.size()) {
+ // This is possible if we didn't prepare the statement and returned
+ // an error.
+ return;
+ }
+ sqlite3_reset(CurrentStatement());
+ sqlite3_clear_bindings(CurrentStatement());
+ --current_recursion_level_;
+ }
+
+ base::StatusOr<RecursiveCallUnroller::FunctionCallState> OnFunctionCall(
+ Memoizer::MemoizedArgs args) {
+ if (!recursive_call_unroller_) {
+ return RecursiveCallUnroller::FunctionCallState::kEvaluate;
+ }
+ return recursive_call_unroller_->OnFunctionCall(args);
+ }
+
+ // Called before checking the function for memoization.
+ base::Status UnrollRecursiveCallIfNeeded(Memoizer::MemoizedArgs args) {
+ if (!memoizer_.enabled() || !is_in_recursive_call() ||
+ recursive_call_unroller_) {
+ return base::OkStatus();
+ }
+ // If we are in a recursive call, we need to check if we have already
+ // computed the result for the current arguments.
+ if (memoizer_.HasMemoizedValue(args)) {
+ return base::OkStatus();
+ }
+
+ // If we are in a beginning of a function call:
+ // - is a recursive,
+ // - can be memoized,
+ // - hasn't been memoized already, and
+ // - hasn't start unrolling yet;
+ // start the unrolling and run the unrolling loop.
+ recursive_call_unroller_ = std::make_unique<RecursiveCallUnroller>(
+ engine_, CurrentStatement(), prototype_, memoizer_);
+ auto status = recursive_call_unroller_->Run(args);
+ recursive_call_unroller_.reset();
+ return status;
+ }
+
+ // Schedule a statement to be validated that it is indeed doesn't have any
+ // more rows.
+ void ScheduleEmptyStatementValidation(sqlite3_stmt* stmt) {
+ empty_stmts_to_validate_.push_back(stmt);
+ }
+
+ base::Status ValidateEmptyStatements() {
+ while (!empty_stmts_to_validate_.empty()) {
+ sqlite3_stmt* stmt = empty_stmts_to_validate_.back();
+ empty_stmts_to_validate_.pop_back();
+ RETURN_IF_ERROR(
+ CheckNoMoreRows(stmt, engine_->sqlite_engine()->db(), prototype_));
+ }
+ return base::OkStatus();
+ }
+
+ bool is_in_recursive_call() const { return current_recursion_level_ > 1; }
+
+ base::Status EnableMemoization() {
+ return memoizer_.EnableMemoization(prototype_);
+ }
+
+ PerfettoSqlEngine* engine() const { return engine_; }
+
+ const Prototype& prototype() const { return prototype_; }
+
+ sql_argument::Type return_type() const { return return_type_; }
+
+ const std::string& sql() const { return sql_; }
+
+ bool is_valid() const { return is_valid_; }
+
+ Memoizer& memoizer() { return memoizer_; }
+
+ private:
+ PerfettoSqlEngine* engine_;
+ Prototype prototype_;
+ std::string prototype_str_;
+ sql_argument::Type return_type_;
+ std::string sql_;
+ // Perfetto SQL functions support recursion. Given that each function call in
+ // the stack requires a dedicated statement, we maintain a stack of prepared
+ // statements and use the top one for each new call (allocating a new one if
+ // needed).
+ std::vector<SqliteEngine::PreparedStatement> stmts_;
+ // A list of statements to verify to ensure that they don't have more rows
+ // in VerifyPostConditions.
+ std::vector<sqlite3_stmt*> empty_stmts_to_validate_;
+ size_t current_recursion_level_ = 0;
+ // Function re-registration is not allowed, but the user is allowed to define
+ // the function again if the first call failed. |is_valid_| flag helps that
+ // by tracking whether the current function definition is valid (in which case
+ // re-registration is not allowed).
+ bool is_valid_ = false;
+ Memoizer memoizer_;
+ // Set if we are in a middle of unrolling a recursive call.
+ std::unique_ptr<RecursiveCallUnroller> recursive_call_unroller_;
+};
+
+State::~State() = default;
+
+std::unique_ptr<CreatedFunction::Context> CreatedFunction::MakeContext(
+ PerfettoSqlEngine* engine) {
+ return std::make_unique<State>(engine);
+}
+
+base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
+ size_t argc,
+ sqlite3_value** argv,
+ SqlValue& out,
+ Destructors&) {
+ State* state = static_cast<State*>(ctx);
+ if (argc != state->prototype().arguments.size()) {
+ return base::ErrStatus(
+ "%s: invalid number of args; expected %zu, received %zu",
+ state->prototype().function_name.c_str(),
+ state->prototype().arguments.size(), argc);
+ }
+
+ // Type check all the arguments.
+ for (size_t i = 0; i < argc; ++i) {
+ sqlite3_value* arg = argv[i];
+ sql_argument::Type type = state->prototype().arguments[i].type();
+ base::Status status = sqlite_utils::TypeCheckSqliteValue(
+ arg, sql_argument::TypeToSqlValueType(type),
+ sql_argument::TypeToHumanFriendlyString(type));
+ if (!status.ok()) {
+ return base::ErrStatus("%s[arg=%s]: argument %zu %s",
+ state->prototype().function_name.c_str(),
+ sqlite3_value_text(arg), i, status.c_message());
+ }
+ }
+
+ // Enter the function and ensure that we have a statement allocated.
+ RETURN_IF_ERROR(state->PushStackEntry());
+
+ std::optional<Memoizer::MemoizedArgs> memoized_args =
+ Memoizer::AsMemoizedArgs(argc, argv);
+
+ if (memoized_args) {
+ // If we are in the middle of an recursive calls unrolling, we might want to
+ // ignore the function invocation. See the comment in RecursiveCallUnroller
+ // for more details.
+ base::StatusOr<RecursiveCallUnroller::FunctionCallState> unroll_state =
+ state->OnFunctionCall(*memoized_args);
+ RETURN_IF_ERROR(unroll_state.status());
+ if (*unroll_state ==
+ RecursiveCallUnroller::FunctionCallState::kIgnoreDueToFirstPass) {
+ // Return NULL.
+ return base::OkStatus();
+ }
+
+ RETURN_IF_ERROR(state->UnrollRecursiveCallIfNeeded(*memoized_args));
+
+ std::optional<SqlValue> memoized_value =
+ state->memoizer().GetMemoizedValue(*memoized_args);
+ if (memoized_value) {
+ out = *memoized_value;
+ return base::OkStatus();
+ }
+ }
+
+ PERFETTO_TP_TRACE(
+ metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
+ [state, argv](metatrace::Record* r) {
+ r->AddArg("Function", state->prototype().function_name.c_str());
+ for (uint32_t i = 0; i < state->prototype().arguments.size(); ++i) {
+ std::string key = "Arg " + std::to_string(i);
+ const char* value =
+ reinterpret_cast<const char*>(sqlite3_value_text(argv[i]));
+ r->AddArg(base::StringView(key),
+ value ? base::StringView(value) : base::StringView("NULL"));
+ }
+ });
+
+ RETURN_IF_ERROR(
+ BindArguments(state->CurrentStatement(), state->prototype(), argc, argv));
+ auto result = EvaluateScalarStatement(state->CurrentStatement(),
+ state->engine()->sqlite_engine()->db(),
+ state->prototype());
+ RETURN_IF_ERROR(result.status());
+ out = result.value();
+ state->ScheduleEmptyStatementValidation(state->CurrentStatement());
+
+ if (memoized_args) {
+ state->memoizer().Memoize(*memoized_args, out);
+ }
+
+ return base::OkStatus();
+}
+
+void CreatedFunction::Cleanup(CreatedFunction::Context* ctx) {
+ // Clear the statement.
+ static_cast<State*>(ctx)->PopStackEntry();
+}
+
+base::Status CreatedFunction::VerifyPostConditions(
+ CreatedFunction::Context* ctx) {
+ return static_cast<State*>(ctx)->ValidateEmptyStatements();
+}
+
+base::Status CreatedFunction::ValidateOrPrepare(CreatedFunction::Context* ctx,
+ Prototype prototype,
+ std::string prototype_str,
+ sql_argument::Type return_type,
+ std::string return_type_str,
+ std::string sql_str) {
+ State* state = static_cast<State*>(ctx);
+ if (state->is_valid()) {
+ // If the function already exists, just verify that the prototype, return
+ // type and SQL matches exactly with what we already had registered. By
+ // doing this, we can avoid the problem plaguing C++ macros where macro
+ // ordering determines which one gets run.
+ if (state->prototype() != prototype) {
+ return base::ErrStatus(
+ "CREATE_FUNCTION[prototype=%s]: function prototype changed",
+ prototype_str.c_str());
+ }
+ if (state->return_type() != return_type) {
+ return base::ErrStatus(
+ "CREATE_FUNCTION[prototype=%s]: return type changed from %s to %s",
+ prototype_str.c_str(),
+ sql_argument::TypeToHumanFriendlyString(state->return_type()),
+ return_type_str.c_str());
+ }
+
+ if (state->sql() != sql_str) {
+ return base::ErrStatus(
+ "CREATE_FUNCTION[prototype=%s]: function SQL changed from %s to %s",
+ prototype_str.c_str(), state->sql().c_str(), sql_str.c_str());
+ }
+ return base::OkStatus();
+ }
+
+ state->Reset(std::move(prototype), std::move(prototype_str), return_type,
+ std::move(sql_str));
+
+ // Ideally, we would unregister the function here if the statement prep
+ // failed, but SQLite doesn't allow unregistering functions inside active
+ // statements. So instead we'll just try to prepare the statement when calling
+ // this function, which will return an error.
+ return state->PrepareStatement();
+}
+
+base::Status CreatedFunction::EnableMemoization(Context* ctx) {
+ return static_cast<State*>(ctx)->EnableMemoization();
+}
+
+} // namespace trace_processor
+} // namespace perfetto
diff --git a/src/trace_processor/perfetto_sql/engine/created_function.h b/src/trace_processor/perfetto_sql/engine/created_function.h
new file mode 100644
index 0000000..7dcd70c
--- /dev/null
+++ b/src/trace_processor/perfetto_sql/engine/created_function.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_CREATED_FUNCTION_H_
+#define SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_CREATED_FUNCTION_H_
+
+#include <sqlite3.h>
+#include <memory>
+#include <unordered_map>
+
+#include "perfetto/base/status.h"
+#include "src/trace_processor/perfetto_sql/engine/function_util.h"
+#include "src/trace_processor/perfetto_sql/intrinsics/functions/sql_function.h"
+#include "src/trace_processor/sqlite/scoped_db.h"
+#include "src/trace_processor/sqlite/sqlite_table.h"
+#include "src/trace_processor/types/destructible.h"
+#include "src/trace_processor/util/sql_argument.h"
+
+namespace perfetto {
+namespace trace_processor {
+
+class PerfettoSqlEngine;
+
+struct CreatedFunction : public SqlFunction {
+ // Expose a do-nothing context
+ using Context = Destructible;
+
+ // SqlFunction implementation.
+ static base::Status Run(Context* ctx,
+ size_t argc,
+ sqlite3_value** argv,
+ SqlValue& out,
+ Destructors&);
+ static base::Status VerifyPostConditions(Context*);
+ static void Cleanup(Context*);
+
+ // Glue code for PerfettoSqlEngine.
+ static std::unique_ptr<Context> MakeContext(PerfettoSqlEngine*);
+ static base::Status ValidateOrPrepare(Context*,
+ Prototype,
+ std::string prototype_str,
+ sql_argument::Type return_type,
+ std::string return_type_str,
+ std::string sql_str);
+ static base::Status EnableMemoization(Context*);
+};
+
+} // namespace trace_processor
+} // namespace perfetto
+
+#endif // SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_CREATED_FUNCTION_H_
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.cc b/src/trace_processor/perfetto_sql/engine/function_util.cc
similarity index 97%
rename from src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.cc
rename to src/trace_processor/perfetto_sql/engine/function_util.cc
index e500d9b..dfba021 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.cc
+++ b/src/trace_processor/perfetto_sql/engine/function_util.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h"
+#include "src/trace_processor/perfetto_sql/engine/function_util.h"
#include "perfetto/base/status.h"
#include "perfetto/ext/base/string_view.h"
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h b/src/trace_processor/perfetto_sql/engine/function_util.h
similarity index 85%
rename from src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h
rename to src/trace_processor/perfetto_sql/engine/function_util.h
index 0f841a3..aa3515f 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h
+++ b/src/trace_processor/perfetto_sql/engine/function_util.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef SRC_TRACE_PROCESSOR_PERFETTO_SQL_INTRINSICS_FUNCTIONS_CREATE_FUNCTION_INTERNAL_H_
-#define SRC_TRACE_PROCESSOR_PERFETTO_SQL_INTRINSICS_FUNCTIONS_CREATE_FUNCTION_INTERNAL_H_
+#ifndef SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_FUNCTION_UTIL_H_
+#define SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_FUNCTION_UTIL_H_
#include <sqlite3.h>
#include <optional>
@@ -23,7 +23,6 @@
#include "perfetto/base/status.h"
#include "perfetto/ext/base/string_view.h"
-#include "perfetto/trace_processor/basic_types.h"
#include "src/trace_processor/util/sql_argument.h"
namespace perfetto {
@@ -61,4 +60,4 @@
} // namespace trace_processor
} // namespace perfetto
-#endif // SRC_TRACE_PROCESSOR_PERFETTO_SQL_INTRINSICS_FUNCTIONS_CREATE_FUNCTION_INTERNAL_H_
+#endif // SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_FUNCTION_UTIL_H_
diff --git a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc
index e0a0243..933299d 100644
--- a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc
+++ b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc
@@ -22,6 +22,9 @@
#include "perfetto/base/status.h"
#include "perfetto/ext/base/string_utils.h"
+#include "perfetto/ext/base/string_view.h"
+#include "src/trace_processor/perfetto_sql/engine/created_function.h"
+#include "src/trace_processor/perfetto_sql/engine/function_util.h"
#include "src/trace_processor/perfetto_sql/engine/perfetto_sql_parser.h"
#include "src/trace_processor/sqlite/db_sqlite_table.h"
#include "src/trace_processor/sqlite/scoped_db.h"
@@ -203,5 +206,60 @@
return ExecutionResult{std::move(*res), stats};
}
+base::Status PerfettoSqlEngine::RegisterSqlFunction(std::string prototype_str,
+ std::string return_type_str,
+ std::string sql_str) {
+ // Parse all the arguments into a more friendly form.
+ Prototype prototype;
+ base::Status status =
+ ParsePrototype(base::StringView(prototype_str), prototype);
+ if (!status.ok()) {
+ return base::ErrStatus("CREATE_FUNCTION[prototype=%s]: %s",
+ prototype_str.c_str(), status.c_message());
+ }
+
+ // Parse the return type into a enum format.
+ auto opt_return_type =
+ sql_argument::ParseType(base::StringView(return_type_str));
+ if (!opt_return_type) {
+ return base::ErrStatus(
+ "CREATE_FUNCTION[prototype=%s, return=%s]: unknown return type "
+ "specified",
+ prototype_str.c_str(), return_type_str.c_str());
+ }
+
+ int created_argc = static_cast<int>(prototype.arguments.size());
+ auto* ctx = static_cast<CreatedFunction::Context*>(
+ sqlite_engine()->GetFunctionContext(prototype.function_name,
+ created_argc));
+ if (!ctx) {
+ // We register the function with SQLite before we prepare the statement so
+ // the statement can reference the function itself, enabling recursive
+ // calls.
+ std::unique_ptr<CreatedFunction::Context> created_fn_ctx =
+ CreatedFunction::MakeContext(this);
+ ctx = created_fn_ctx.get();
+ RETURN_IF_ERROR(RegisterCppFunction<CreatedFunction>(
+ prototype.function_name.c_str(), created_argc,
+ std::move(created_fn_ctx)));
+ }
+ return CreatedFunction::ValidateOrPrepare(
+ ctx, std::move(prototype), std::move(prototype_str),
+ std::move(*opt_return_type), std::move(return_type_str),
+ std::move(sql_str));
+}
+
+base::Status PerfettoSqlEngine::EnableSqlFunctionMemoization(
+ const std::string& name) {
+ constexpr size_t kSupportedArgCount = 1;
+ CreatedFunction::Context* ctx = static_cast<CreatedFunction::Context*>(
+ sqlite_engine()->GetFunctionContext(name.c_str(), kSupportedArgCount));
+ if (!ctx) {
+ return base::ErrStatus(
+ "EXPERIMENTAL_MEMOIZE: Function %s(INT) does not exist", name.c_str());
+ }
+ return CreatedFunction::EnableMemoization(ctx);
+}
+
} // namespace trace_processor
} // namespace perfetto
diff --git a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h
index 598a01d..be80897 100644
--- a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h
+++ b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h
@@ -17,6 +17,7 @@
#ifndef SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_PERFETTO_SQL_ENGINE_H_
#define SRC_TRACE_PROCESSOR_PERFETTO_SQL_ENGINE_PERFETTO_SQL_ENGINE_H_
+#include "perfetto/base/status.h"
#include "perfetto/ext/base/status_or.h"
#include "src/trace_processor/perfetto_sql/intrinsics/functions/sql_function.h"
#include "src/trace_processor/perfetto_sql/intrinsics/table_functions/table_function.h"
@@ -75,7 +76,7 @@
// |determistic|: whether this function has deterministic output given the
// same set of arguments.
template <typename Function = SqlFunction>
- base::Status RegisterSqlFunction(const char* name,
+ base::Status RegisterCppFunction(const char* name,
int argc,
typename Function::Context* ctx,
bool deterministic = true);
@@ -87,12 +88,21 @@
// this pointer instead of the essentially static requirement of the context
// pointer above.
template <typename Function>
- base::Status RegisterSqlFunction(
+ base::Status RegisterCppFunction(
const char* name,
int argc,
std::unique_ptr<typename Function::Context> ctx,
bool deterministic = true);
+ // Registers a function with the prototype |prototype| which returns a value
+ // of |return_type| and is implemented by executing the SQL statement |sql|.
+ base::Status RegisterSqlFunction(std::string prototype,
+ std::string return_type,
+ std::string sql);
+
+ // Enables memoization for the given SQL function.
+ base::Status EnableSqlFunctionMemoization(const std::string& name);
+
// Registers a trace processor C++ table with SQLite with an SQL name of
// |name|.
void RegisterTable(const Table& table, const std::string& name);
@@ -168,7 +178,7 @@
} // namespace perfetto_sql_internal
template <typename Function>
-base::Status PerfettoSqlEngine::RegisterSqlFunction(
+base::Status PerfettoSqlEngine::RegisterCppFunction(
const char* name,
int argc,
typename Function::Context* ctx,
@@ -179,7 +189,7 @@
}
template <typename Function>
-base::Status PerfettoSqlEngine::RegisterSqlFunction(
+base::Status PerfettoSqlEngine::RegisterCppFunction(
const char* name,
int argc,
std::unique_ptr<typename Function::Context> user_data,
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/BUILD.gn b/src/trace_processor/perfetto_sql/intrinsics/functions/BUILD.gn
index 825e4a6..fb15c53 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/BUILD.gn
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/BUILD.gn
@@ -21,8 +21,6 @@
"clock_functions.h",
"create_function.cc",
"create_function.h",
- "create_function_internal.cc",
- "create_function_internal.h",
"create_view_function.cc",
"create_view_function.h",
"import.cc",
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/clock_functions.h b/src/trace_processor/perfetto_sql/intrinsics/functions/clock_functions.h
index 2179097..605cea5 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/clock_functions.h
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/clock_functions.h
@@ -22,10 +22,8 @@
#include "perfetto/ext/base/base64.h"
#include "protos/perfetto/common/builtin_clock.pbzero.h"
#include "src/trace_processor/importers/common/clock_converter.h"
-#include "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h"
-#include "src/trace_processor/util/status_macros.h"
-
#include "src/trace_processor/perfetto_sql/intrinsics/functions/sql_function.h"
+#include "src/trace_processor/util/status_macros.h"
namespace perfetto {
namespace trace_processor {
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/create_function.cc b/src/trace_processor/perfetto_sql/intrinsics/functions/create_function.cc
index 2e6c6a2..6e282b9 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/create_function.cc
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/create_function.cc
@@ -21,8 +21,8 @@
#include "perfetto/base/status.h"
#include "perfetto/trace_processor/basic_types.h"
+#include "src/trace_processor/perfetto_sql/engine/function_util.h"
#include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
-#include "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h"
#include "src/trace_processor/sqlite/scoped_db.h"
#include "src/trace_processor/sqlite/sql_source.h"
#include "src/trace_processor/sqlite/sqlite_engine.h"
@@ -33,664 +33,6 @@
namespace perfetto {
namespace trace_processor {
-namespace {
-
-base::StatusOr<SqliteEngine::PreparedStatement> CreateStatement(
- PerfettoSqlEngine* engine,
- const std::string& sql,
- const std::string& prototype) {
- auto res = engine->sqlite_engine()->PrepareStatement(
- SqlSource::FromFunction(sql.c_str(), prototype));
- RETURN_IF_ERROR(res.status());
- return std::move(res.value());
-}
-
-base::Status CheckNoMoreRows(sqlite3_stmt* stmt,
- sqlite3* db,
- const Prototype& prototype) {
- int ret = sqlite3_step(stmt);
- RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
- if (ret == SQLITE_ROW) {
- auto expanded_sql = sqlite_utils::ExpandedSqlForStmt(stmt);
- return base::ErrStatus(
- "%s: multiple values were returned when executing function body. "
- "Executed SQL was %s",
- prototype.function_name.c_str(), expanded_sql.get());
- }
- PERFETTO_DCHECK(ret == SQLITE_DONE);
- return base::OkStatus();
-}
-
-// Note: if the returned type is string / bytes, it will be invalidated by the
-// next call to SQLite, so the caller must take care to either copy or use the
-// value before calling SQLite again.
-base::StatusOr<SqlValue> EvaluateScalarStatement(sqlite3_stmt* stmt,
- sqlite3* db,
- const Prototype& prototype) {
- int ret = sqlite3_step(stmt);
- RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
- if (ret == SQLITE_DONE) {
- // No return value means we just return don't set |out|.
- return SqlValue();
- }
-
- PERFETTO_DCHECK(ret == SQLITE_ROW);
- size_t col_count = static_cast<size_t>(sqlite3_column_count(stmt));
- if (col_count != 1) {
- return base::ErrStatus(
- "%s: SQL definition should only return one column: returned %zu "
- "columns",
- prototype.function_name.c_str(), col_count);
- }
-
- SqlValue result =
- sqlite_utils::SqliteValueToSqlValue(sqlite3_column_value(stmt, 0));
-
- // If we return a bytes type but have a null pointer, SQLite will convert this
- // to an SQL null. However, for proto build functions, we actively want to
- // distinguish between nulls and 0 byte strings. Therefore, change the value
- // to an empty string.
- if (result.type == SqlValue::kBytes && result.bytes_value == nullptr) {
- PERFETTO_DCHECK(result.bytes_count == 0);
- result.bytes_value = "";
- }
-
- return result;
-}
-
-base::Status BindArguments(sqlite3_stmt* stmt,
- const Prototype& prototype,
- size_t argc,
- sqlite3_value** argv) {
- // Bind all the arguments to the appropriate places in the function.
- for (size_t i = 0; i < argc; ++i) {
- RETURN_IF_ERROR(MaybeBindArgument(stmt, prototype.function_name,
- prototype.arguments[i], argv[i]));
- }
- return base::OkStatus();
-}
-
-struct CreatedFunction : public SqlFunction {
- class Context;
-
- static base::Status Run(Context* ctx,
- size_t argc,
- sqlite3_value** argv,
- SqlValue& out,
- Destructors&);
- static base::Status VerifyPostConditions(Context*);
- static void Cleanup(Context*);
-};
-
-struct StoredSqlValue {
- // unique_ptr to ensure that the pointers to these values are long-lived.
- using OwnedString = std::unique_ptr<std::string>;
- using OwnedBytes = std::unique_ptr<std::vector<uint8_t>>;
- // variant is a pain to use, but it's the simplest way to ensure that
- // the destructors run correctly for non-trivial members of the
- // union.
- using Data =
- std::variant<int64_t, double, OwnedString, OwnedBytes, nullptr_t>;
-
- StoredSqlValue(SqlValue value) {
- switch (value.type) {
- case SqlValue::Type::kNull:
- data = nullptr;
- break;
- case SqlValue::Type::kLong:
- data = value.long_value;
- break;
- case SqlValue::Type::kDouble:
- data = value.double_value;
- break;
- case SqlValue::Type::kString:
- data = std::make_unique<std::string>(value.string_value);
- break;
- case SqlValue::Type::kBytes:
- const uint8_t* ptr = static_cast<const uint8_t*>(value.bytes_value);
- data = std::make_unique<std::vector<uint8_t>>(ptr,
- ptr + value.bytes_count);
- break;
- }
- }
-
- SqlValue AsSqlValue() {
- if (std::holds_alternative<nullptr_t>(data)) {
- return SqlValue();
- } else if (std::holds_alternative<int64_t>(data)) {
- return SqlValue::Long(std::get<int64_t>(data));
- } else if (std::holds_alternative<double>(data)) {
- return SqlValue::Double(std::get<double>(data));
- } else if (std::holds_alternative<OwnedString>(data)) {
- const auto& str_ptr = std::get<OwnedString>(data);
- return SqlValue::String(str_ptr->c_str());
- } else if (std::holds_alternative<OwnedBytes>(data)) {
- const auto& bytes_ptr = std::get<OwnedBytes>(data);
- return SqlValue::Bytes(bytes_ptr->data(), bytes_ptr->size());
- }
- // GCC doesn't realize that the switch is exhaustive.
- PERFETTO_CHECK(false);
- return SqlValue();
- }
-
- Data data = nullptr;
-};
-
-class Memoizer {
- public:
- // Supported arguments. For now, only functions with a single int argument are
- // supported.
- using MemoizedArgs = int64_t;
-
- // Enables memoization.
- // Only functions with a single int argument returning ints are supported.
- base::Status EnableMemoization(const Prototype& prototype) {
- if (prototype.arguments.size() != 1 ||
- TypeToSqlValueType(prototype.arguments[0].type()) !=
- SqlValue::Type::kLong) {
- return base::ErrStatus(
- "EXPERIMENTAL_MEMOIZE: Function %s should take one int argument",
- prototype.function_name.c_str());
- }
- enabled_ = true;
- return base::OkStatus();
- }
-
- // Returns the memoized value for the current invocation if it exists.
- std::optional<SqlValue> GetMemoizedValue(MemoizedArgs args) {
- if (!enabled_) {
- return std::nullopt;
- }
- StoredSqlValue* value = memoized_values_.Find(args);
- if (!value) {
- return std::nullopt;
- }
- return value->AsSqlValue();
- }
-
- bool HasMemoizedValue(MemoizedArgs args) {
- return GetMemoizedValue(args).has_value();
- }
-
- // Saves the return value of the current invocation for memoization.
- void Memoize(MemoizedArgs args, SqlValue value) {
- if (!enabled_) {
- return;
- }
- memoized_values_.Insert(args, StoredSqlValue(value));
- }
-
- // Checks that the function has a single int argument and returns it.
- static std::optional<MemoizedArgs> AsMemoizedArgs(size_t argc,
- sqlite3_value** argv) {
- if (argc != 1) {
- return std::nullopt;
- }
- SqlValue arg = sqlite_utils::SqliteValueToSqlValue(argv[0]);
- if (arg.type != SqlValue::Type::kLong) {
- return std::nullopt;
- }
- return arg.AsLong();
- }
-
- bool enabled() const { return enabled_; }
-
- private:
- bool enabled_ = false;
- base::FlatHashMap<MemoizedArgs, StoredSqlValue> memoized_values_;
-};
-
-// A helper to unroll recursive calls: to minimise the amount of stack space
-// used, memoized recursive calls are evaluated using an on-heap queue.
-//
-// We compute the function in two passes:
-// - In the first pass, we evaluate the statement to discover which recursive
-// calls it makes, returning null from recursive calls and ignoring the
-// result.
-// - In the second pass, we evaluate the statement again, but this time we
-// memoize the result of each recursive call.
-//
-// We maintain a queue for scheduled "first pass" calls and a stack for the
-// scheduled "second pass" calls, evaluating available first pass calls, then
-// second pass calls. When we evaluate a first pass call, the further calls to
-// CreatedFunction::Run will just add it to the "first pass" queue. The second
-// pass, however, will evaluate the function normally, typically just using the
-// memoized result for the dependent calls. However, if the recursive calls
-// depend on the return value of the function, we will proceed with normal
-// recursion.
-//
-// To make it more concrete, consider an following example.
-// We have a function computing factorial (f) and we want to compute f(3).
-//
-// SELECT create_function('f(x INT)', 'INT',
-// 'SELECT IIF($x = 0, 1, $x * f($x - 1))');
-// SELECT experimental_memoize('f');
-// SELECT f(3);
-//
-// - We start with a call to f(3). It executes the statement as normal, which
-// recursively calls f(2).
-// - When f(2) is called, we detect that it is a recursive call and we start
-// unrolling it, entering RecursiveCallUnroller::Run.
-// - We schedule first pass for 2 and the state of the unroller
-// is first_pass: [2], second_pass: [].
-// - Then we compute the first pass for f(2). It calls f(1), which is ignored
-// due to OnFunctionCall returning kIgnoreDueToFirstPass and 1 is added to the
-// first pass queue. 2 is taked out of the first pass queue and moved to the
-// second pass stack. State: first_pass: [1], second_pass: [2].
-// - Then we compute the first pass for 1. The similar thing happens: f(0) is
-// called and ignored, 0 is added to first_pass, 1 is added to second_pass.
-// State: first_pass: [0], second_pass: [2, 1].
-// - Then we compute the first pass for 0. It doesn't make further calls, so
-// 0 is moved to the second pass stack.
-// State: first_pass: [], second_pass: [2, 1, 0].
-// - Then we compute the second pass for 0. It just returns 1.
-// State: first_pass: [], second_pass: [2, 1], results: {0: 1}.
-// - Then we compute the second pass for 1. It calls f(0), which is memoized.
-// State: first_pass: [], second_pass: [2], results: {0: 1, 1: 1}.
-// - Then we compute the second pass for 1. It calls f(1), which is memoized.
-// State: first_pass: [], second_pass: [], results: {0: 1, 1: 1, 2: 2}.
-// - As both first_pass and second_pass are empty, we return from
-// RecursiveCallUnroller::Run.
-// - Control is returned to CreatedFunction::Run for f(2), which returns
-// memoized value.
-// - Then control is returned to CreatedFunction::Run for f(3), which completes
-// the computation.
-class RecursiveCallUnroller {
- public:
- RecursiveCallUnroller(PerfettoSqlEngine* engine,
- sqlite3_stmt* stmt,
- const Prototype& prototype,
- Memoizer& memoizer)
- : engine_(engine),
- stmt_(stmt),
- prototype_(prototype),
- memoizer_(memoizer) {}
-
- // Whether we should just return null due to us being in the "first pass".
- enum class FunctionCallState {
- kIgnoreDueToFirstPass,
- kEvaluate,
- };
-
- base::StatusOr<FunctionCallState> OnFunctionCall(
- Memoizer::MemoizedArgs args) {
- // If we are in the second pass, we just continue the function execution,
- // including checking if a memoized value is available and returning it.
- //
- // We generally expect a memoized value to be available, but there are
- // cases when it might not be the case, e.g. when which recursive calls are
- // made depends on the return value of the function, e.g. for the following
- // function, the first pass will not detect f(y) calls, so they will
- // be computed recursively.
- // f(x): SELECT max(f(y)) FROM y WHERE y < f($x - 1);
- if (state_ == State::kComputingSecondPass) {
- return FunctionCallState::kEvaluate;
- }
- if (!memoizer_.HasMemoizedValue(args)) {
- ArgState* state = visited_.Find(args);
- if (state) {
- // Detect recursive loops, e.g. f(1) calling f(2) calling f(1).
- if (*state == ArgState::kEvaluating) {
- return base::ErrStatus("Infinite recursion detected");
- }
- } else {
- visited_.Insert(args, ArgState::kScheduled);
- first_pass_.push(args);
- }
- }
- return FunctionCallState::kIgnoreDueToFirstPass;
- }
-
- base::Status Run(Memoizer::MemoizedArgs initial_args) {
- PERFETTO_TP_TRACE(metatrace::Category::FUNCTION,
- "UNROLL_RECURSIVE_FUNCTION_CALL",
- [&](metatrace::Record* r) {
- r->AddArg("Function", prototype_.function_name);
- r->AddArg("Arg 0", std::to_string(initial_args));
- });
-
- first_pass_.push(initial_args);
- visited_.Insert(initial_args, ArgState::kScheduled);
-
- while (!first_pass_.empty() || !second_pass_.empty()) {
- // If we have scheduled first pass calls, we evaluate them first.
- if (!first_pass_.empty()) {
- state_ = State::kComputingFirstPass;
- Memoizer::MemoizedArgs args = first_pass_.front();
-
- PERFETTO_TP_TRACE(metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
- [&](metatrace::Record* r) {
- r->AddArg("Function", prototype_.function_name);
- r->AddArg("Type", "UnrollRecursiveCall_FirstPass");
- r->AddArg("Arg 0", std::to_string(args));
- });
-
- first_pass_.pop();
- second_pass_.push(args);
- Evaluate(args).status();
- continue;
- }
-
- state_ = State::kComputingSecondPass;
- Memoizer::MemoizedArgs args = second_pass_.top();
-
- PERFETTO_TP_TRACE(metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
- [&](metatrace::Record* r) {
- r->AddArg("Function", prototype_.function_name);
- r->AddArg("Type", "UnrollRecursiveCall_SecondPass");
- r->AddArg("Arg 0", std::to_string(args));
- });
-
- visited_.Insert(args, ArgState::kEvaluating);
- second_pass_.pop();
- base::StatusOr<std::optional<int64_t>> result = Evaluate(args);
- RETURN_IF_ERROR(result.status());
- std::optional<int64_t> maybe_int_result = result.value();
- if (!maybe_int_result.has_value()) {
- continue;
- }
- visited_.Insert(args, ArgState::kEvaluated);
- memoizer_.Memoize(args, SqlValue::Long(*maybe_int_result));
- }
- return base::OkStatus();
- }
-
- private:
- // This function returns:
- // - base::ErrStatus if the evaluation of the function failed.
- // - std::nullopt if the function returned a non-integer value.
- // - the result of the function otherwise.
- base::StatusOr<std::optional<int64_t>> Evaluate(Memoizer::MemoizedArgs args) {
- RETURN_IF_ERROR(MaybeBindIntArgument(stmt_, prototype_.function_name,
- prototype_.arguments[0], args));
- base::StatusOr<SqlValue> result = EvaluateScalarStatement(
- stmt_, engine_->sqlite_engine()->db(), prototype_);
- sqlite3_reset(stmt_);
- sqlite3_clear_bindings(stmt_);
- RETURN_IF_ERROR(result.status());
- if (result->type != SqlValue::Type::kLong) {
- return std::optional<int64_t>(std::nullopt);
- }
- return std::optional<int64_t>(result->long_value);
- }
-
- PerfettoSqlEngine* engine_;
- sqlite3_stmt* stmt_;
- const Prototype& prototype_;
- Memoizer& memoizer_;
-
- // Current state of the evaluation.
- enum class State {
- kComputingFirstPass,
- kComputingSecondPass,
- };
- State state_ = State::kComputingFirstPass;
-
- // A state of evaluation of a given argument.
- enum class ArgState {
- kScheduled,
- kEvaluating,
- kEvaluated,
- };
-
- // See the class-level comment for the explanation of the two passes.
- std::queue<Memoizer::MemoizedArgs> first_pass_;
- base::FlatHashMap<Memoizer::MemoizedArgs, ArgState> visited_;
- std::stack<Memoizer::MemoizedArgs> second_pass_;
-};
-
-// This class is used to store the state of a CREATE_FUNCTION call.
-// It is used to store the state of the function across multiple invocations
-// of the function (e.g. when the function is called recursively).
-class CreatedFunction::Context {
- public:
- explicit Context(PerfettoSqlEngine* engine) : engine_(engine) {}
-
- // Prepare a statement and push it into the stack of allocated statements
- // for this function.
- base::Status PrepareStatement() {
- base::StatusOr<SqliteEngine::PreparedStatement> stmt =
- CreateStatement(engine_, sql_, prototype_str_);
- RETURN_IF_ERROR(stmt.status());
- is_valid_ = true;
- stmts_.push_back(std::move(stmt.value()));
- return base::OkStatus();
- }
-
- // Sets the state of the function. Should be called only when the function
- // is invalid (i.e. when it is first created or when the previous statement
- // failed to prepare).
- void Reset(Prototype prototype,
- std::string prototype_str,
- sql_argument::Type return_type,
- std::string sql) {
- // Re-registration of valid functions is not allowed.
- PERFETTO_DCHECK(!is_valid_);
- PERFETTO_DCHECK(stmts_.empty());
-
- prototype_ = std::move(prototype);
- prototype_str_ = std::move(prototype_str);
- return_type_ = return_type;
- sql_ = std::move(sql);
- }
-
- // This function is called each time the function is called.
- // It ensures that we have a statement for the current recursion level,
- // allocating a new one if needed.
- base::Status PushStackEntry() {
- ++current_recursion_level_;
- if (current_recursion_level_ > stmts_.size()) {
- return PrepareStatement();
- }
- return base::OkStatus();
- }
-
- // Returns the statement that is used for the current invocation.
- sqlite3_stmt* CurrentStatement() {
- return stmts_[current_recursion_level_ - 1].sqlite_stmt();
- }
-
- // This function is called each time the function returns and resets the
- // statement that this invocation used.
- void PopStackEntry() {
- if (current_recursion_level_ > stmts_.size()) {
- // This is possible if we didn't prepare the statement and returned
- // an error.
- return;
- }
- sqlite3_reset(CurrentStatement());
- sqlite3_clear_bindings(CurrentStatement());
- --current_recursion_level_;
- }
-
- base::StatusOr<RecursiveCallUnroller::FunctionCallState> OnFunctionCall(
- Memoizer::MemoizedArgs args) {
- if (!recursive_call_unroller_) {
- return RecursiveCallUnroller::FunctionCallState::kEvaluate;
- }
- return recursive_call_unroller_->OnFunctionCall(args);
- }
-
- // Called before checking the function for memoization.
- base::Status UnrollRecursiveCallIfNeeded(Memoizer::MemoizedArgs args) {
- if (!memoizer_.enabled() || !is_in_recursive_call() ||
- recursive_call_unroller_) {
- return base::OkStatus();
- }
- // If we are in a recursive call, we need to check if we have already
- // computed the result for the current arguments.
- if (memoizer_.HasMemoizedValue(args)) {
- return base::OkStatus();
- }
-
- // If we are in a beginning of a function call:
- // - is a recursive,
- // - can be memoized,
- // - hasn't been memoized already, and
- // - hasn't start unrolling yet;
- // start the unrolling and run the unrolling loop.
- recursive_call_unroller_ = std::make_unique<RecursiveCallUnroller>(
- engine_, CurrentStatement(), prototype_, memoizer_);
- auto status = recursive_call_unroller_->Run(args);
- recursive_call_unroller_.reset();
- return status;
- }
-
- // Schedule a statement to be validated that it is indeed doesn't have any
- // more rows.
- void ScheduleEmptyStatementValidation(sqlite3_stmt* stmt) {
- empty_stmts_to_validate_.push_back(stmt);
- }
-
- base::Status ValidateEmptyStatements() {
- while (!empty_stmts_to_validate_.empty()) {
- sqlite3_stmt* stmt = empty_stmts_to_validate_.back();
- empty_stmts_to_validate_.pop_back();
- RETURN_IF_ERROR(
- CheckNoMoreRows(stmt, engine_->sqlite_engine()->db(), prototype_));
- }
- return base::OkStatus();
- }
-
- bool is_in_recursive_call() const { return current_recursion_level_ > 1; }
-
- base::Status EnableMemoization() {
- return memoizer_.EnableMemoization(prototype_);
- }
-
- PerfettoSqlEngine* engine() const { return engine_; }
-
- const Prototype& prototype() const { return prototype_; }
-
- sql_argument::Type return_type() const { return return_type_; }
-
- const std::string& sql() const { return sql_; }
-
- bool is_valid() const { return is_valid_; }
-
- Memoizer& memoizer() { return memoizer_; }
-
- private:
- PerfettoSqlEngine* engine_;
- Prototype prototype_;
- std::string prototype_str_;
- sql_argument::Type return_type_;
- std::string sql_;
- // Perfetto SQL functions support recursion. Given that each function call in
- // the stack requires a dedicated statement, we maintain a stack of prepared
- // statements and use the top one for each new call (allocating a new one if
- // needed).
- std::vector<SqliteEngine::PreparedStatement> stmts_;
- // A list of statements to verify to ensure that they don't have more rows
- // in VerifyPostConditions.
- std::vector<sqlite3_stmt*> empty_stmts_to_validate_;
- size_t current_recursion_level_ = 0;
- // Function re-registration is not allowed, but the user is allowed to define
- // the function again if the first call failed. |is_valid_| flag helps that
- // by tracking whether the current function definition is valid (in which case
- // re-registration is not allowed).
- bool is_valid_ = false;
- Memoizer memoizer_;
- // Set if we are in a middle of unrolling a recursive call.
- std::unique_ptr<RecursiveCallUnroller> recursive_call_unroller_;
-};
-
-base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
- size_t argc,
- sqlite3_value** argv,
- SqlValue& out,
- Destructors&) {
- if (argc != ctx->prototype().arguments.size()) {
- return base::ErrStatus(
- "%s: invalid number of args; expected %zu, received %zu",
- ctx->prototype().function_name.c_str(),
- ctx->prototype().arguments.size(), argc);
- }
-
- // Type check all the arguments.
- for (size_t i = 0; i < argc; ++i) {
- sqlite3_value* arg = argv[i];
- sql_argument::Type type = ctx->prototype().arguments[i].type();
- base::Status status = sqlite_utils::TypeCheckSqliteValue(
- arg, sql_argument::TypeToSqlValueType(type),
- sql_argument::TypeToHumanFriendlyString(type));
- if (!status.ok()) {
- return base::ErrStatus("%s[arg=%s]: argument %zu %s",
- ctx->prototype().function_name.c_str(),
- sqlite3_value_text(arg), i, status.c_message());
- }
- }
-
- // Enter the function and ensure that we have a statement allocated.
- RETURN_IF_ERROR(ctx->PushStackEntry());
-
- std::optional<Memoizer::MemoizedArgs> memoized_args =
- Memoizer::AsMemoizedArgs(argc, argv);
-
- if (memoized_args) {
- // If we are in the middle of an recursive calls unrolling, we might want to
- // ignore the function invocation. See the comment in RecursiveCallUnroller
- // for more details.
- base::StatusOr<RecursiveCallUnroller::FunctionCallState> unroll_state =
- ctx->OnFunctionCall(*memoized_args);
- RETURN_IF_ERROR(unroll_state.status());
- if (*unroll_state ==
- RecursiveCallUnroller::FunctionCallState::kIgnoreDueToFirstPass) {
- // Return NULL.
- return base::OkStatus();
- }
-
- RETURN_IF_ERROR(ctx->UnrollRecursiveCallIfNeeded(*memoized_args));
-
- std::optional<SqlValue> memoized_value =
- ctx->memoizer().GetMemoizedValue(*memoized_args);
- if (memoized_value) {
- out = *memoized_value;
- return base::OkStatus();
- }
- }
-
- PERFETTO_TP_TRACE(
- metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
- [ctx, argv](metatrace::Record* r) {
- r->AddArg("Function", ctx->prototype().function_name.c_str());
- for (uint32_t i = 0; i < ctx->prototype().arguments.size(); ++i) {
- std::string key = "Arg " + std::to_string(i);
- const char* value =
- reinterpret_cast<const char*>(sqlite3_value_text(argv[i]));
- r->AddArg(base::StringView(key),
- value ? base::StringView(value) : base::StringView("NULL"));
- }
- });
-
- RETURN_IF_ERROR(
- BindArguments(ctx->CurrentStatement(), ctx->prototype(), argc, argv));
- auto result = EvaluateScalarStatement(ctx->CurrentStatement(),
- ctx->engine()->sqlite_engine()->db(),
- ctx->prototype());
- RETURN_IF_ERROR(result.status());
- out = result.value();
- ctx->ScheduleEmptyStatementValidation(ctx->CurrentStatement());
-
- if (memoized_args) {
- ctx->memoizer().Memoize(*memoized_args, out);
- }
-
- return base::OkStatus();
-}
-
-void CreatedFunction::Cleanup(CreatedFunction::Context* ctx) {
- // Clear the statement.
- ctx->PopStackEntry();
-}
-
-base::Status CreatedFunction::VerifyPostConditions(
- CreatedFunction::Context* ctx) {
- return ctx->ValidateEmptyStatements();
-}
-
-} // namespace
-
base::Status CreateFunction::Run(PerfettoSqlEngine* engine,
size_t argc,
sqlite3_value** argv,
@@ -727,81 +69,11 @@
auto extract_string = [](sqlite3_value* value) -> base::StringView {
return reinterpret_cast<const char*>(sqlite3_value_text(value));
};
- base::StringView prototype_str = extract_string(prototype_value);
- base::StringView return_type_str = extract_string(return_type_value);
+ std::string prototype_str = extract_string(prototype_value).ToStdString();
+ std::string return_type_str = extract_string(return_type_value).ToStdString();
std::string sql_defn_str = extract_string(sql_defn_value).ToStdString();
-
- // Parse all the arguments into a more friendly form.
- Prototype prototype;
- base::Status status = ParsePrototype(prototype_str, prototype);
- if (!status.ok()) {
- return base::ErrStatus("CREATE_FUNCTION[prototype=%s]: %s",
- prototype_str.ToStdString().c_str(),
- status.c_message());
- }
-
- // Parse the return type into a enum format.
- auto opt_return_type = sql_argument::ParseType(return_type_str);
- if (!opt_return_type) {
- return base::ErrStatus(
- "CREATE_FUNCTION[prototype=%s, return=%s]: unknown return type "
- "specified",
- prototype_str.ToStdString().c_str(),
- return_type_str.ToStdString().c_str());
- }
-
- std::string function_name = prototype.function_name;
- int created_argc = static_cast<int>(prototype.arguments.size());
- auto* ctx = static_cast<CreatedFunction::Context*>(
- engine->sqlite_engine()->GetFunctionContext(prototype.function_name,
- created_argc));
- if (!ctx) {
- // We register the function with SQLite before we prepare the statement so
- // the statement can reference the function itself, enabling recursive
- // calls.
- std::unique_ptr<CreatedFunction::Context> created_fn_ctx =
- std::make_unique<CreatedFunction::Context>(engine);
- ctx = created_fn_ctx.get();
- RETURN_IF_ERROR(engine->RegisterSqlFunction<CreatedFunction>(
- function_name.c_str(), created_argc, std::move(created_fn_ctx)));
- }
- if (ctx->is_valid()) {
- // If the function already exists, just verify that the prototype, return
- // type and SQL matches exactly with what we already had registered. By
- // doing this, we can avoid the problem plaguing C++ macros where macro
- // ordering determines which one gets run.
- if (ctx->prototype() != prototype) {
- return base::ErrStatus(
- "CREATE_FUNCTION[prototype=%s]: function prototype changed",
- prototype_str.ToStdString().c_str());
- }
-
- if (ctx->return_type() != *opt_return_type) {
- return base::ErrStatus(
- "CREATE_FUNCTION[prototype=%s]: return type changed from %s to %s",
- prototype_str.ToStdString().c_str(),
- sql_argument::TypeToHumanFriendlyString(ctx->return_type()),
- return_type_str.ToStdString().c_str());
- }
-
- if (ctx->sql() != sql_defn_str) {
- return base::ErrStatus(
- "CREATE_FUNCTION[prototype=%s]: function SQL changed from %s to %s",
- prototype_str.ToStdString().c_str(), ctx->sql().c_str(),
- sql_defn_str.c_str());
- }
-
- return base::OkStatus();
- }
-
- ctx->Reset(std::move(prototype), prototype_str.ToStdString(),
- *opt_return_type, std::move(sql_defn_str));
-
- // Ideally, we would unregister the function here if the statement prep
- // failed, but SQLite doesn't allow unregistering functions inside active
- // statements. So instead we'll just try to prepare the statement when calling
- // this function, which will return an error.
- return ctx->PrepareStatement();
+ return engine->RegisterSqlFunction(prototype_str, return_type_str,
+ sql_defn_str);
}
base::Status ExperimentalMemoize::Run(PerfettoSqlEngine* engine,
@@ -813,17 +85,7 @@
base::StatusOr<std::string> function_name =
sqlite_utils::ExtractStringArg("MEMOIZE", "function_name", argv[0]);
RETURN_IF_ERROR(function_name.status());
-
- constexpr size_t kSupportedArgCount = 1;
- CreatedFunction::Context* ctx = static_cast<CreatedFunction::Context*>(
- engine->sqlite_engine()->GetFunctionContext(function_name->c_str(),
- kSupportedArgCount));
- if (!ctx) {
- return base::ErrStatus(
- "EXPERIMENTAL_MEMOIZE: Function %s(INT) does not exist",
- function_name->c_str());
- }
- return ctx->EnableMemoization();
+ return engine->EnableSqlFunctionMemoization(*function_name);
}
} // namespace trace_processor
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/create_view_function.cc b/src/trace_processor/perfetto_sql/intrinsics/functions/create_view_function.cc
index 4fbf8ad..2c08cf6 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/create_view_function.cc
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/create_view_function.cc
@@ -23,8 +23,8 @@
#include "perfetto/ext/base/string_utils.h"
#include "perfetto/ext/base/string_view.h"
#include "perfetto/trace_processor/basic_types.h"
+#include "src/trace_processor/perfetto_sql/engine/function_util.h"
#include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
-#include "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h"
#include "src/trace_processor/sqlite/scoped_db.h"
#include "src/trace_processor/sqlite/sqlite_table.h"
#include "src/trace_processor/sqlite/sqlite_utils.h"
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/import.cc b/src/trace_processor/perfetto_sql/intrinsics/functions/import.cc
index 48ace71..d4dd9e8 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/import.cc
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/import.cc
@@ -22,7 +22,6 @@
#include "perfetto/ext/base/string_utils.h"
#include "perfetto/ext/base/string_view.h"
#include "perfetto/trace_processor/basic_types.h"
-#include "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h"
#include "src/trace_processor/sqlite/scoped_db.h"
#include "src/trace_processor/sqlite/sql_source.h"
#include "src/trace_processor/sqlite/sqlite_table.h"
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/math.cc b/src/trace_processor/perfetto_sql/intrinsics/functions/math.cc
index 0a0e00a..0671a79 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/math.cc
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/math.cc
@@ -79,8 +79,8 @@
} // namespace
base::Status RegisterMathFunctions(PerfettoSqlEngine& engine) {
- RETURN_IF_ERROR(engine.RegisterSqlFunction<Ln>("ln", 1, nullptr, true));
- return engine.RegisterSqlFunction<Exp>("exp", 1, nullptr, true);
+ RETURN_IF_ERROR(engine.RegisterCppFunction<Ln>("ln", 1, nullptr, true));
+ return engine.RegisterCppFunction<Exp>("exp", 1, nullptr, true);
}
} // namespace perfetto::trace_processor
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/stack_functions.cc b/src/trace_processor/perfetto_sql/intrinsics/functions/stack_functions.cc
index 86d3c74..c78e6c9 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/stack_functions.cc
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/stack_functions.cc
@@ -246,13 +246,13 @@
base::Status RegisterStackFunctions(PerfettoSqlEngine* engine,
TraceProcessorContext* context) {
- RETURN_IF_ERROR(engine->RegisterSqlFunction<CatStacksFunction>(
+ RETURN_IF_ERROR(engine->RegisterCppFunction<CatStacksFunction>(
CatStacksFunction::kFunctionName, -1, context->storage.get()));
RETURN_IF_ERROR(
- engine->RegisterSqlFunction<StackFromStackProfileFrameFunction>(
+ engine->RegisterCppFunction<StackFromStackProfileFrameFunction>(
StackFromStackProfileFrameFunction::kFunctionName, 1,
context->storage.get()));
- return engine->RegisterSqlFunction<StackFromStackProfileCallsiteFunction>(
+ return engine->RegisterCppFunction<StackFromStackProfileCallsiteFunction>(
StackFromStackProfileCallsiteFunction::kFunctionName, -1,
context->storage.get());
}
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/utils.h b/src/trace_processor/perfetto_sql/intrinsics/functions/utils.h
index f859c62..be2b91f 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/utils.h
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/utils.h
@@ -26,7 +26,6 @@
#include "protos/perfetto/common/builtin_clock.pbzero.h"
#include "src/trace_processor/export_json.h"
#include "src/trace_processor/importers/common/clock_tracker.h"
-#include "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h"
#include "src/trace_processor/perfetto_sql/intrinsics/functions/sql_function.h"
#include "src/trace_processor/sqlite/sqlite_utils.h"
#include "src/trace_processor/util/status_macros.h"
diff --git a/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h b/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h
index ebcb844..3b90896 100644
--- a/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h
+++ b/src/trace_processor/perfetto_sql/intrinsics/functions/window_functions.h
@@ -25,10 +25,8 @@
#include "protos/perfetto/common/builtin_clock.pbzero.h"
#include "src/trace_processor/export_json.h"
#include "src/trace_processor/importers/common/clock_tracker.h"
-#include "src/trace_processor/perfetto_sql/intrinsics/functions/create_function_internal.h"
-#include "src/trace_processor/util/status_macros.h"
-
#include "src/trace_processor/perfetto_sql/intrinsics/functions/sql_function.h"
+#include "src/trace_processor/util/status_macros.h"
namespace perfetto {
namespace trace_processor {
diff --git a/src/trace_processor/trace_processor_impl.cc b/src/trace_processor/trace_processor_impl.cc
index 00d1ef5..a59757a 100644
--- a/src/trace_processor/trace_processor_impl.cc
+++ b/src/trace_processor/trace_processor_impl.cc
@@ -114,7 +114,7 @@
int argc,
Ptr context = nullptr,
bool deterministic = true) {
- auto status = engine->RegisterSqlFunction<SqlFunction>(
+ auto status = engine->RegisterCppFunction<SqlFunction>(
name, argc, std::move(context), deterministic);
if (!status.ok())
PERFETTO_ELOG("%s", status.c_message());