Merge "proto_filter: add option to passthrough submessages without recursing"
diff --git a/src/protozero/filtering/filter_util.cc b/src/protozero/filtering/filter_util.cc
index 8e81cb0..d17a36b 100644
--- a/src/protozero/filtering/filter_util.cc
+++ b/src/protozero/filtering/filter_util.cc
@@ -67,9 +67,13 @@
 FilterUtil::FilterUtil() = default;
 FilterUtil::~FilterUtil() = default;
 
-bool FilterUtil::LoadMessageDefinition(const std::string& proto_file,
-                                       const std::string& root_message,
-                                       const std::string& proto_dir_path) {
+bool FilterUtil::LoadMessageDefinition(
+    const std::string& proto_file,
+    const std::string& root_message,
+    const std::string& proto_dir_path,
+    const std::set<std::string>& passthrough_fields) {
+  passthrough_fields_ = passthrough_fields;
+  passthrough_fields_seen_.clear();
   // The protobuf compiler doesn't like backslashes and prints an error like:
   // Error C:\it7mjanpw3\perfetto-a16500 -1:0: Backslashes, consecutive slashes,
   // ".", or ".." are not allowed in the virtual path.
@@ -122,6 +126,22 @@
   // future without realizing) when performing the Dedupe() pass.
   DescriptorsByNameMap descriptors_by_full_name;
   ParseProtoDescriptor(root_msg, &descriptors_by_full_name);
+
+  // If the user specified a set of fields to pass through, print an error and
+  // fail if any of the passed fields have not been seen while recursing in the
+  // schema. This is to avoid typos or naming changes to be silently ignored.
+  std::vector<std::string> unused_passthrough;
+  std::set_difference(passthrough_fields_.begin(), passthrough_fields_.end(),
+                      passthrough_fields_seen_.begin(),
+                      passthrough_fields_seen_.end(),
+                      std::back_inserter(unused_passthrough));
+  for (const std::string& message_and_field : unused_passthrough) {
+    PERFETTO_ELOG("Field not found %s", message_and_field.c_str());
+  }
+  if (!unused_passthrough.empty()) {
+    PERFETTO_ELOG("Syntax: perfetto.protos.MessageName:field_name");
+    return false;
+  }
   return true;
 }
 
@@ -145,7 +165,15 @@
     auto& field = msg->fields[field_id];
     field.name = proto_field->name();
     field.type = proto_field->type_name();
-    if (proto_field->message_type()) {
+
+    std::string message_and_field = msg->full_name + ":" + field.name;
+    bool passthrough = false;
+    if (passthrough_fields_.count(message_and_field)) {
+      field.type = "bytes";
+      passthrough = true;
+      passthrough_fields_seen_.insert(message_and_field);
+    }
+    if (proto_field->message_type() && !passthrough) {
       msg->has_nested_fields = true;
       // Recurse.
       field.nested_type = ParseProtoDescriptor(proto_field->message_type(),
@@ -253,8 +281,11 @@
       }
 
       const Message* nested_type = id_and_field.second.nested_type;
+      bool passthrough = false;
       if (nested_type) {
-        PERFETTO_CHECK(!result.simple_field() || !filter_bytecode);
+        // result.simple_field might be true if the generated bytecode is
+        // passing through a whole submessage without recursing.
+        passthrough = result.simple_field();
         if (seen_msgs.find(nested_type) == seen_msgs.end()) {
           seen_msgs.insert(nested_type);
           queue.emplace_back(result.nested_msg_index, nested_type);
@@ -267,6 +298,8 @@
       std::string stripped_nested =
           nested_type ? " " + StripPrefix(nested_type->full_name, root_prefix)
                       : "";
+      if (passthrough)
+        stripped_nested += "  # PASSTHROUGH";
       fprintf(print_stream_, "%-60s %3u %-8s %-32s%s\n", stripped_name.c_str(),
               field_id, field.type.c_str(), field.name.c_str(),
               stripped_nested.c_str());
diff --git a/src/protozero/filtering/filter_util.h b/src/protozero/filtering/filter_util.h
index 0c5e74f..276ec6d 100644
--- a/src/protozero/filtering/filter_util.h
+++ b/src/protozero/filtering/filter_util.h
@@ -22,6 +22,7 @@
 #include <list>
 #include <map>
 #include <optional>
+#include <set>
 #include <string>
 
 // We include this intentionally instead of forward declaring to allow
@@ -46,9 +47,14 @@
   // root_message: fully qualified message name (e.g., perfetto.protos.Trace).
   //     If empty, the first message in the file will be used.
   // proto_dir_path: the root for .proto includes. If empty uses CWD.
-  bool LoadMessageDefinition(const std::string& proto_file,
-                             const std::string& root_message,
-                             const std::string& proto_dir_path);
+  // passthrough: an optional set of fields that should be transparently passed
+  //     through without recursing further.
+  //     Syntax: "perfetto.protos.TracePacket:trace_config"
+  bool LoadMessageDefinition(
+      const std::string& proto_file,
+      const std::string& root_message,
+      const std::string& proto_dir_path,
+      const std::set<std::string>& passthrough_fields = {});
 
   // Deduplicates leaf messages having the same sets of field ids.
   // It changes the internal state and affects the behavior of next calls to
@@ -103,6 +109,12 @@
 
   // list<> because pointers need to be stable.
   std::list<Message> descriptors_;
+  std::set<std::string> passthrough_fields_;
+
+  // Used only for debugging aid, to print out an error message when the user
+  // specifies a field to pass through but it doesn't exist.
+  std::set<std::string> passthrough_fields_seen_;
+
   FILE* print_stream_ = stdout;
 };
 
diff --git a/src/protozero/filtering/filter_util_unittest.cc b/src/protozero/filtering/filter_util_unittest.cc
index 7ad4d6f..2b13ae6 100644
--- a/src/protozero/filtering/filter_util_unittest.cc
+++ b/src/protozero/filtering/filter_util_unittest.cc
@@ -302,5 +302,42 @@
             FilterToText(filter, bytecode));
 }
 
+TEST(SchemaParserTest, Passthrough) {
+  auto schema = MkTemp(R"(
+  syntax = "proto2";
+  message Root {
+    optional int32 i32 = 13;
+    optional TracePacket packet = 7;
+  }
+  message TraceConfig {
+    optional int32 f3 = 3;
+    optional int64 f4 = 4;
+  }
+  message TracePacket {
+    optional int32 f1 = 3;
+    optional int64 f2 = 4;
+    optional TraceConfig cfg = 5;
+  }
+  )");
+
+  FilterUtil filter;
+  std::set<std::string> passthrough{"TracePacket:cfg"};
+  ASSERT_TRUE(
+      filter.LoadMessageDefinition(schema.path(), "Root", "", passthrough));
+
+  EXPECT_EQ(R"(Root 7 message packet TracePacket
+Root 13 int32 i32
+TracePacket 3 int32 f1
+TracePacket 4 int64 f2
+TracePacket 5 bytes cfg
+)",
+            FilterToText(filter));
+
+  std::string bytecode = filter.GenerateFilterBytecode();
+  // If we generate bytecode from the schema itself, all fields are allowed and
+  // the result is identical to the unfiltered output.
+  EXPECT_EQ(FilterToText(filter), FilterToText(filter, bytecode));
+}
+
 }  // namespace
 }  // namespace protozero
