Merge "tp: allow RewriteAll function to deal with rewritten SqlSources" into main
diff --git a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc
index b34f556..1066192 100644
--- a/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc
+++ b/src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.cc
@@ -91,6 +91,11 @@
   return status;
 }
 
+SqlSource RewriteToDummySql(const SqlSource& source) {
+  return source.RewriteAllIgnoreExisting(
+      SqlSource::FromTraceProcessorImplementation("SELECT 0 WHERE 0"));
+}
+
 }  // namespace
 
 PerfettoSqlEngine::PerfettoSqlEngine(StringPool* pool)
@@ -202,15 +207,13 @@
           RegisterRuntimeTable(cst->name, cst->sql), parser.statement_sql()));
       // Since the rest of the code requires a statement, just use a no-value
       // dummy statement.
-      source = parser.statement_sql().FullRewrite(
-          SqlSource::FromTraceProcessorImplementation("SELECT 0 WHERE 0"));
+      source = RewriteToDummySql(parser.statement_sql());
     } else if (auto* include = std::get_if<PerfettoSqlParser::Include>(
                    &parser.statement())) {
       RETURN_IF_ERROR(ExecuteInclude(*include, parser));
       // Since the rest of the code requires a statement, just use a no-value
       // dummy statement.
-      source = parser.statement_sql().FullRewrite(
-          SqlSource::FromTraceProcessorImplementation("SELECT 0 WHERE 0"));
+      source = RewriteToDummySql(parser.statement_sql());
     } else {
       // If none of the above matched, this must just be an SQL statement
       // directly executable by SQLite.
@@ -450,8 +453,7 @@
 
     // Since the rest of the code requires a statement, just use a no-value
     // dummy statement.
-    return parser.statement_sql().FullRewrite(
-        SqlSource::FromTraceProcessorImplementation("SELECT 0 WHERE 0"));
+    return RewriteToDummySql(parser.statement_sql());
   }
 
   RuntimeTableFunction::State state{cf.prototype, cf.sql, {}, {}, std::nullopt};
@@ -565,7 +567,7 @@
 
   base::StackString<1024> create(
       "CREATE VIRTUAL TABLE %s USING runtime_table_function", fn_name.c_str());
-  return cf.sql.FullRewrite(
+  return cf.sql.RewriteAllIgnoreExisting(
       SqlSource::FromTraceProcessorImplementation(create.ToStdString()));
 }
 
diff --git a/src/trace_processor/sqlite/sql_source.cc b/src/trace_processor/sqlite/sql_source.cc
index 8a08ae1..449e638 100644
--- a/src/trace_processor/sqlite/sql_source.cc
+++ b/src/trace_processor/sqlite/sql_source.cc
@@ -141,9 +141,15 @@
   return source;
 }
 
-SqlSource SqlSource::FullRewrite(SqlSource source) const {
-  SqlSource::Rewriter rewriter(*this);
-  rewriter.Rewrite(0, static_cast<uint32_t>(sql().size()), source);
+SqlSource SqlSource::RewriteAllIgnoreExisting(SqlSource source) const {
+  // Reset any rewrites.
+  SqlSource copy = *this;
+  copy.root_.rewritten_sql = copy.root_.original_sql;
+  copy.root_.rewrites.clear();
+
+  SqlSource::Rewriter rewriter(std::move(copy));
+  rewriter.Rewrite(0, static_cast<uint32_t>(root_.original_sql.size()),
+                   std::move(source));
   return std::move(rewriter).Build();
 }
 
diff --git a/src/trace_processor/sqlite/sql_source.h b/src/trace_processor/sqlite/sql_source.h
index b4b9870..a91ce50 100644
--- a/src/trace_processor/sqlite/sql_source.h
+++ b/src/trace_processor/sqlite/sql_source.h
@@ -78,15 +78,12 @@
   // at |offset| with |len| characters.
   SqlSource Substr(uint32_t offset, uint32_t len) const;
 
-  // Creates a SqlSource instance with the execution SQL rewritten to
-  // |rewrite_sql| but preserving the context from |this|.
+  // Rewrites the SQL backing |this| to SQL from |source| ignoring any existing
+  // rewrites in |this|.
   //
   // This is useful when PerfettoSQL statements are transpiled into SQLite
   // statements but we want to preserve the context of the original statement.
-  //
-  // Note: this function should only be called if |this| has not already been
-  // rewritten (i.e. it is undefined behaviour if |IsRewritten()| returns true).
-  SqlSource FullRewrite(SqlSource) const;
+  SqlSource RewriteAllIgnoreExisting(SqlSource source) const;
 
   // Returns the SQL string backing this SqlSource instance;
   const std::string& sql() const { return root_.rewritten_sql; }
diff --git a/src/trace_processor/sqlite/sql_source_unittest.cc b/src/trace_processor/sqlite/sql_source_unittest.cc
index c0d5a88..ef8500b 100644
--- a/src/trace_processor/sqlite/sql_source_unittest.cc
+++ b/src/trace_processor/sqlite/sql_source_unittest.cc
@@ -53,10 +53,10 @@
             "          ^\n");
 }
 
