Merge "proto_filter: add option to passthrough submessages without recursing"
diff --git a/src/protozero/filtering/filter_util.cc b/src/protozero/filtering/filter_util.cc
index 8e81cb0..d17a36b 100644
--- a/src/protozero/filtering/filter_util.cc
+++ b/src/protozero/filtering/filter_util.cc
@@ -67,9 +67,13 @@
FilterUtil::FilterUtil() = default;
FilterUtil::~FilterUtil() = default;
-bool FilterUtil::LoadMessageDefinition(const std::string& proto_file,
- const std::string& root_message,
- const std::string& proto_dir_path) {
+bool FilterUtil::LoadMessageDefinition(
+ const std::string& proto_file,
+ const std::string& root_message,
+ const std::string& proto_dir_path,
+ const std::set<std::string>& passthrough_fields) {
+ passthrough_fields_ = passthrough_fields;
+ passthrough_fields_seen_.clear();
// The protobuf compiler doesn't like backslashes and prints an error like:
// Error C:\it7mjanpw3\perfetto-a16500 -1:0: Backslashes, consecutive slashes,
// ".", or ".." are not allowed in the virtual path.
@@ -122,6 +126,22 @@
// future without realizing) when performing the Dedupe() pass.
DescriptorsByNameMap descriptors_by_full_name;
ParseProtoDescriptor(root_msg, &descriptors_by_full_name);
+
+ // If the user specified a set of fields to pass through, print an error and
+ // fail if any of the passed fields have not been seen while recursing in the
+ // schema. This is to avoid typos or naming changes to be silently ignored.
+ std::vector<std::string> unused_passthrough;
+ std::set_difference(passthrough_fields_.begin(), passthrough_fields_.end(),
+ passthrough_fields_seen_.begin(),
+ passthrough_fields_seen_.end(),
+ std::back_inserter(unused_passthrough));
+ for (const std::string& message_and_field : unused_passthrough) {
+ PERFETTO_ELOG("Field not found %s", message_and_field.c_str());
+ }
+ if (!unused_passthrough.empty()) {
+ PERFETTO_ELOG("Syntax: perfetto.protos.MessageName:field_name");
+ return false;
+ }
return true;
}
@@ -145,7 +165,15 @@
auto& field = msg->fields[field_id];
field.name = proto_field->name();
field.type = proto_field->type_name();
- if (proto_field->message_type()) {
+
+ std::string message_and_field = msg->full_name + ":" + field.name;
+ bool passthrough = false;
+ if (passthrough_fields_.count(message_and_field)) {
+ field.type = "bytes";
+ passthrough = true;
+ passthrough_fields_seen_.insert(message_and_field);
+ }
+ if (proto_field->message_type() && !passthrough) {
msg->has_nested_fields = true;
// Recurse.
field.nested_type = ParseProtoDescriptor(proto_field->message_type(),
@@ -253,8 +281,11 @@
}
const Message* nested_type = id_and_field.second.nested_type;
+ bool passthrough = false;
if (nested_type) {
- PERFETTO_CHECK(!result.simple_field() || !filter_bytecode);
+ // result.simple_field might be true if the generated bytecode is
+ // passing through a whole submessage without recursing.
+ passthrough = result.simple_field();
if (seen_msgs.find(nested_type) == seen_msgs.end()) {
seen_msgs.insert(nested_type);
queue.emplace_back(result.nested_msg_index, nested_type);
@@ -267,6 +298,8 @@
std::string stripped_nested =
nested_type ? " " + StripPrefix(nested_type->full_name, root_prefix)
: "";
+ if (passthrough)
+ stripped_nested += " # PASSTHROUGH";
fprintf(print_stream_, "%-60s %3u %-8s %-32s%s\n", stripped_name.c_str(),
field_id, field.type.c_str(), field.name.c_str(),
stripped_nested.c_str());
diff --git a/src/protozero/filtering/filter_util.h b/src/protozero/filtering/filter_util.h
index 0c5e74f..276ec6d 100644
--- a/src/protozero/filtering/filter_util.h
+++ b/src/protozero/filtering/filter_util.h
@@ -22,6 +22,7 @@
#include <list>
#include <map>
#include <optional>
+#include <set>
#include <string>
// We include this intentionally instead of forward declaring to allow
@@ -46,9 +47,14 @@
// root_message: fully qualified message name (e.g., perfetto.protos.Trace).
// If empty, the first message in the file will be used.
// proto_dir_path: the root for .proto includes. If empty uses CWD.
- bool LoadMessageDefinition(const std::string& proto_file,
- const std::string& root_message,
- const std::string& proto_dir_path);
+ // passthrough: an optional set of fields that should be transparently passed
+ // through without recursing further.
+ // Syntax: "perfetto.protos.TracePacket:trace_config"
+ bool LoadMessageDefinition(
+ const std::string& proto_file,
+ const std::string& root_message,
+ const std::string& proto_dir_path,
+ const std::set<std::string>& passthrough_fields = {});
// Deduplicates leaf messages having the same sets of field ids.
// It changes the internal state and affects the behavior of next calls to
@@ -103,6 +109,12 @@
// list<> because pointers need to be stable.
std::list<Message> descriptors_;
+ std::set<std::string> passthrough_fields_;
+
+ // Used only for debugging aid, to print out an error message when the user
+ // specifies a field to pass through but it doesn't exist.
+ std::set<std::string> passthrough_fields_seen_;
+
FILE* print_stream_ = stdout;
};
diff --git a/src/protozero/filtering/filter_util_unittest.cc b/src/protozero/filtering/filter_util_unittest.cc
index 7ad4d6f..2b13ae6 100644
--- a/src/protozero/filtering/filter_util_unittest.cc
+++ b/src/protozero/filtering/filter_util_unittest.cc
@@ -302,5 +302,42 @@
FilterToText(filter, bytecode));
}
+TEST(SchemaParserTest, Passthrough) {
+ auto schema = MkTemp(R"(
+ syntax = "proto2";
+ message Root {
+ optional int32 i32 = 13;
+ optional TracePacket packet = 7;
+ }
+ message TraceConfig {
+ optional int32 f3 = 3;
+ optional int64 f4 = 4;
+ }
+ message TracePacket {
+ optional int32 f1 = 3;
+ optional int64 f2 = 4;
+ optional TraceConfig cfg = 5;
+ }
+ )");
+
+ FilterUtil filter;
+ std::set<std::string> passthrough{"TracePacket:cfg"};
+ ASSERT_TRUE(
+ filter.LoadMessageDefinition(schema.path(), "Root", "", passthrough));
+
+ EXPECT_EQ(R"(Root 7 message packet TracePacket
+Root 13 int32 i32
+TracePacket 3 int32 f1
+TracePacket 4 int64 f2
+TracePacket 5 bytes cfg
+)",
+ FilterToText(filter));
+
+ std::string bytecode = filter.GenerateFilterBytecode();
+ // If we generate bytecode from the schema itself, all fields are allowed and
+ // the result is identical to the unfiltered output.
+ EXPECT_EQ(FilterToText(filter), FilterToText(filter, bytecode));
+}
+
} // namespace
} // namespace protozero
diff --git a/src/protozero/filtering/message_filter_unittest.cc b/src/protozero/filtering/message_filter_unittest.cc
index 83cb911..3159bfc 100644
--- a/src/protozero/filtering/message_filter_unittest.cc
+++ b/src/protozero/filtering/message_filter_unittest.cc
@@ -128,6 +128,87 @@
}
}
+TEST(MessageFilterTest, Passthrough) {
+ auto schema = perfetto::base::TempFile::Create();
+ static const char kSchema[] = R"(
+ syntax = "proto2";
+ message TracePacket {
+ optional int64 timestamp = 1;
+ optional TraceConfig cfg = 2;
+ optional TraceConfig cfg_filtered = 3;
+ optional string other = 4;
+ };
+ message SubConfig {
+ optional string f4 = 6;
+ }
+ message TraceConfig {
+ optional int64 f1 = 3;
+ optional string f2 = 4;
+ optional SubConfig f3 = 5;
+ }
+ )";
+
+ perfetto::base::WriteAll(*schema, kSchema, strlen(kSchema));
+ perfetto::base::FlushFile(*schema);
+
+ FilterUtil filter;
+ ASSERT_TRUE(filter.LoadMessageDefinition(
+ schema.path(), "", "", {"TracePacket:other", "TracePacket:cfg"}));
+ std::string bytecode = filter.GenerateFilterBytecode();
+ ASSERT_GT(bytecode.size(), 0u);
+
+ HeapBuffered<Message> msg;
+ msg->AppendVarInt(/*field_id=*/1, 10);
+ msg->AppendString(/*field_id=*/4, "other_string");
+
+ // Fill `cfg`.
+ auto* nest = msg->BeginNestedMessage<Message>(/*field_id=*/2);
+ nest->AppendVarInt(/*field_id=*/3, 100);
+ nest->AppendString(/*field_id=*/4, "f2.payload");
+ nest->AppendString(/*field_id=*/99, "not_in_original_schema");
+ auto* nest2 = nest->BeginNestedMessage<Message>(/*field_id=*/5);
+ nest2->AppendString(/*field_id=*/6, "subconfig.f4");
+ nest2->Finalize();
+ nest->Finalize();
+
+ // Fill `cfg_filtered`.
+ nest = msg->BeginNestedMessage<Message>(/*field_id=*/3);
+ nest->AppendVarInt(/*field_id=*/3, 200); // This should be propagated.
+ nest->AppendVarInt(/*field_id=*/6, 300); // This shoudl be filtered out.
+ nest->Finalize();
+
+ MessageFilter flt;
+ ASSERT_TRUE(flt.LoadFilterBytecode(bytecode.data(), bytecode.size()));
+
+ std::vector<uint8_t> encoded = msg.SerializeAsArray();
+
+ auto filtered = flt.FilterMessage(encoded.data(), encoded.size());
+ ASSERT_LT(filtered.size, encoded.size());
+
+ ProtoDecoder dec(filtered.data.get(), filtered.size);
+ EXPECT_EQ(dec.FindField(1).as_int64(), 10);
+ EXPECT_EQ(dec.FindField(4).as_std_string(), "other_string");
+
+ EXPECT_TRUE(dec.FindField(2).valid());
+ ProtoDecoder nest_dec(dec.FindField(2).as_bytes());
+ EXPECT_EQ(nest_dec.FindField(3).as_int32(), 100);
+ EXPECT_EQ(nest_dec.FindField(4).as_std_string(), "f2.payload");
+ EXPECT_TRUE(nest_dec.FindField(5).valid());
+ ProtoDecoder nest_dec2(nest_dec.FindField(5).as_bytes());
+ EXPECT_EQ(nest_dec2.FindField(6).as_std_string(), "subconfig.f4");
+
+ // Field 99 should be preserved anyways even if it wasn't in the original
+ // schema because the whole TracePacket submessage was passed through.
+ EXPECT_TRUE(nest_dec.FindField(99).valid());
+ EXPECT_EQ(nest_dec.FindField(99).as_std_string(), "not_in_original_schema");
+
+ // Check that the field `cfg_filtered` contains only `f1`,`f2`,`f3`.
+ EXPECT_TRUE(dec.FindField(3).valid());
+ ProtoDecoder nest_dec3(dec.FindField(3).as_bytes());
+ EXPECT_EQ(nest_dec3.FindField(3).as_int32(), 200);
+ EXPECT_FALSE(nest_dec3.FindField(6).valid());
+}
+
TEST(MessageFilterTest, ChangeRoot) {
auto schema = perfetto::base::TempFile::Create();
static const char kSchema[] = R"(
diff --git a/src/tools/proto_filter/proto_filter.cc b/src/tools/proto_filter/proto_filter.cc
index 04385c2..7bfcb74 100644
--- a/src/tools/proto_filter/proto_filter.cc
+++ b/src/tools/proto_filter/proto_filter.cc
@@ -49,7 +49,7 @@
# Generate the filter bytecode from a .proto schema
proto_filter -r perfetto.protos.Trace -s protos/perfetto/trace/trace.proto \
- -F /tmp/bytecode [--dedupe]
+ -F /tmp/bytecode [--dedupe] [-x protos.Message:field_to_pass]
# List the used/filtered fields from a trace file
@@ -72,14 +72,15 @@
{"help", no_argument, nullptr, 'h'},
{"version", no_argument, nullptr, 'v'},
{"dedupe", no_argument, nullptr, 'd'},
- {"proto_path", no_argument, nullptr, 'I'},
- {"schema_in", no_argument, nullptr, 's'},
- {"root_message", no_argument, nullptr, 'r'},
- {"msg_in", no_argument, nullptr, 'i'},
- {"msg_out", no_argument, nullptr, 'o'},
- {"filter_in", no_argument, nullptr, 'f'},
- {"filter_out", no_argument, nullptr, 'F'},
- {"filter_oct_out", no_argument, nullptr, 'T'},
+ {"proto_path", required_argument, nullptr, 'I'},
+ {"schema_in", required_argument, nullptr, 's'},
+ {"root_message", required_argument, nullptr, 'r'},
+ {"msg_in", required_argument, nullptr, 'i'},
+ {"msg_out", required_argument, nullptr, 'o'},
+ {"filter_in", required_argument, nullptr, 'f'},
+ {"filter_out", required_argument, nullptr, 'F'},
+ {"filter_oct_out", required_argument, nullptr, 'T'},
+ {"passthrough", required_argument, nullptr, 'x'},
{nullptr, 0, nullptr, 0}};
std::string msg_in;
@@ -90,11 +91,12 @@
std::string filter_oct_out;
std::string proto_path;
std::string root_message_arg;
+ std::set<std::string> passthrough_fields;
bool dedupe = false;
for (;;) {
int option =
- getopt_long(argc, argv, "hvdI:s:r:i:o:f:F:T:", long_options, nullptr);
+ getopt_long(argc, argv, "hvdI:s:r:i:o:f:F:T:x:", long_options, nullptr);
if (option == -1)
break; // EOF.
@@ -149,6 +151,11 @@
continue;
}
+ if (option == 'x') {
+ passthrough_fields.insert(optarg);
+ continue;
+ }
+
if (option == 'h') {
fprintf(stdout, kUsage);
exit(0);
@@ -175,8 +182,8 @@
protozero::FilterUtil filter;
if (!schema_in.empty()) {
PERFETTO_LOG("Loading proto schema from %s", schema_in.c_str());
- if (!filter.LoadMessageDefinition(schema_in, root_message_arg,
- proto_path)) {
+ if (!filter.LoadMessageDefinition(schema_in, root_message_arg, proto_path,
+ passthrough_fields)) {
PERFETTO_ELOG("Failed to parse proto schema from %s", schema_in.c_str());
return 1;
}
@@ -264,7 +271,7 @@
if (!msg_out.empty()) {
PERFETTO_LOG("Writing filtered proto bytes (%zu bytes) into %s",
msg_filtered_data.size(), msg_out.c_str());
- auto fd = base::OpenFile(msg_out, O_WRONLY | O_CREAT, 0644);
+ auto fd = base::OpenFile(msg_out, O_WRONLY | O_TRUNC | O_CREAT, 0644);
base::WriteAll(*fd, msg_filtered_data.data(), msg_filtered_data.size());
}