diff --git a/src/protozero/filtering/message_filter_unittest.cc b/src/protozero/filtering/message_filter_unittest.cc
index 83cb911..3159bfc 100644
--- a/src/protozero/filtering/message_filter_unittest.cc
+++ b/src/protozero/filtering/message_filter_unittest.cc
@@ -128,6 +128,87 @@
   }
 }
 
+TEST(MessageFilterTest, Passthrough) {
+  auto schema = perfetto::base::TempFile::Create();
+  static const char kSchema[] = R"(
+  syntax = "proto2";
+  message TracePacket {
+    optional int64 timestamp = 1;
+    optional TraceConfig cfg = 2;
+    optional TraceConfig cfg_filtered = 3;
+    optional string other = 4;
+  };
+  message SubConfig {
+    optional string f4 = 6;
+  }
+  message TraceConfig {
+    optional int64 f1 = 3;
+    optional string f2 = 4;
+    optional SubConfig f3 = 5;
+  }
+  )";
+
+  perfetto::base::WriteAll(*schema, kSchema, strlen(kSchema));
+  perfetto::base::FlushFile(*schema);
+
+  FilterUtil filter;
+  ASSERT_TRUE(filter.LoadMessageDefinition(
+      schema.path(), "", "", {"TracePacket:other", "TracePacket:cfg"}));
+  std::string bytecode = filter.GenerateFilterBytecode();
+  ASSERT_GT(bytecode.size(), 0u);
+
+  HeapBuffered<Message> msg;
+  msg->AppendVarInt(/*field_id=*/1, 10);
+  msg->AppendString(/*field_id=*/4, "other_string");
+
+  // Fill `cfg`.
+  auto* nest = msg->BeginNestedMessage<Message>(/*field_id=*/2);
+  nest->AppendVarInt(/*field_id=*/3, 100);
+  nest->AppendString(/*field_id=*/4, "f2.payload");
+  nest->AppendString(/*field_id=*/99, "not_in_original_schema");
+  auto* nest2 = nest->BeginNestedMessage<Message>(/*field_id=*/5);
+  nest2->AppendString(/*field_id=*/6, "subconfig.f4");
+  nest2->Finalize();
+  nest->Finalize();
+
+  // Fill `cfg_filtered`.
+  nest = msg->BeginNestedMessage<Message>(/*field_id=*/3);
+  nest->AppendVarInt(/*field_id=*/3, 200);  // This should be propagated.
+  nest->AppendVarInt(/*field_id=*/6, 300);  // This shoudl be filtered out.
+  nest->Finalize();
+
+  MessageFilter flt;
+  ASSERT_TRUE(flt.LoadFilterBytecode(bytecode.data(), bytecode.size()));
+
+  std::vector<uint8_t> encoded = msg.SerializeAsArray();
+
+  auto filtered = flt.FilterMessage(encoded.data(), encoded.size());
+  ASSERT_LT(filtered.size, encoded.size());
+
+  ProtoDecoder dec(filtered.data.get(), filtered.size);
+  EXPECT_EQ(dec.FindField(1).as_int64(), 10);
+  EXPECT_EQ(dec.FindField(4).as_std_string(), "other_string");
+
+  EXPECT_TRUE(dec.FindField(2).valid());
+  ProtoDecoder nest_dec(dec.FindField(2).as_bytes());
+  EXPECT_EQ(nest_dec.FindField(3).as_int32(), 100);
+  EXPECT_EQ(nest_dec.FindField(4).as_std_string(), "f2.payload");
+  EXPECT_TRUE(nest_dec.FindField(5).valid());
+  ProtoDecoder nest_dec2(nest_dec.FindField(5).as_bytes());
+  EXPECT_EQ(nest_dec2.FindField(6).as_std_string(), "subconfig.f4");
+
+  // Field 99 should be preserved anyways even if it wasn't in the original
+  // schema because the whole TracePacket submessage was passed through.
+  EXPECT_TRUE(nest_dec.FindField(99).valid());
+  EXPECT_EQ(nest_dec.FindField(99).as_std_string(), "not_in_original_schema");
+
+  // Check that the field `cfg_filtered` contains only `f1`,`f2`,`f3`.
+  EXPECT_TRUE(dec.FindField(3).valid());
+  ProtoDecoder nest_dec3(dec.FindField(3).as_bytes());
+  EXPECT_EQ(nest_dec3.FindField(3).as_int32(), 200);
+  EXPECT_FALSE(nest_dec3.FindField(6).valid());
+}
+
 TEST(MessageFilterTest, ChangeRoot) {
   auto schema = perfetto::base::TempFile::Create();
   static const char kSchema[] = R"(
diff --git a/src/tools/proto_filter/proto_filter.cc b/src/tools/proto_filter/proto_filter.cc
index 04385c2..7bfcb74 100644
--- a/src/tools/proto_filter/proto_filter.cc
+++ b/src/tools/proto_filter/proto_filter.cc
@@ -49,7 +49,7 @@
 # Generate the filter bytecode from a .proto schema
 
   proto_filter -r perfetto.protos.Trace -s protos/perfetto/trace/trace.proto \
-               -F /tmp/bytecode [--dedupe]
+               -F /tmp/bytecode [--dedupe] [-x protos.Message:field_to_pass]
 
 # List the used/filtered fields from a trace file
 
@@ -72,14 +72,15 @@
       {"help", no_argument, nullptr, 'h'},
       {"version", no_argument, nullptr, 'v'},
       {"dedupe", no_argument, nullptr, 'd'},
-      {"proto_path", no_argument, nullptr, 'I'},
-      {"schema_in", no_argument, nullptr, 's'},
-      {"root_message", no_argument, nullptr, 'r'},
-      {"msg_in", no_argument, nullptr, 'i'},
-      {"msg_out", no_argument, nullptr, 'o'},
-      {"filter_in", no_argument, nullptr, 'f'},
-      {"filter_out", no_argument, nullptr, 'F'},
-      {"filter_oct_out", no_argument, nullptr, 'T'},
+      {"proto_path", required_argument, nullptr, 'I'},
+      {"schema_in", required_argument, nullptr, 's'},
+      {"root_message", required_argument, nullptr, 'r'},
+      {"msg_in", required_argument, nullptr, 'i'},
+      {"msg_out", required_argument, nullptr, 'o'},
+      {"filter_in", required_argument, nullptr, 'f'},
+      {"filter_out", required_argument, nullptr, 'F'},
+      {"filter_oct_out", required_argument, nullptr, 'T'},
+      {"passthrough", required_argument, nullptr, 'x'},
       {nullptr, 0, nullptr, 0}};
 
   std::string msg_in;
