blob: a16f949ff2fe6959b050497194910772df34916a [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/trace_processor/prelude/functions/create_function.h"
#include <queue>
#include <stack>
#include "perfetto/base/status.h"
#include "perfetto/trace_processor/basic_types.h"
#include "src/trace_processor/prelude/functions/create_function_internal.h"
#include "src/trace_processor/sqlite/perfetto_sql_engine.h"
#include "src/trace_processor/sqlite/scoped_db.h"
#include "src/trace_processor/sqlite/sql_source.h"
#include "src/trace_processor/sqlite/sqlite_engine.h"
#include "src/trace_processor/sqlite/sqlite_utils.h"
#include "src/trace_processor/tp_metatrace.h"
#include "src/trace_processor/util/status_macros.h"
namespace perfetto {
namespace trace_processor {
namespace {
base::StatusOr<SqliteEngine::PreparedStatement> CreateStatement(
PerfettoSqlEngine* engine,
const std::string& sql,
const std::string& prototype) {
auto res = engine->sqlite_engine()->PrepareStatement(
SqlSource::FromFunction(sql.c_str(), prototype));
RETURN_IF_ERROR(res.status());
return std::move(res.value());
}
base::Status CheckNoMoreRows(sqlite3_stmt* stmt,
sqlite3* db,
const Prototype& prototype) {
int ret = sqlite3_step(stmt);
RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
if (ret == SQLITE_ROW) {
auto expanded_sql = sqlite_utils::ExpandedSqlForStmt(stmt);
return base::ErrStatus(
"%s: multiple values were returned when executing function body. "
"Executed SQL was %s",
prototype.function_name.c_str(), expanded_sql.get());
}
PERFETTO_DCHECK(ret == SQLITE_DONE);
return base::OkStatus();
}
// Note: if the returned type is string / bytes, it will be invalidated by the
// next call to SQLite, so the caller must take care to either copy or use the
// value before calling SQLite again.
base::StatusOr<SqlValue> EvaluateScalarStatement(sqlite3_stmt* stmt,
sqlite3* db,
const Prototype& prototype) {
int ret = sqlite3_step(stmt);
RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
if (ret == SQLITE_DONE) {
// No return value means we just return don't set |out|.
return SqlValue();
}
PERFETTO_DCHECK(ret == SQLITE_ROW);
size_t col_count = static_cast<size_t>(sqlite3_column_count(stmt));
if (col_count != 1) {
return base::ErrStatus(
"%s: SQL definition should only return one column: returned %zu "
"columns",
prototype.function_name.c_str(), col_count);
}
SqlValue result =
sqlite_utils::SqliteValueToSqlValue(sqlite3_column_value(stmt, 0));
// If we return a bytes type but have a null pointer, SQLite will convert this
// to an SQL null. However, for proto build functions, we actively want to
// distinguish between nulls and 0 byte strings. Therefore, change the value
// to an empty string.
if (result.type == SqlValue::kBytes && result.bytes_value == nullptr) {
PERFETTO_DCHECK(result.bytes_count == 0);
result.bytes_value = "";
}
return result;
}
base::Status BindArguments(sqlite3_stmt* stmt,
const Prototype& prototype,
size_t argc,
sqlite3_value** argv) {
// Bind all the arguments to the appropriate places in the function.
for (size_t i = 0; i < argc; ++i) {
RETURN_IF_ERROR(MaybeBindArgument(stmt, prototype.function_name,
prototype.arguments[i], argv[i]));
}
return base::OkStatus();
}
struct CreatedFunction : public SqlFunction {
class Context;
static base::Status Run(Context* ctx,
size_t argc,
sqlite3_value** argv,
SqlValue& out,
Destructors&);
static base::Status VerifyPostConditions(Context*);
static void Cleanup(Context*);
};
struct StoredSqlValue {
// unique_ptr to ensure that the pointers to these values are long-lived.
using OwnedString = std::unique_ptr<std::string>;
using OwnedBytes = std::unique_ptr<std::vector<uint8_t>>;
// variant is a pain to use, but it's the simplest way to ensure that
// the destructors run correctly for non-trivial members of the
// union.
using Data =
std::variant<int64_t, double, OwnedString, OwnedBytes, nullptr_t>;
StoredSqlValue(SqlValue value) {
switch (value.type) {
case SqlValue::Type::kNull:
data = nullptr;
break;
case SqlValue::Type::kLong:
data = value.long_value;
break;
case SqlValue::Type::kDouble:
data = value.double_value;
break;
case SqlValue::Type::kString:
data = std::make_unique<std::string>(value.string_value);
break;
case SqlValue::Type::kBytes:
const uint8_t* ptr = static_cast<const uint8_t*>(value.bytes_value);
data = std::make_unique<std::vector<uint8_t>>(ptr,
ptr + value.bytes_count);
break;
}
}
SqlValue AsSqlValue() {
if (std::holds_alternative<nullptr_t>(data)) {
return SqlValue();
} else if (std::holds_alternative<int64_t>(data)) {
return SqlValue::Long(std::get<int64_t>(data));
} else if (std::holds_alternative<double>(data)) {
return SqlValue::Double(std::get<double>(data));
} else if (std::holds_alternative<OwnedString>(data)) {
const auto& str_ptr = std::get<OwnedString>(data);
return SqlValue::String(str_ptr->c_str());
} else if (std::holds_alternative<OwnedBytes>(data)) {
const auto& bytes_ptr = std::get<OwnedBytes>(data);
return SqlValue::Bytes(bytes_ptr->data(), bytes_ptr->size());
}
// GCC doesn't realize that the switch is exhaustive.
PERFETTO_CHECK(false);
return SqlValue();
}
Data data = nullptr;
};
class Memoizer {
public:
// Supported arguments. For now, only functions with a single int argument are
// supported.
using MemoizedArgs = int64_t;
// Enables memoization.
// Only functions with a single int argument returning ints are supported.
base::Status EnableMemoization(const Prototype& prototype) {
if (prototype.arguments.size() != 1 ||
TypeToSqlValueType(prototype.arguments[0].type()) !=
SqlValue::Type::kLong) {
return base::ErrStatus(
"EXPERIMENTAL_MEMOIZE: Function %s should take one int argument",
prototype.function_name.c_str());
}
enabled_ = true;
return base::OkStatus();
}
// Returns the memoized value for the current invocation if it exists.
std::optional<SqlValue> GetMemoizedValue(MemoizedArgs args) {
if (!enabled_) {
return std::nullopt;
}
StoredSqlValue* value = memoized_values_.Find(args);
if (!value) {
return std::nullopt;
}
return value->AsSqlValue();
}
bool HasMemoizedValue(MemoizedArgs args) {
return GetMemoizedValue(args).has_value();
}
// Saves the return value of the current invocation for memoization.
void Memoize(MemoizedArgs args, SqlValue value) {
if (!enabled_) {
return;
}
memoized_values_.Insert(args, StoredSqlValue(value));
}
// Checks that the function has a single int argument and returns it.
static std::optional<MemoizedArgs> AsMemoizedArgs(size_t argc,
sqlite3_value** argv) {
if (argc != 1) {
return std::nullopt;
}
SqlValue arg = sqlite_utils::SqliteValueToSqlValue(argv[0]);
if (arg.type != SqlValue::Type::kLong) {
return std::nullopt;
}
return arg.AsLong();
}
bool enabled() const { return enabled_; }
private:
bool enabled_ = false;
base::FlatHashMap<MemoizedArgs, StoredSqlValue> memoized_values_;
};
// A helper to unroll recursive calls: to minimise the amount of stack space
// used, memoized recursive calls are evaluated using an on-heap queue.
//
// We compute the function in two passes:
// - In the first pass, we evaluate the statement to discover which recursive
// calls it makes, returning null from recursive calls and ignoring the
// result.
// - In the second pass, we evaluate the statement again, but this time we
// memoize the result of each recursive call.
//
// We maintain a queue for scheduled "first pass" calls and a stack for the
// scheduled "second pass" calls, evaluating available first pass calls, then
// second pass calls. When we evaluate a first pass call, the further calls to
// CreatedFunction::Run will just add it to the "first pass" queue. The second
// pass, however, will evaluate the function normally, typically just using the
// memoized result for the dependent calls. However, if the recursive calls
// depend on the return value of the function, we will proceed with normal
// recursion.
//
// To make it more concrete, consider an following example.
// We have a function computing factorial (f) and we want to compute f(3).
//
// SELECT create_function('f(x INT)', 'INT',
// 'SELECT IIF($x = 0, 1, $x * f($x - 1))');
// SELECT experimental_memoize('f');
// SELECT f(3);
//
// - We start with a call to f(3). It executes the statement as normal, which
// recursively calls f(2).
// - When f(2) is called, we detect that it is a recursive call and we start
// unrolling it, entering RecursiveCallUnroller::Run.
// - We schedule first pass for 2 and the state of the unroller
// is first_pass: [2], second_pass: [].
// - Then we compute the first pass for f(2). It calls f(1), which is ignored
// due to OnFunctionCall returning kIgnoreDueToFirstPass and 1 is added to the
// first pass queue. 2 is taked out of the first pass queue and moved to the
// second pass stack. State: first_pass: [1], second_pass: [2].
// - Then we compute the first pass for 1. The similar thing happens: f(0) is
// called and ignored, 0 is added to first_pass, 1 is added to second_pass.
// State: first_pass: [0], second_pass: [2, 1].
// - Then we compute the first pass for 0. It doesn't make further calls, so
// 0 is moved to the second pass stack.
// State: first_pass: [], second_pass: [2, 1, 0].
// - Then we compute the second pass for 0. It just returns 1.
// State: first_pass: [], second_pass: [2, 1], results: {0: 1}.
// - Then we compute the second pass for 1. It calls f(0), which is memoized.
// State: first_pass: [], second_pass: [2], results: {0: 1, 1: 1}.
// - Then we compute the second pass for 1. It calls f(1), which is memoized.
// State: first_pass: [], second_pass: [], results: {0: 1, 1: 1, 2: 2}.
// - As both first_pass and second_pass are empty, we return from
// RecursiveCallUnroller::Run.
// - Control is returned to CreatedFunction::Run for f(2), which returns
// memoized value.
// - Then control is returned to CreatedFunction::Run for f(3), which completes
// the computation.
class RecursiveCallUnroller {
public:
RecursiveCallUnroller(PerfettoSqlEngine* engine,
sqlite3_stmt* stmt,
const Prototype& prototype,
Memoizer& memoizer)
: engine_(engine),
stmt_(stmt),
prototype_(prototype),
memoizer_(memoizer) {}
// Whether we should just return null due to us being in the "first pass".
enum class FunctionCallState {
kIgnoreDueToFirstPass,
kEvaluate,
};
base::StatusOr<FunctionCallState> OnFunctionCall(
Memoizer::MemoizedArgs args) {
// If we are in the second pass, we just continue the function execution,
// including checking if a memoized value is available and returning it.
//
// We generally expect a memoized value to be available, but there are
// cases when it might not be the case, e.g. when which recursive calls are
// made depends on the return value of the function, e.g. for the following
// function, the first pass will not detect f(y) calls, so they will
// be computed recursively.
// f(x): SELECT max(f(y)) FROM y WHERE y < f($x - 1);
if (state_ == State::kComputingSecondPass) {
return FunctionCallState::kEvaluate;
}
if (!memoizer_.HasMemoizedValue(args)) {
ArgState* state = visited_.Find(args);
if (state) {
// Detect recursive loops, e.g. f(1) calling f(2) calling f(1).
if (*state == ArgState::kEvaluating) {
return base::ErrStatus("Infinite recursion detected");
}
} else {
visited_.Insert(args, ArgState::kScheduled);
first_pass_.push(args);
}
}
return FunctionCallState::kIgnoreDueToFirstPass;
}
base::Status Run(Memoizer::MemoizedArgs initial_args) {
PERFETTO_TP_TRACE(metatrace::Category::FUNCTION,
"UNROLL_RECURSIVE_FUNCTION_CALL",
[&](metatrace::Record* r) {
r->AddArg("Function", prototype_.function_name);
r->AddArg("Arg 0", std::to_string(initial_args));
});
first_pass_.push(initial_args);
visited_.Insert(initial_args, ArgState::kScheduled);
while (!first_pass_.empty() || !second_pass_.empty()) {
// If we have scheduled first pass calls, we evaluate them first.
if (!first_pass_.empty()) {
state_ = State::kComputingFirstPass;
Memoizer::MemoizedArgs args = first_pass_.front();
PERFETTO_TP_TRACE(metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
[&](metatrace::Record* r) {
r->AddArg("Function", prototype_.function_name);
r->AddArg("Type", "UnrollRecursiveCall_FirstPass");
r->AddArg("Arg 0", std::to_string(args));
});
first_pass_.pop();
second_pass_.push(args);
Evaluate(args).status();
continue;
}
state_ = State::kComputingSecondPass;
Memoizer::MemoizedArgs args = second_pass_.top();
PERFETTO_TP_TRACE(metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
[&](metatrace::Record* r) {
r->AddArg("Function", prototype_.function_name);
r->AddArg("Type", "UnrollRecursiveCall_SecondPass");
r->AddArg("Arg 0", std::to_string(args));
});
visited_.Insert(args, ArgState::kEvaluating);
second_pass_.pop();
base::StatusOr<std::optional<int64_t>> result = Evaluate(args);
RETURN_IF_ERROR(result.status());
std::optional<int64_t> maybe_int_result = result.value();
if (!maybe_int_result.has_value()) {
continue;
}
visited_.Insert(args, ArgState::kEvaluated);
memoizer_.Memoize(args, SqlValue::Long(*maybe_int_result));
}
return base::OkStatus();
}
private:
// This function returns:
// - base::ErrStatus if the evaluation of the function failed.
// - std::nullopt if the function returned a non-integer value.
// - the result of the function otherwise.
base::StatusOr<std::optional<int64_t>> Evaluate(Memoizer::MemoizedArgs args) {
RETURN_IF_ERROR(MaybeBindIntArgument(stmt_, prototype_.function_name,
prototype_.arguments[0], args));
base::StatusOr<SqlValue> result = EvaluateScalarStatement(
stmt_, engine_->sqlite_engine()->db(), prototype_);
sqlite3_reset(stmt_);
sqlite3_clear_bindings(stmt_);
RETURN_IF_ERROR(result.status());
if (result->type != SqlValue::Type::kLong) {
return std::optional<int64_t>(std::nullopt);
}
return std::optional<int64_t>(result->long_value);
}
PerfettoSqlEngine* engine_;
sqlite3_stmt* stmt_;
const Prototype& prototype_;
Memoizer& memoizer_;
// Current state of the evaluation.
enum class State {
kComputingFirstPass,
kComputingSecondPass,
};
State state_ = State::kComputingFirstPass;
// A state of evaluation of a given argument.
enum class ArgState {
kScheduled,
kEvaluating,
kEvaluated,
};
// See the class-level comment for the explanation of the two passes.
std::queue<Memoizer::MemoizedArgs> first_pass_;
base::FlatHashMap<Memoizer::MemoizedArgs, ArgState> visited_;
std::stack<Memoizer::MemoizedArgs> second_pass_;
};
// This class is used to store the state of a CREATE_FUNCTION call.
// It is used to store the state of the function across multiple invocations
// of the function (e.g. when the function is called recursively).
class CreatedFunction::Context {
public:
explicit Context(PerfettoSqlEngine* engine) : engine_(engine) {}
// Prepare a statement and push it into the stack of allocated statements
// for this function.
base::Status PrepareStatement() {
base::StatusOr<SqliteEngine::PreparedStatement> stmt =
CreateStatement(engine_, sql_, prototype_str_);
RETURN_IF_ERROR(stmt.status());
is_valid_ = true;
stmts_.push_back(std::move(stmt.value()));
return base::OkStatus();
}
// Sets the state of the function. Should be called only when the function
// is invalid (i.e. when it is first created or when the previous statement
// failed to prepare).
void Reset(Prototype prototype,
std::string prototype_str,
sql_argument::Type return_type,
std::string sql) {
// Re-registration of valid functions is not allowed.
PERFETTO_DCHECK(!is_valid_);
PERFETTO_DCHECK(stmts_.empty());
prototype_ = std::move(prototype);
prototype_str_ = std::move(prototype_str);
return_type_ = return_type;
sql_ = std::move(sql);
}
// This function is called each time the function is called.
// It ensures that we have a statement for the current recursion level,
// allocating a new one if needed.
base::Status PushStackEntry() {
++current_recursion_level_;
if (current_recursion_level_ > stmts_.size()) {
return PrepareStatement();
}
return base::OkStatus();
}
// Returns the statement that is used for the current invocation.
sqlite3_stmt* CurrentStatement() {
return stmts_[current_recursion_level_ - 1].sqlite_stmt();
}
// This function is called each time the function returns and resets the
// statement that this invocation used.
void PopStackEntry() {
if (current_recursion_level_ > stmts_.size()) {
// This is possible if we didn't prepare the statement and returned
// an error.
return;
}
sqlite3_reset(CurrentStatement());
sqlite3_clear_bindings(CurrentStatement());
--current_recursion_level_;
}
base::StatusOr<RecursiveCallUnroller::FunctionCallState> OnFunctionCall(
Memoizer::MemoizedArgs args) {
if (!recursive_call_unroller_) {
return RecursiveCallUnroller::FunctionCallState::kEvaluate;
}
return recursive_call_unroller_->OnFunctionCall(args);
}
// Called before checking the function for memoization.
base::Status UnrollRecursiveCallIfNeeded(Memoizer::MemoizedArgs args) {
if (!memoizer_.enabled() || !is_in_recursive_call() ||
recursive_call_unroller_) {
return base::OkStatus();
}
// If we are in a recursive call, we need to check if we have already
// computed the result for the current arguments.
if (memoizer_.HasMemoizedValue(args)) {
return base::OkStatus();
}
// If we are in a beginning of a function call:
// - is a recursive,
// - can be memoized,
// - hasn't been memoized already, and
// - hasn't start unrolling yet;
// start the unrolling and run the unrolling loop.
recursive_call_unroller_ = std::make_unique<RecursiveCallUnroller>(
engine_, CurrentStatement(), prototype_, memoizer_);
auto status = recursive_call_unroller_->Run(args);
recursive_call_unroller_.reset();
return status;
}
// Schedule a statement to be validated that it is indeed doesn't have any
// more rows.
void ScheduleEmptyStatementValidation(sqlite3_stmt* stmt) {
empty_stmts_to_validate_.push_back(stmt);
}
base::Status ValidateEmptyStatements() {
while (!empty_stmts_to_validate_.empty()) {
sqlite3_stmt* stmt = empty_stmts_to_validate_.back();
empty_stmts_to_validate_.pop_back();
RETURN_IF_ERROR(
CheckNoMoreRows(stmt, engine_->sqlite_engine()->db(), prototype_));
}
return base::OkStatus();
}
bool is_in_recursive_call() const { return current_recursion_level_ > 1; }
base::Status EnableMemoization() {
return memoizer_.EnableMemoization(prototype_);
}
PerfettoSqlEngine* engine() const { return engine_; }
const Prototype& prototype() const { return prototype_; }
sql_argument::Type return_type() const { return return_type_; }
const std::string& sql() const { return sql_; }
bool is_valid() const { return is_valid_; }
Memoizer& memoizer() { return memoizer_; }
private:
PerfettoSqlEngine* engine_;
Prototype prototype_;
std::string prototype_str_;
sql_argument::Type return_type_;
std::string sql_;
// Perfetto SQL functions support recursion. Given that each function call in
// the stack requires a dedicated statement, we maintain a stack of prepared
// statements and use the top one for each new call (allocating a new one if
// needed).
std::vector<SqliteEngine::PreparedStatement> stmts_;
// A list of statements to verify to ensure that they don't have more rows
// in VerifyPostConditions.
std::vector<sqlite3_stmt*> empty_stmts_to_validate_;
size_t current_recursion_level_ = 0;
// Function re-registration is not allowed, but the user is allowed to define
// the function again if the first call failed. |is_valid_| flag helps that
// by tracking whether the current function definition is valid (in which case
// re-registration is not allowed).
bool is_valid_ = false;
Memoizer memoizer_;
// Set if we are in a middle of unrolling a recursive call.
std::unique_ptr<RecursiveCallUnroller> recursive_call_unroller_;
};
base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
size_t argc,
sqlite3_value** argv,
SqlValue& out,
Destructors&) {
if (argc != ctx->prototype().arguments.size()) {
return base::ErrStatus(
"%s: invalid number of args; expected %zu, received %zu",
ctx->prototype().function_name.c_str(),
ctx->prototype().arguments.size(), argc);
}
// Type check all the arguments.
for (size_t i = 0; i < argc; ++i) {
sqlite3_value* arg = argv[i];
sql_argument::Type type = ctx->prototype().arguments[i].type();
base::Status status = sqlite_utils::TypeCheckSqliteValue(
arg, sql_argument::TypeToSqlValueType(type),
sql_argument::TypeToHumanFriendlyString(type));
if (!status.ok()) {
return base::ErrStatus("%s[arg=%s]: argument %zu %s",
ctx->prototype().function_name.c_str(),
sqlite3_value_text(arg), i, status.c_message());
}
}
// Enter the function and ensure that we have a statement allocated.
RETURN_IF_ERROR(ctx->PushStackEntry());
std::optional<Memoizer::MemoizedArgs> memoized_args =
Memoizer::AsMemoizedArgs(argc, argv);
if (memoized_args) {
// If we are in the middle of an recursive calls unrolling, we might want to
// ignore the function invocation. See the comment in RecursiveCallUnroller
// for more details.
base::StatusOr<RecursiveCallUnroller::FunctionCallState> unroll_state =
ctx->OnFunctionCall(*memoized_args);
RETURN_IF_ERROR(unroll_state.status());
if (*unroll_state ==
RecursiveCallUnroller::FunctionCallState::kIgnoreDueToFirstPass) {
// Return NULL.
return base::OkStatus();
}
RETURN_IF_ERROR(ctx->UnrollRecursiveCallIfNeeded(*memoized_args));
std::optional<SqlValue> memoized_value =
ctx->memoizer().GetMemoizedValue(*memoized_args);
if (memoized_value) {
out = *memoized_value;
return base::OkStatus();
}
}
PERFETTO_TP_TRACE(
metatrace::Category::FUNCTION, "SQL_FUNCTION_CALL",
[ctx, argv](metatrace::Record* r) {
r->AddArg("Function", ctx->prototype().function_name.c_str());
for (uint32_t i = 0; i < ctx->prototype().arguments.size(); ++i) {
std::string key = "Arg " + std::to_string(i);
const char* value =
reinterpret_cast<const char*>(sqlite3_value_text(argv[i]));
r->AddArg(base::StringView(key),
value ? base::StringView(value) : base::StringView("NULL"));
}
});
RETURN_IF_ERROR(
BindArguments(ctx->CurrentStatement(), ctx->prototype(), argc, argv));
auto result = EvaluateScalarStatement(ctx->CurrentStatement(),
ctx->engine()->sqlite_engine()->db(),
ctx->prototype());
RETURN_IF_ERROR(result.status());
out = result.value();
ctx->ScheduleEmptyStatementValidation(ctx->CurrentStatement());
if (memoized_args) {
ctx->memoizer().Memoize(*memoized_args, out);
}
return base::OkStatus();
}
void CreatedFunction::Cleanup(CreatedFunction::Context* ctx) {
// Clear the statement.
ctx->PopStackEntry();
}
base::Status CreatedFunction::VerifyPostConditions(
CreatedFunction::Context* ctx) {
return ctx->ValidateEmptyStatements();
}
} // namespace
base::Status CreateFunction::Run(PerfettoSqlEngine* engine,
size_t argc,
sqlite3_value** argv,
SqlValue&,
Destructors&) {
RETURN_IF_ERROR(sqlite_utils::CheckArgCount("CREATE_FUNCTION", argc, 3u));
sqlite3_value* prototype_value = argv[0];
sqlite3_value* return_type_value = argv[1];
sqlite3_value* sql_defn_value = argv[2];
// Type check all the arguments.
{
auto type_check = [prototype_value](sqlite3_value* value,
SqlValue::Type type, const char* desc) {
base::Status status = sqlite_utils::TypeCheckSqliteValue(value, type);
if (!status.ok()) {
return base::ErrStatus("CREATE_FUNCTION[prototype=%s]: %s %s",
sqlite3_value_text(prototype_value), desc,
status.c_message());
}
return base::OkStatus();
};
RETURN_IF_ERROR(type_check(prototype_value, SqlValue::Type::kString,
"function prototype (first argument)"));
RETURN_IF_ERROR(type_check(return_type_value, SqlValue::Type::kString,
"return type (second argument)"));
RETURN_IF_ERROR(type_check(sql_defn_value, SqlValue::Type::kString,
"SQL definition (third argument)"));
}
// Extract the arguments from the value wrappers.
auto extract_string = [](sqlite3_value* value) -> base::StringView {
return reinterpret_cast<const char*>(sqlite3_value_text(value));
};
base::StringView prototype_str = extract_string(prototype_value);
base::StringView return_type_str = extract_string(return_type_value);
std::string sql_defn_str = extract_string(sql_defn_value).ToStdString();
// Parse all the arguments into a more friendly form.
Prototype prototype;
base::Status status = ParsePrototype(prototype_str, prototype);
if (!status.ok()) {
return base::ErrStatus("CREATE_FUNCTION[prototype=%s]: %s",
prototype_str.ToStdString().c_str(),
status.c_message());
}
// Parse the return type into a enum format.
auto opt_return_type = sql_argument::ParseType(return_type_str);
if (!opt_return_type) {
return base::ErrStatus(
"CREATE_FUNCTION[prototype=%s, return=%s]: unknown return type "
"specified",
prototype_str.ToStdString().c_str(),
return_type_str.ToStdString().c_str());
}
std::string function_name = prototype.function_name;
int created_argc = static_cast<int>(prototype.arguments.size());
auto* ctx = static_cast<CreatedFunction::Context*>(
engine->sqlite_engine()->GetFunctionContext(prototype.function_name,
created_argc));
if (!ctx) {
// We register the function with SQLite before we prepare the statement so
// the statement can reference the function itself, enabling recursive
// calls.
std::unique_ptr<CreatedFunction::Context> created_fn_ctx =
std::make_unique<CreatedFunction::Context>(engine);
ctx = created_fn_ctx.get();
RETURN_IF_ERROR(engine->RegisterSqlFunction<CreatedFunction>(
function_name.c_str(), created_argc, std::move(created_fn_ctx)));
}
if (ctx->is_valid()) {
// If the function already exists, just verify that the prototype, return
// type and SQL matches exactly with what we already had registered. By
// doing this, we can avoid the problem plaguing C++ macros where macro
// ordering determines which one gets run.
if (ctx->prototype() != prototype) {
return base::ErrStatus(
"CREATE_FUNCTION[prototype=%s]: function prototype changed",
prototype_str.ToStdString().c_str());
}
if (ctx->return_type() != *opt_return_type) {
return base::ErrStatus(
"CREATE_FUNCTION[prototype=%s]: return type changed from %s to %s",
prototype_str.ToStdString().c_str(),
sql_argument::TypeToHumanFriendlyString(ctx->return_type()),
return_type_str.ToStdString().c_str());
}
if (ctx->sql() != sql_defn_str) {
return base::ErrStatus(
"CREATE_FUNCTION[prototype=%s]: function SQL changed from %s to %s",
prototype_str.ToStdString().c_str(), ctx->sql().c_str(),
sql_defn_str.c_str());
}
return base::OkStatus();
}
ctx->Reset(std::move(prototype), prototype_str.ToStdString(),
*opt_return_type, std::move(sql_defn_str));
// Ideally, we would unregister the function here if the statement prep
// failed, but SQLite doesn't allow unregistering functions inside active
// statements. So instead we'll just try to prepare the statement when calling
// this function, which will return an error.
return ctx->PrepareStatement();
}
base::Status ExperimentalMemoize::Run(PerfettoSqlEngine* engine,
size_t argc,
sqlite3_value** argv,
SqlValue&,
Destructors&) {
RETURN_IF_ERROR(sqlite_utils::CheckArgCount("EXPERIMENTAL_MEMOIZE", argc, 1));
base::StatusOr<std::string> function_name =
sqlite_utils::ExtractStringArg("MEMOIZE", "function_name", argv[0]);
RETURN_IF_ERROR(function_name.status());
constexpr size_t kSupportedArgCount = 1;
CreatedFunction::Context* ctx = static_cast<CreatedFunction::Context*>(
engine->sqlite_engine()->GetFunctionContext(function_name->c_str(),
kSupportedArgCount));
if (!ctx) {
return base::ErrStatus(
"EXPERIMENTAL_MEMOIZE: Function %s(INT) does not exist",
function_name->c_str());
}
return ctx->EnableMemoization();
}
} // namespace trace_processor
} // namespace perfetto