blob: d92eb9c168ef3f4d9ae6de875d8c1c0bc34e69c2 [file]
// Copyright (C) 2025 Rivos Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdlib.h>
#include <limits>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <google/protobuf/compiler/code_generator.h>
#include <google/protobuf/compiler/plugin.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/io/printer.h>
#include <google/protobuf/io/zero_copy_stream.h>
#include "perfetto/ext/base/string_utils.h"
namespace protozero {
namespace {
using google::protobuf::Descriptor;
using google::protobuf::EnumDescriptor;
using google::protobuf::EnumValueDescriptor;
using google::protobuf::FieldDescriptor;
using google::protobuf::FileDescriptor;
using google::protobuf::compiler::GeneratorContext;
using google::protobuf::io::Printer;
using google::protobuf::io::ZeroCopyOutputStream;
using perfetto::base::ReplaceAll;
using perfetto::base::SplitString;
using perfetto::base::StripChars;
using perfetto::base::StripPrefix;
using perfetto::base::StripSuffix;
using perfetto::base::ToUpper;
using perfetto::base::Uppercase;
void Assert(bool condition) {
if (!condition)
abort();
}
// Maximum line length for single-line pb_enum! macro invocations.
// Enums that would produce output longer than this threshold will be
// formatted across multiple lines for readability.
constexpr size_t kMaxSingleLinePbEnumLength = 60;
struct FileDescriptorComp {
bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
int comp = lhs->name().compare(rhs->name());
Assert(comp != 0 || lhs == rhs);
return comp < 0;
}
};
struct DescriptorComp {
bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
int comp = lhs->full_name().compare(rhs->full_name());
Assert(comp != 0 || lhs == rhs);
return comp < 0;
}
};
struct EnumDescriptorComp {
bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
int comp = lhs->full_name().compare(rhs->full_name());
Assert(comp != 0 || lhs == rhs);
return comp < 0;
}
};
inline std::string ProtoStubName(const FileDescriptor* proto) {
return StripSuffix(std::string(proto->name()), ".proto");
}
std::string IntLiteralString(int number) {
// Special case for -2147483648. If int is 32-bit, the compiler will
// misinterpret it.
if (number == std::numeric_limits<int32_t>::min()) {
return "-2147483647 - 1";
}
return std::to_string(number);
}
class GeneratorJob {
public:
GeneratorJob(const FileDescriptor* file, Printer* stub_rs_printer)
: source_(file), stub_rs_(stub_rs_printer) {}
bool GenerateStubs() {
Preprocess();
GeneratePrologue();
for (const EnumDescriptor* enumeration : enums_)
GenerateEnumDescriptor(enumeration);
for (const Descriptor* message : messages_)
GenerateMessageDescriptor(message);
for (const auto& key_value : extensions_)
GenerateExtensionDescriptor(key_value.first, key_value.second);
return error_.empty();
}
void SetOption(const std::string& name, const std::string& value) {
if (name == "wrapper_namespace") {
wrapper_namespace_ = value;
} else if (name == "path_strip_prefix") {
path_strip_prefix_ = value;
} else if (name == "path_add_prefix") {
path_add_prefix_ = value;
} else if (name == "invoker") {
invoker_ = value;
} else if (name == "external_crate") {
external_crate_ = value;
} else if (name == "local_files") {
for (const auto& f : SplitString(value, "|")) {
local_files_.insert(std::string(f));
}
} else {
Abort(std::string() + "Unknown plugin option '" + name + "'.");
}
}
// If generator fails to produce stubs for a particular proto definitions
// it finishes with undefined output and writes the first error occurred.
const std::string& GetFirstError() const { return error_; }
private:
// Only the first error will be recorded.
void Abort(const std::string& reason) {
if (error_.empty())
error_ = reason;
}
// Get Rust struct name corresponding to proto descriptor (simple name only).
template <class T>
inline std::string GetRustStructName(const T* descriptor) {
return StripChars(std::string(descriptor->name()), ".", '_');
}
// Get full Rust struct name including parent type names for nested messages.
std::string GetFullRustMessageName(const Descriptor* descriptor) {
std::string name;
if (descriptor->containing_type()) {
name.append(GetFullRustMessageName(descriptor->containing_type()));
}
name.append(GetRustStructName(descriptor));
return name;
}
std::string FieldToRustTypeName(const FieldDescriptor* field) {
switch (field->type()) {
case FieldDescriptor::TYPE_BOOL:
return "bool";
case FieldDescriptor::TYPE_INT32:
return "i32";
case FieldDescriptor::TYPE_INT64:
return "i64";
case FieldDescriptor::TYPE_UINT32:
return "u32";
case FieldDescriptor::TYPE_UINT64:
return "u64";
case FieldDescriptor::TYPE_SINT32:
return "i32";
case FieldDescriptor::TYPE_SINT64:
return "i64";
case FieldDescriptor::TYPE_FIXED32:
return "u32";
case FieldDescriptor::TYPE_FIXED64:
return "u64";
case FieldDescriptor::TYPE_SFIXED32:
return "i32";
case FieldDescriptor::TYPE_SFIXED64:
return "i64";
case FieldDescriptor::TYPE_FLOAT:
return "f32";
case FieldDescriptor::TYPE_DOUBLE:
return "f64";
case FieldDescriptor::TYPE_ENUM: {
std::string name;
if (field->enum_type()->containing_type()) {
name.append(GetRustStructName(field->enum_type()->containing_type()));
}
name.append(GetRustStructName(field->enum_type()));
return name;
}
case FieldDescriptor::TYPE_STRING:
return "String";
case FieldDescriptor::TYPE_BYTES:
return "String";
case FieldDescriptor::TYPE_MESSAGE:
return GetFullRustMessageName(field->message_type());
case FieldDescriptor::TYPE_GROUP:
Abort("Groups not supported.");
return "";
}
Abort("Unrecognized FieldDescriptor::Type.");
return "";
}
void CollectDescriptors() {
// Collect message descriptors in DFS order.
std::vector<const Descriptor*> stack;
stack.reserve(static_cast<size_t>(source_->message_type_count()));
for (int i = 0; i < source_->message_type_count(); ++i)
stack.push_back(source_->message_type(i));
while (!stack.empty()) {
const Descriptor* message = stack.back();
stack.pop_back();
if (message->extension_count() > 0) {
if (message->field_count() > 0 || message->nested_type_count() > 0 ||
message->enum_type_count() > 0) {
Abort("message with extend blocks shouldn't contain anything else");
}
// Iterate over all fields in "extend" blocks.
for (int i = 0; i < message->extension_count(); ++i) {
const FieldDescriptor* extension = message->extension(i);
// Protoc plugin API does not group fields in "extend" blocks.
// As the support for extensions in protozero is limited, the code
// assumes that extend blocks are located inside a wrapper message and
// name of this message is used to group them.
std::string extension_name =
GetRustStructName(extension->extension_scope());
extensions_[extension_name].push_back(extension);
}
} else {
messages_.push_back(message);
for (int i = 0; i < message->nested_type_count(); ++i) {
stack.push_back(message->nested_type(i));
// Emit a use statemewnt for nested message types, as the outer
// struct will refer to them.
referenced_messages_.insert(message->nested_type(i));
}
}
}
// Collect enums.
for (int i = 0; i < source_->enum_type_count(); ++i)
enums_.push_back(source_->enum_type(i));
for (const Descriptor* message : messages_) {
for (int i = 0; i < message->enum_type_count(); ++i) {
enums_.push_back(message->enum_type(i));
}
}
}
void CollectDependencies() {
// Public import basically means that callers only need to import this
// proto in order to use the stuff publicly imported by this proto.
for (int i = 0; i < source_->public_dependency_count(); ++i)
public_imports_.insert(source_->public_dependency(i));
if (source_->weak_dependency_count() > 0)
Abort("Weak imports are not supported.");
// Validations. Collect public imports (of collected imports) in DFS order.
// Visibilty for current proto:
// - all imports listed in current proto,
// - public imports of everything imported (recursive).
std::vector<const FileDescriptor*> stack;
for (int i = 0; i < source_->dependency_count(); ++i) {
const FileDescriptor* imp = source_->dependency(i);
stack.push_back(imp);
if (public_imports_.count(imp) == 0) {
private_imports_.insert(imp);
}
}
while (!stack.empty()) {
const FileDescriptor* imp = stack.back();
stack.pop_back();
for (int i = 0; i < imp->public_dependency_count(); ++i) {
stack.push_back(imp->public_dependency(i));
}
}
// Collect descriptors of messages and enums used in current proto.
// It will be used to generate the necessary "use" statements.
for (const Descriptor* message : messages_) {
for (int i = 0; i < message->field_count(); ++i) {
const FieldDescriptor* field = message->field(i);
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
if (public_imports_.count(field->message_type()->file()) == 0) {
referenced_messages_.insert(field->message_type());
}
} else if (field->type() == FieldDescriptor::TYPE_ENUM) {
if (public_imports_.count(field->enum_type()->file()) == 0) {
referenced_enums_.insert(field->enum_type());
}
}
}
}
// Collect dependencies for extension fields (base message and field types).
for (const auto& key_value : extensions_) {
for (const FieldDescriptor* field : key_value.second) {
// The extended message type (e.g. TrackEvent).
const Descriptor* extendee = field->containing_type();
if (public_imports_.count(extendee->file()) == 0) {
referenced_messages_.insert(extendee);
}
// Message/enum types used in extension fields.
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
if (public_imports_.count(field->message_type()->file()) == 0) {
referenced_messages_.insert(field->message_type());
}
} else if (field->type() == FieldDescriptor::TYPE_ENUM) {
if (public_imports_.count(field->enum_type()->file()) == 0) {
referenced_enums_.insert(field->enum_type());
}
}
}
}
}
void Preprocess() {
// Package name maps to a series of namespaces.
package_ = source_->package();
namespaces_ = SplitString(package_, ".");
if (!wrapper_namespace_.empty())
namespaces_.push_back(wrapper_namespace_);
full_namespace_prefix_ = "";
for (size_t i = 0; i < namespaces_.size(); i++) {
full_namespace_prefix_ += namespaces_[i];
if (i + 1 != namespaces_.size()) {
full_namespace_prefix_ += "_";
}
}
CollectDescriptors();
CollectDependencies();
}
// Print top header, namespaces and forward declarations.
void GeneratePrologue() {
stub_rs_->Print(
R"(// Copyright (C) 2025 Rivos Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
)");
stub_rs_->Print(
"// Autogenerated by the ProtoZero Rust compiler plugin.\n");
if (!invoker_.empty()) {
stub_rs_->Print("// Invoked by $invoker$\n", "invoker", invoker_);
}
stub_rs_->Print("// DO NOT EDIT.\n\n");
if (!enums_.empty()) {
stub_rs_->Print("use crate::pb_enum;\n");
}
if (!messages_.empty() || !extensions_.empty()) {
stub_rs_->Print("use crate::pb_msg;\n");
}
if (!extensions_.empty()) {
stub_rs_->Print("use crate::pb_msg_ext;\n");
}
// Print use statements for public imports, enums and messages.
std::vector<std::string> imports;
for (const FileDescriptor* dependency : public_imports_) {
imports.push_back(ProtoStubName(dependency));
}
for (const EnumDescriptor* e : referenced_enums_) {
if (e->file() != source_) {
imports.push_back(ProtoStubName(e->file()));
}
}
for (const Descriptor* m : referenced_messages_) {
if (m->file() != source_) {
imports.push_back(ProtoStubName(m->file()));
}
}
std::sort(imports.begin(), imports.end());
// Remove consecutive duplicates.
auto last = std::unique(imports.begin(), imports.end());
imports.erase(last, imports.end());
for (const std::string& imp : imports) {
std::string mod_path = imp;
if (!path_strip_prefix_.empty()) {
mod_path = StripPrefix(imp, path_strip_prefix_);
}
// When external_crate is set and this import is not a local file,
// use the external crate path instead of crate::.
bool is_external =
!external_crate_.empty() && local_files_.count(imp + ".proto") == 0;
std::string crate_prefix = is_external ? external_crate_ : "crate";
stub_rs_->Print("use $crate$::protos$mod$::*;\n", "crate", crate_prefix,
"mod", ReplaceAll(mod_path, "/", "::"));
}
}
void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
std::string name;
if (enumeration->containing_type()) {
name.append(GetRustStructName(enumeration->containing_type()));
}
name.append(GetRustStructName(enumeration));
// Build enum values content and calculate single-line length
std::string values_content;
for (int i = 0; i < enumeration->value_count(); ++i) {
const EnumValueDescriptor* value = enumeration->value(i);
if (i > 0) {
values_content += ", ";
}
values_content += value->name();
values_content += ": ";
values_content += IntLiteralString(value->number());
}
// Single-line format: "pb_enum!(" + name + " { " + values + " });"
// For empty: "pb_enum!(" + name + " {});" = 14 + name.length()
// For non-empty: "pb_enum!(" + name + " { " + values + " });" = 16 +
// name.length() + values.length()
size_t single_line_length =
enumeration->value_count() == 0
? 14 + name.length()
: 16 + name.length() + values_content.length();
if (single_line_length <= kMaxSingleLinePbEnumLength) {
if (enumeration->value_count() == 0) {
stub_rs_->Print("\npb_enum!($name$ {});\n", "name", name);
} else {
stub_rs_->Print("\npb_enum!($name$ { $values$ });\n", "name", name,
"values", values_content);
}
} else {
stub_rs_->Print("\npb_enum!($name$ {\n", "name", name);
for (int i = 0; i < enumeration->value_count(); ++i) {
const EnumValueDescriptor* value = enumeration->value(i);
const std::string value_name = std::string(value->name());
stub_rs_->Print(" ");
stub_rs_->Print("$val$: $number$,\n", "val", value_name, "number",
IntLiteralString(value->number()));
}
stub_rs_->Print("});\n");
}
}
// Generate field descriptor content as a string (without
// indentation/newline).
std::string GetFieldDescriptorContent(const FieldDescriptor* field) {
std::string name = std::string(field->lowercase_name());
std::string id = std::to_string(field->number());
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
std::string inner_struct = GetFullRustMessageName(field->message_type());
return name + ": " + inner_struct + ", msg, " + id + ",";
}
std::string type = FieldToRustTypeName(field);
std::string kind;
switch (field->type()) {
case FieldDescriptor::TYPE_ENUM:
kind = "enum";
break;
case FieldDescriptor::TYPE_GROUP:
Abort("Groups not supported.");
return "";
default:
kind = "primitive";
break;
}
return name + ": " + type + ", " + kind + ", " + id + ",";
}
void GenerateMessageDescriptor(const Descriptor* message) {
std::string name = GetFullRustMessageName(message);
if (message->field_count() == 0) {
stub_rs_->Print("\npb_msg!($name$ {});\n", "name", name);
} else {
stub_rs_->Print("\npb_msg!($name$ {\n", "name", name);
for (int i = 0; i < message->field_count(); ++i) {
stub_rs_->Print(" $field$\n", "field",
GetFieldDescriptorContent(message->field(i)));
}
stub_rs_->Print("});\n");
}
}
void GenerateExtensionDescriptor(
const std::string& /*extension_name*/,
const std::vector<const FieldDescriptor*>& fields) {
// The extended message name (e.g. "TrackEvent").
std::string base_name =
GetFullRustMessageName(fields[0]->containing_type());
stub_rs_->Print("\npb_msg_ext!($base$ {\n", "base", base_name);
for (const FieldDescriptor* field : fields) {
stub_rs_->Print(" $field$\n", "field",
GetFieldDescriptorContent(field));
}
stub_rs_->Print("});\n");
}
const FileDescriptor* const source_;
Printer* const stub_rs_;
std::string error_;
std::string package_;
std::string wrapper_namespace_;
std::string path_strip_prefix_;
std::string path_add_prefix_;
std::string external_crate_;
std::set<std::string> local_files_;
std::string invoker_;
std::vector<std::string> namespaces_;
std::string full_namespace_prefix_;
std::vector<const Descriptor*> messages_;
std::vector<const EnumDescriptor*> enums_;
std::map<std::string, std::vector<const FieldDescriptor*>> extensions_;
// The custom *Comp comparators are to ensure determinism of the generator.
std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
std::set<const Descriptor*, DescriptorComp> referenced_messages_;
std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
};
class ProtoZeroRustGenerator
: public google::protobuf::compiler::CodeGenerator {
public:
explicit ProtoZeroRustGenerator();
~ProtoZeroRustGenerator() override;
// CodeGenerator implementation
bool Generate(const google::protobuf::FileDescriptor* file,
const std::string& options,
GeneratorContext* context,
std::string* error) const override;
};
ProtoZeroRustGenerator::ProtoZeroRustGenerator() {}
ProtoZeroRustGenerator::~ProtoZeroRustGenerator() {}
bool ProtoZeroRustGenerator::Generate(const FileDescriptor* file,
const std::string& options,
GeneratorContext* context,
std::string* error) const {
const std::unique_ptr<ZeroCopyOutputStream> stub_rs_file_stream(
context->Open(ProtoStubName(file) + ".pz.rs"));
// Variables are delimited by $.
Printer stub_rs_printer(stub_rs_file_stream.get(), '$');
GeneratorJob job(file, &stub_rs_printer);
// Parse additional options.
for (const std::string& option : SplitString(options, ",")) {
std::vector<std::string> option_pair = SplitString(option, "=");
job.SetOption(option_pair[0], option_pair[1]);
}
if (!job.GenerateStubs()) {
*error = job.GetFirstError();
return false;
}
return true;
}
} // namespace
} // namespace protozero
int main(int argc, char* argv[]) {
protozero::ProtoZeroRustGenerator generator;
return google::protobuf::compiler::PluginMain(argc, argv, &generator);
}