Merge "metrics: add support for extension fields in descriptors"
diff --git a/protos/perfetto/common/descriptor.proto b/protos/perfetto/common/descriptor.proto
index be0fff5..d0d86ab 100644
--- a/protos/perfetto/common/descriptor.proto
+++ b/protos/perfetto/common/descriptor.proto
@@ -43,9 +43,9 @@
   // All top-level definitions in this file.
   repeated DescriptorProto message_type = 4;
   repeated EnumDescriptorProto enum_type = 5;
+  repeated FieldDescriptorProto extension = 7;
 
   reserved 6;
-  reserved 7;
   reserved 8;
   reserved 9;
   reserved 12;
@@ -137,7 +137,9 @@
   // namespace).
   optional string type_name = 6;
 
-  reserved 2;
+  // For extensions, this is the name of the type being extended.  It is
+  // resolved in the same manner as type_name.
+  optional string extendee = 2;
 
   // For numeric types, contains the original text representation of the value.
   // For booleans, "true" or "false".
diff --git a/src/trace_processor/metrics/descriptors.cc b/src/trace_processor/metrics/descriptors.cc
index c487101..35104e1 100644
--- a/src/trace_processor/metrics/descriptors.cc
+++ b/src/trace_processor/metrics/descriptors.cc
@@ -22,6 +22,65 @@
 namespace trace_processor {
 namespace metrics {
 
+namespace {
+
+FieldDescriptor CreateFieldFromDecoder(
+    const protos::pbzero::FieldDescriptorProto::Decoder& f_decoder) {
+  using FieldDescriptorProto = protos::pbzero::FieldDescriptorProto;
+  std::string type_name =
+      f_decoder.has_type_name()
+          ? base::StringView(f_decoder.type_name()).ToStdString()
+          : "";
+  // TODO(lalitm): add support for enums here.
+  uint32_t type = f_decoder.has_type() ? static_cast<uint32_t>(f_decoder.type())
+                                       : FieldDescriptorProto::TYPE_MESSAGE;
+  return FieldDescriptor(
+      base::StringView(f_decoder.name()).ToStdString(),
+      static_cast<uint32_t>(f_decoder.number()), type, std::move(type_name),
+      f_decoder.label() == FieldDescriptorProto::LABEL_REPEATED);
+}
+
+}  // namespace
+
+base::Optional<uint32_t> DescriptorPool::ResolveShortType(
+    const std::string& parent_path,
+    const std::string& short_type) {
+  PERFETTO_DCHECK(!short_type.empty());
+
+  std::string search_path = short_type[0] == '.'
+                                ? parent_path + short_type
+                                : parent_path + '.' + short_type;
+  auto opt_idx = FindDescriptorIdx(search_path);
+  if (opt_idx)
+    return opt_idx;
+
+  if (parent_path.empty())
+    return base::nullopt;
+
+  auto parent_dot_idx = parent_path.rfind('.');
+  auto parent_substr = parent_dot_idx == std::string::npos
+                           ? ""
+                           : parent_path.substr(0, parent_dot_idx);
+  return ResolveShortType(parent_substr, short_type);
+}
+
+util::Status DescriptorPool::AddExtensionField(const std::string& package_name,
+                                               const uint8_t* field_desc_proto,
+                                               size_t size) {
+  using FieldDescriptorProto = protos::pbzero::FieldDescriptorProto;
+  FieldDescriptorProto::Decoder f_decoder(field_desc_proto, size);
+  auto field = CreateFieldFromDecoder(f_decoder);
+
+  auto extendee_name =
+      package_name + "." + base::StringView(f_decoder.extendee()).ToStdString();
+  auto extendee = FindDescriptorIdx(extendee_name);
+  if (!extendee.has_value()) {
+    return util::ErrStatus("Extendee does not exist %s", extendee_name.c_str());
+  }
+  descriptors_[extendee.value()].AddField(field);
+  return util::OkStatus();
+}
+
 void DescriptorPool::AddNestedProtoDescriptors(
     const std::string& package_name,
     base::Optional<uint32_t> parent_idx,
@@ -38,16 +97,7 @@
   ProtoDescriptor proto_descriptor(package_name, full_name, parent_idx);
   for (auto it = decoder.field(); it; ++it) {
     FieldDescriptorProto::Decoder f_decoder(it->data(), it->size());
-    std::string type_name =
-        f_decoder.has_type_name()
-            ? base::StringView(f_decoder.type_name()).ToStdString()
-            : "";
-    FieldDescriptor field(
-        base::StringView(f_decoder.name()).ToStdString(),
-        static_cast<uint32_t>(f_decoder.number()),
-        static_cast<uint32_t>(f_decoder.type()), std::move(type_name),
-        f_decoder.label() == FieldDescriptorProto::LABEL_REPEATED);
-    proto_descriptor.AddField(std::move(field));
+    proto_descriptor.AddField(CreateFieldFromDecoder(f_decoder));
   }
   descriptors_.emplace_back(std::move(proto_descriptor));
 
@@ -57,7 +107,7 @@
   }
 }
 
-void DescriptorPool::AddFromFileDescriptorSet(
+util::Status DescriptorPool::AddFromFileDescriptorSet(
     const uint8_t* file_descriptor_set_proto,
     size_t size) {
   // First pass: extract all the message descriptors from the file and add them
@@ -73,17 +123,43 @@
     }
   }
 
-  // Second pass: resolve the types of all the fields to the correct indiices.
+  // Second pass: extract all the extension protos and add them to the real
+  // protos.
+  for (auto it = proto.file(); it; ++it) {
+    protos::pbzero::FileDescriptorProto::Decoder file(it->data(), it->size());
+
+    std::string package = "." + base::StringView(file.package()).ToStdString();
+    for (auto ext_it = file.extension(); ext_it; ++ext_it) {
+      auto status = AddExtensionField(package, ext_it->data(), ext_it->size());
+      if (!status.ok())
+        return status;
+    }
+
+    // TODO(lalitm): we don't currently support nested extensions as they are
+    // relatively niche and probably shouldn't be used in metrics because they
+    // are confusing. Add the code for it here if we find a use for them in
+    // the future.
+  }
+
+  // Third pass: resolve the types of all the fields to the correct indiices.
   using FieldDescriptorProto = protos::pbzero::FieldDescriptorProto;
   for (auto& descriptor : descriptors_) {
     for (auto& field : *descriptor.mutable_fields()) {
       if (field.type() == FieldDescriptorProto::TYPE_MESSAGE ||
           field.type() == FieldDescriptorProto::TYPE_ENUM) {
-        field.set_message_type_idx(
-            FindDescriptorIdx(field.raw_type_name()).value());
+        auto opt_desc =
+            ResolveShortType(descriptor.full_name(), field.raw_type_name());
+        if (!opt_desc.has_value()) {
+          return util::ErrStatus(
+              "Unable to find short type %s in field inside message %s",
+              field.raw_type_name().c_str(), descriptor.full_name().c_str());
+        }
+        field.set_resolved_type_name(
+            descriptors_[opt_desc.value()].full_name());
       }
     }
   }
+  return util::OkStatus();
 }
 
 base::Optional<uint32_t> DescriptorPool::FindDescriptorIdx(
diff --git a/src/trace_processor/metrics/descriptors.h b/src/trace_processor/metrics/descriptors.h
index d6d167a..37ef287 100644
--- a/src/trace_processor/metrics/descriptors.h
+++ b/src/trace_processor/metrics/descriptors.h
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "perfetto/base/optional.h"
+#include "perfetto/trace_processor/basic_types.h"
 
 namespace perfetto {
 namespace trace_processor {
@@ -39,18 +40,20 @@
   uint32_t number() const { return number_; }
   uint32_t type() const { return type_; }
   const std::string& raw_type_name() const { return raw_type_name_; }
+  const std::string& resolved_type_name() const { return resolved_type_name_; }
   bool is_repeated() const { return is_repeated_; }
 
-  void set_message_type_idx(uint32_t idx) { message_type_idx_ = idx; }
+  void set_resolved_type_name(const std::string& resolved_type_name) {
+    resolved_type_name_ = resolved_type_name;
+  }
 
  private:
   std::string name_;
   uint32_t number_;
   uint32_t type_;
   std::string raw_type_name_;
+  std::string resolved_type_name_;
   bool is_repeated_;
-
-  base::Optional<uint32_t> message_type_idx_;
 };
 
 class ProtoDescriptor {
@@ -87,8 +90,9 @@
 
 class DescriptorPool {
  public:
-  void AddFromFileDescriptorSet(const uint8_t* file_descriptor_set_proto,
-                                size_t size);
+  util::Status AddFromFileDescriptorSet(
+      const uint8_t* file_descriptor_set_proto,
+      size_t size);
 
   base::Optional<uint32_t> FindDescriptorIdx(
       const std::string& full_name) const;
@@ -103,6 +107,15 @@
                                  const uint8_t* descriptor_proto,
                                  size_t size);
 
+  util::Status AddExtensionField(const std::string& package_name,
+                                 const uint8_t* field_desc_proto,
+                                 size_t size);
+
+  // Recursively searches for the the given short type in all parent messages
+  // and packages.
+  base::Optional<uint32_t> ResolveShortType(const std::string& parent_path,
+                                            const std::string& short_type);
+
   std::vector<ProtoDescriptor> descriptors_;
 };
 
diff --git a/src/trace_processor/metrics/metrics.cc b/src/trace_processor/metrics/metrics.cc
index 427a050..c3f7433 100644
--- a/src/trace_processor/metrics/metrics.cc
+++ b/src/trace_processor/metrics/metrics.cc
@@ -256,10 +256,10 @@
   }
 
   auto actual_type_name = single.type_name().ToStdString();
-  if (actual_type_name != field.raw_type_name()) {
+  if (actual_type_name != field.resolved_type_name()) {
     return util::ErrStatus("Field %s has wrong type (expected %s, was %s)",
                            field.name().c_str(), actual_type_name.c_str(),
-                           field.raw_type_name().c_str());
+                           field.resolved_type_name().c_str());
   }
 
   if (!single.has_protobuf()) {
diff --git a/src/trace_processor/trace_processor_impl.cc b/src/trace_processor/trace_processor_impl.cc
index f4cacee..844f464 100644
--- a/src/trace_processor/trace_processor_impl.cc
+++ b/src/trace_processor/trace_processor_impl.cc
@@ -456,8 +456,7 @@
 
 util::Status TraceProcessorImpl::ExtendMetricsProto(const uint8_t* data,
                                                     size_t size) {
-  pool_.AddFromFileDescriptorSet(data, size);
-  return util::OkStatus();
+  return pool_.AddFromFileDescriptorSet(data, size);
 }
 
 util::Status TraceProcessorImpl::ComputeMetric(