@@ -90,11 +91,12 @@
   std::string filter_oct_out;
   std::string proto_path;
   std::string root_message_arg;
+  std::set<std::string> passthrough_fields;
   bool dedupe = false;
 
   for (;;) {
     int option =
-        getopt_long(argc, argv, "hvdI:s:r:i:o:f:F:T:", long_options, nullptr);
+        getopt_long(argc, argv, "hvdI:s:r:i:o:f:F:T:x:", long_options, nullptr);
 
     if (option == -1)
       break;  // EOF.
@@ -149,6 +151,11 @@
       continue;
     }
 
+    if (option == 'x') {
+      passthrough_fields.insert(optarg);
+      continue;
+    }
+
     if (option == 'h') {
       fprintf(stdout, kUsage);
       exit(0);
@@ -175,8 +182,8 @@
   protozero::FilterUtil filter;
   if (!schema_in.empty()) {
     PERFETTO_LOG("Loading proto schema from %s", schema_in.c_str());
-    if (!filter.LoadMessageDefinition(schema_in, root_message_arg,
-                                      proto_path)) {
+    if (!filter.LoadMessageDefinition(schema_in, root_message_arg, proto_path,
+                                      passthrough_fields)) {
       PERFETTO_ELOG("Failed to parse proto schema from %s", schema_in.c_str());
       return 1;
     }
@@ -264,7 +271,7 @@
   if (!msg_out.empty()) {
     PERFETTO_LOG("Writing filtered proto bytes (%zu bytes) into %s",
                  msg_filtered_data.size(), msg_out.c_str());
-    auto fd = base::OpenFile(msg_out, O_WRONLY | O_CREAT, 0644);
+    auto fd = base::OpenFile(msg_out, O_WRONLY | O_TRUNC | O_CREAT, 0644);
     base::WriteAll(*fd, msg_filtered_data.data(), msg_filtered_data.size());
   }