-TEST(SqlSourceTest, FullRewrite) {
+TEST(SqlSourceTest, RewriteAllIgnoreExisting) {
   SqlSource source =
       SqlSource::FromExecuteQuery("macro!()")
-          .FullRewrite(SqlSource::FromTraceProcessorImplementation(
+          .RewriteAllIgnoreExisting(SqlSource::FromTraceProcessorImplementation(
               "SELECT * FROM slice"));
   ASSERT_EQ(source.sql(), "SELECT * FROM slice");
 
@@ -81,12 +81,12 @@
 TEST(SqlSourceTest, NestedFullRewrite) {
   SqlSource nested =
       SqlSource::FromTraceProcessorImplementation("nested!()")
-          .FullRewrite(SqlSource::FromTraceProcessorImplementation(
+          .RewriteAllIgnoreExisting(SqlSource::FromTraceProcessorImplementation(
               "SELECT * FROM slice"));
   ASSERT_EQ(nested.sql(), "SELECT * FROM slice");
 
-  SqlSource source =
-      SqlSource::FromExecuteQuery("macro!()").FullRewrite(std::move(nested));
+  SqlSource source = SqlSource::FromExecuteQuery("macro!()")
+                         .RewriteAllIgnoreExisting(std::move(nested));
   ASSERT_EQ(source.sql(), "SELECT * FROM slice");
 
   ASSERT_EQ(source.AsTraceback(0),
@@ -113,6 +113,32 @@
             "           ^\n");
 }
 
+TEST(SqlSourceTest, RewriteAllIgnoresExistingCorrectly) {
+  SqlSource foo =
+      SqlSource::FromExecuteQuery("foo!()").RewriteAllIgnoreExisting(
+          SqlSource::FromTraceProcessorImplementation("SELECT * FROM slice"));
+  SqlSource source = foo.RewriteAllIgnoreExisting(
+      SqlSource::FromTraceProcessorImplementation("SELECT 0 WHERE 0"));
+  ASSERT_EQ(source.sql(), "SELECT 0 WHERE 0");
+
+  ASSERT_EQ(source.AsTraceback(0),
+            "Traceback (most recent call last):\n"
+            "  File \"stdin\" line 1 col 1\n"
+            "    foo!()\n"
+            "    ^\n"
+            "  Trace Processor Internal line 1 col 1\n"
+            "    SELECT 0 WHERE 0\n"
+            "    ^\n");
+  ASSERT_EQ(source.AsTraceback(4),
+            "Traceback (most recent call last):\n"
+            "  File \"stdin\" line 1 col 1\n"
+            "    foo!()\n"
+            "    ^\n"
+            "  Trace Processor Internal line 1 col 5\n"
+            "    SELECT 0 WHERE 0\n"
+            "        ^\n");
+}
+
 TEST(SqlSourceTest, Rewriter) {
   SqlSource::Rewriter rewriter(
       SqlSource::FromExecuteQuery("SELECT cols!() FROM slice"));