Implement rust repeated scalars for cpp and upb
PiperOrigin-RevId: 574261929
diff --git a/Cargo.bazel.lock b/Cargo.bazel.lock
index 8067490..3839525 100644
--- a/Cargo.bazel.lock
+++ b/Cargo.bazel.lock
@@ -1,5 +1,5 @@
{
- "checksum": "8bc2d235f612e77f4dca1b6886cc8bd14df348168fea27a687805ed9518a8f1a",
+ "checksum": "641f887b045ff0fc19f64df79b53d96d77d1c03c96069036d84bd1104ddc0000",
"crates": {
"aho-corasick 1.1.2": {
"name": "aho-corasick",
@@ -108,6 +108,15 @@
"selects": {}
},
"edition": "2018",
+ "proc_macro_deps": {
+ "common": [
+ {
+ "id": "paste 1.0.14",
+ "target": "paste"
+ }
+ ],
+ "selects": {}
+ },
"version": "0.0.1"
},
"license": null
@@ -318,6 +327,59 @@
},
"license": "MIT OR Apache-2.0"
},
+ "paste 1.0.14": {
+ "name": "paste",
+ "version": "1.0.14",
+ "repository": {
+ "Http": {
+ "url": "https://crates.io/api/v1/crates/paste/1.0.14/download",
+ "sha256": "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
+ }
+ },
+ "targets": [
+ {
+ "ProcMacro": {
+ "crate_name": "paste",
+ "crate_root": "src/lib.rs",
+ "srcs": [
+ "**/*.rs"
+ ]
+ }
+ },
+ {
+ "BuildScript": {
+ "crate_name": "build_script_build",
+ "crate_root": "build.rs",
+ "srcs": [
+ "**/*.rs"
+ ]
+ }
+ }
+ ],
+ "library_target_name": "paste",
+ "common_attrs": {
+ "compile_data_glob": [
+ "**"
+ ],
+ "deps": {
+ "common": [
+ {
+ "id": "paste 1.0.14",
+ "target": "build_script_build"
+ }
+ ],
+ "selects": {}
+ },
+ "edition": "2018",
+ "version": "1.0.14"
+ },
+ "build_script_attrs": {
+ "data_glob": [
+ "**"
+ ]
+ },
+ "license": "MIT OR Apache-2.0"
+ },
"proc-macro2 1.0.69": {
"name": "proc-macro2",
"version": "1.0.69",
@@ -769,4 +831,3 @@
},
"conditions": {}
}
-
diff --git a/Cargo.lock b/Cargo.lock
index e075e97..ea70571 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -22,6 +22,7 @@
version = "0.0.1"
dependencies = [
"googletest",
+ "paste",
]
[[package]]
@@ -62,6 +63,12 @@
]
[[package]]
+name = "paste"
+version = "1.0.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
+
+[[package]]
name = "proc-macro2"
version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -130,4 +137,3 @@
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
-
diff --git a/WORKSPACE b/WORKSPACE
index 2c665be..d0d66ec 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -193,6 +193,9 @@
"googletest": crate.spec(
version = ">0.0.0",
),
+ "paste": crate.spec(
+ version = ">=1",
+ ),
},
)
diff --git a/rust/BUILD b/rust/BUILD
index dd23b44..3e0ca71 100644
--- a/rust/BUILD
+++ b/rust/BUILD
@@ -52,6 +52,7 @@
"optional.rs",
"primitive.rs",
"proxied.rs",
+ "repeated.rs",
"shared.rs",
"string.rs",
"vtable.rs",
@@ -92,8 +93,14 @@
name = "protobuf_cpp",
srcs = PROTOBUF_SHARED + ["cpp.rs"],
crate_root = "shared.rs",
+ proc_macro_deps = [
+ "@crate_index//:paste",
+ ],
rustc_flags = ["--cfg=cpp_kernel"],
- deps = [":utf8"],
+ deps = [
+ ":utf8",
+ "//rust/cpp_kernel:cpp_api",
+ ],
)
rust_test(
diff --git a/rust/cpp.rs b/rust/cpp.rs
index c2af02d..6ef1240 100644
--- a/rust/cpp.rs
+++ b/rust/cpp.rs
@@ -7,7 +7,8 @@
// Rust Protobuf runtime using the C++ kernel.
-use crate::__internal::{Private, RawArena, RawMessage};
+use crate::__internal::{Private, RawArena, RawMessage, RawRepeatedField};
+use paste::paste;
use std::alloc::Layout;
use std::cell::UnsafeCell;
use std::fmt;
@@ -35,6 +36,7 @@
impl Arena {
/// Allocates a fresh arena.
#[inline]
+ #[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self { ptr: NonNull::dangling(), _not_sync: PhantomData }
}
@@ -182,6 +184,116 @@
val
}
+/// RepeatedField impls delegate out to `extern "C"` functions exposed by
+/// `cpp_api.h` and store either a RepeatedField* or a RepeatedPtrField*
+/// depending on the type.
+///
+/// Note: even though this type is `Copy`, it should only be copied by
+/// protobuf internals that can maintain mutation invariants:
+///
+/// - No concurrent mutation for any two fields in a message: this means
+/// mutators cannot be `Send` but are `Sync`.
+/// - If there are multiple accessible `Mut` to a single message at a time, they
+/// must be different fields, and not be in the same oneof. As such, a `Mut`
+/// cannot be `Clone` but *can* reborrow itself with `.as_mut()`, which
+/// converts `&'b mut Mut<'a, T>` to `Mut<'b, T>`.
+#[derive(Clone, Copy)]
+pub struct RepeatedField<'msg, T: ?Sized> {
+ inner: RepeatedFieldInner<'msg>,
+ _phantom: PhantomData<&'msg mut T>,
+}
+
+/// CPP runtime-specific arguments for initializing a RepeatedField.
+/// See RepeatedField comment about mutation invariants for when this type can
+/// be copied.
+#[derive(Clone, Copy)]
+pub struct RepeatedFieldInner<'msg> {
+ pub raw: RawRepeatedField,
+ pub _phantom: PhantomData<&'msg ()>,
+}
+
+impl<'msg, T: ?Sized> RepeatedField<'msg, T> {
+ pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self {
+ RepeatedField { inner, _phantom: PhantomData }
+ }
+}
+impl<'msg> RepeatedField<'msg, i32> {}
+
+pub trait RepeatedScalarOps {
+ fn new_repeated_field() -> RawRepeatedField;
+ fn push(f: RawRepeatedField, v: Self);
+ fn len(f: RawRepeatedField) -> usize;
+ fn get(f: RawRepeatedField, i: usize) -> Self;
+ fn set(f: RawRepeatedField, i: usize, v: Self);
+}
+
+macro_rules! impl_repeated_scalar_ops {
+ ($($t: ty),*) => {
+ paste! { $(
+ extern "C" {
+ fn [< __pb_rust_RepeatedField_ $t _new >]() -> RawRepeatedField;
+ fn [< __pb_rust_RepeatedField_ $t _add >](f: RawRepeatedField, v: $t);
+ fn [< __pb_rust_RepeatedField_ $t _size >](f: RawRepeatedField) -> usize;
+ fn [< __pb_rust_RepeatedField_ $t _get >](f: RawRepeatedField, i: usize) -> $t;
+ fn [< __pb_rust_RepeatedField_ $t _set >](f: RawRepeatedField, i: usize, v: $t);
+ }
+ impl RepeatedScalarOps for $t {
+ fn new_repeated_field() -> RawRepeatedField {
+ unsafe { [< __pb_rust_RepeatedField_ $t _new >]() }
+ }
+ fn push(f: RawRepeatedField, v: Self) {
+ unsafe { [< __pb_rust_RepeatedField_ $t _add >](f, v) }
+ }
+ fn len(f: RawRepeatedField) -> usize {
+ unsafe { [< __pb_rust_RepeatedField_ $t _size >](f) }
+ }
+ fn get(f: RawRepeatedField, i: usize) -> Self {
+ unsafe { [< __pb_rust_RepeatedField_ $t _get >](f, i) }
+ }
+ fn set(f: RawRepeatedField, i: usize, v: Self) {
+ unsafe { [< __pb_rust_RepeatedField_ $t _set >](f, i, v) }
+ }
+ }
+ )* }
+ };
+}
+
+impl_repeated_scalar_ops!(i32, u32, i64, u64, f32, f64, bool);
+
+impl<'msg, T: RepeatedScalarOps> RepeatedField<'msg, T> {
+ #[allow(clippy::new_without_default, dead_code)]
+ /// new() is not currently used in our normal pathways, it is only used
+ /// for testing. Existing `RepeatedField<>`s are owned by, and retrieved
+ /// from, the containing `Message`.
+ pub fn new() -> Self {
+ Self::from_inner(
+ Private,
+ RepeatedFieldInner::<'msg> { raw: T::new_repeated_field(), _phantom: PhantomData },
+ )
+ }
+ pub fn push(&mut self, val: T) {
+ T::push(self.inner.raw, val)
+ }
+ pub fn len(&self) -> usize {
+ T::len(self.inner.raw)
+ }
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+ pub fn get(&self, index: usize) -> Option<T> {
+ if index >= self.len() {
+ return None;
+ }
+ Some(T::get(self.inner.raw, index))
+ }
+ pub fn set(&mut self, index: usize, val: T) {
+ if index >= self.len() {
+ return;
+ }
+ T::set(self.inner.raw, index, val)
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -201,4 +313,27 @@
let serialized_data = SerializedData { data: NonNull::new(ptr).unwrap(), len: len };
assert_eq!(&*serialized_data, b"Hello world");
}
+
+ #[test]
+ fn repeated_field() {
+ let mut r = RepeatedField::<i32>::new();
+ assert_eq!(r.len(), 0);
+ r.push(32);
+ assert_eq!(r.get(0), Some(32));
+
+ let mut r = RepeatedField::<u32>::new();
+ assert_eq!(r.len(), 0);
+ r.push(32);
+ assert_eq!(r.get(0), Some(32));
+
+ let mut r = RepeatedField::<f64>::new();
+ assert_eq!(r.len(), 0);
+ r.push(0.1234f64);
+ assert_eq!(r.get(0), Some(0.1234));
+
+ let mut r = RepeatedField::<bool>::new();
+ assert_eq!(r.len(), 0);
+ r.push(true);
+ assert_eq!(r.get(0), Some(true));
+ }
}
diff --git a/rust/cpp_kernel/BUILD b/rust/cpp_kernel/BUILD
index 245772c..d10f9e7 100644
--- a/rust/cpp_kernel/BUILD
+++ b/rust/cpp_kernel/BUILD
@@ -4,13 +4,14 @@
cc_library(
name = "cpp_api",
+ srcs = ["cpp_api.cc"],
hdrs = ["cpp_api.h"],
visibility = [
"//src/google/protobuf:__subpackages__",
"//rust:__subpackages__",
],
deps = [
- ":rust_alloc_for_cpp_api",
+ ":rust_alloc_for_cpp_api", # buildcleaner: keep
"//:protobuf_nowkt",
],
)
diff --git a/rust/cpp_kernel/cpp_api.cc b/rust/cpp_kernel/cpp_api.cc
new file mode 100644
index 0000000..8ff79d8
--- /dev/null
+++ b/rust/cpp_kernel/cpp_api.cc
@@ -0,0 +1,35 @@
+#include "google/protobuf/repeated_field.h"
+
+extern "C" {
+
+#define expose_repeated_field_methods(ty, rust_ty) \
+ google::protobuf::RepeatedField<ty>* __pb_rust_RepeatedField_##rust_ty##_new() { \
+ return new google::protobuf::RepeatedField<ty>(); \
+ } \
+ void __pb_rust_RepeatedField_##rust_ty##_add(google::protobuf::RepeatedField<ty>* r, \
+ ty val) { \
+ r->Add(val); \
+ } \
+ size_t __pb_rust_RepeatedField_##rust_ty##_size( \
+ google::protobuf::RepeatedField<ty>* r) { \
+ return r->size(); \
+ } \
+ ty __pb_rust_RepeatedField_##rust_ty##_get(google::protobuf::RepeatedField<ty>* r, \
+ size_t index) { \
+ return r->Get(index); \
+ } \
+ void __pb_rust_RepeatedField_##rust_ty##_set(google::protobuf::RepeatedField<ty>* r, \
+ size_t index, ty val) { \
+ return r->Set(index, val); \
+ }
+
+expose_repeated_field_methods(int32_t, i32);
+expose_repeated_field_methods(uint32_t, u32);
+expose_repeated_field_methods(float, f32);
+expose_repeated_field_methods(double, f64);
+expose_repeated_field_methods(bool, bool);
+expose_repeated_field_methods(uint64_t, u64);
+expose_repeated_field_methods(int64_t, i64);
+
+#undef expose_repeated_field_methods
+}
diff --git a/rust/internal.rs b/rust/internal.rs
index 1e0f536..e56c9dc 100644
--- a/rust/internal.rs
+++ b/rust/internal.rs
@@ -51,6 +51,17 @@
_data: [u8; 0],
_marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>,
}
+
+ /// Opaque pointee for [`RawRepeatedField`]
+ ///
+ /// This type is not meant to be dereferenced in Rust code.
+ /// It is only meant to provide type safety for raw pointers
+ /// which are manipulated behind FFI.
+ #[repr(C)]
+ pub struct RawRepeatedFieldData {
+ _data: [u8; 0],
+ _marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>,
+ }
}
/// A raw pointer to the underlying message for this runtime.
@@ -59,6 +70,9 @@
/// A raw pointer to the underlying arena for this runtime.
pub type RawArena = NonNull<_opaque_pointees::RawArenaData>;
+/// A raw pointer to the underlying repeated field container for this runtime.
+pub type RawRepeatedField = NonNull<_opaque_pointees::RawRepeatedFieldData>;
+
/// Represents an ABI-stable version of `NonNull<[u8]>`/`string_view` (a
/// borrowed slice of bytes) for FFI use only.
///
diff --git a/rust/repeated.rs b/rust/repeated.rs
new file mode 100644
index 0000000..b824de8
--- /dev/null
+++ b/rust/repeated.rs
@@ -0,0 +1,119 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2023 Google LLC. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+/// Repeated scalar fields are implemented around the runtime-specific
+/// `RepeatedField` struct. `RepeatedField` stores an opaque pointer to the
+/// runtime-specific representation of a repeated scalar (`upb_Array*` on upb,
+/// and `RepeatedField<T>*` on cpp).
+use std::marker::PhantomData;
+
+use crate::{
+ __internal::{Private, RawRepeatedField},
+ __runtime::{RepeatedField, RepeatedFieldInner},
+};
+
+#[derive(Clone, Copy)]
+pub struct RepeatedFieldRef<'a> {
+ pub repeated_field: RawRepeatedField,
+ pub _phantom: PhantomData<&'a mut ()>,
+}
+
+unsafe impl<'a> Send for RepeatedFieldRef<'a> {}
+unsafe impl<'a> Sync for RepeatedFieldRef<'a> {}
+
+#[derive(Clone, Copy)]
+#[repr(transparent)]
+pub struct RepeatedView<'a, T: ?Sized> {
+ inner: RepeatedField<'a, T>,
+}
+
+impl<'msg, T: ?Sized> RepeatedView<'msg, T> {
+ pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self {
+ Self { inner: RepeatedField::<'msg>::from_inner(_private, inner) }
+ }
+}
+
+pub struct RepeatedFieldIter<'a, T> {
+ inner: RepeatedField<'a, T>,
+ current_index: usize,
+}
+
+impl<'a, T> std::fmt::Debug for RepeatedView<'a, T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_tuple("RepeatedView").finish()
+ }
+}
+
+#[repr(transparent)]
+pub struct RepeatedMut<'a, T: ?Sized> {
+ inner: RepeatedField<'a, T>,
+}
+
+impl<'msg, T: ?Sized> RepeatedMut<'msg, T> {
+ pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self {
+ Self { inner: RepeatedField::from_inner(_private, inner) }
+ }
+}
+
+impl<'a, T> std::ops::Deref for RepeatedMut<'a, T> {
+ type Target = RepeatedView<'a, T>;
+ fn deref(&self) -> &Self::Target {
+ // SAFETY:
+ // - `Repeated{View,Mut}<'a, T>` are both `#[repr(transparent)]` over
+ // `RepeatedField<'a, T>`.
+ // - `RepeatedField` is a type alias for `NonNull`.
+ unsafe { &*(self as *const Self as *const RepeatedView<'a, T>) }
+ }
+}
+
+macro_rules! impl_repeated_primitives {
+ ($($t:ty),*) => {
+ $(
+ impl<'a> RepeatedView<'a, $t> {
+ pub fn len(&self) -> usize {
+ self.inner.len()
+ }
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+ pub fn get(&self, index: usize) -> Option<$t> {
+ self.inner.get(index)
+ }
+ }
+
+ impl<'a> RepeatedMut<'a, $t> {
+ pub fn push(&mut self, val: $t) {
+ self.inner.push(val)
+ }
+ pub fn set(&mut self, index: usize, val: $t) {
+ self.inner.set(index, val)
+ }
+ }
+
+ impl<'a> std::iter::Iterator for RepeatedFieldIter<'a, $t> {
+ type Item = $t;
+ fn next(&mut self) -> Option<Self::Item> {
+ let val = self.inner.get(self.current_index);
+ if val.is_some() {
+ self.current_index += 1;
+ }
+ val
+ }
+ }
+
+ impl<'a> std::iter::IntoIterator for RepeatedView<'a, $t> {
+ type Item = $t;
+ type IntoIter = RepeatedFieldIter<'a, $t>;
+ fn into_iter(self) -> Self::IntoIter {
+ RepeatedFieldIter { inner: self.inner, current_index: 0 }
+ }
+ }
+ )*
+ }
+}
+
+impl_repeated_primitives!(i32, u32, bool, f32, f64, i64, u64);
diff --git a/rust/shared.rs b/rust/shared.rs
index 3c4408d..f8a9d11 100644
--- a/rust/shared.rs
+++ b/rust/shared.rs
@@ -22,6 +22,7 @@
pub use crate::proxied::{
Mut, MutProxy, Proxied, ProxiedWithPresence, SettableValue, View, ViewProxy,
};
+ pub use crate::repeated::{RepeatedFieldRef, RepeatedMut, RepeatedView};
pub use crate::string::{BytesMut, ProtoStr, ProtoStrMut};
}
pub use __public::*;
@@ -46,6 +47,7 @@
mod optional;
mod primitive;
mod proxied;
+mod repeated;
mod string;
mod vtable;
diff --git a/rust/test/cpp/interop/test_utils.cc b/rust/test/cpp/interop/test_utils.cc
index d5c3784..5c27a95 100644
--- a/rust/test/cpp/interop/test_utils.cc
+++ b/rust/test/cpp/interop/test_utils.cc
@@ -8,7 +8,7 @@
#include <cstddef>
#include "absl/strings/string_view.h"
-#include "google/protobuf/rust/cpp_kernel/cpp_api.h"
+#include "rust/cpp_kernel/cpp_api.h"
#include "google/protobuf/unittest.pb.h"
extern "C" void MutateTestAllTypes(protobuf_unittest::TestAllTypes* msg) {
diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD
index e5be7b7..79d43dd 100644
--- a/rust/test/shared/BUILD
+++ b/rust/test/shared/BUILD
@@ -151,6 +151,9 @@
"//rust:protobuf_cpp": "protobuf",
"//rust/test/shared:matchers_cpp": "matchers",
},
+ proc_macro_deps = [
+ "@crate_index//:paste",
+ ],
tags = [
# TODO: Enable testing on arm once we support sanitizers for Rust on Arm.
"not_build:arm",
@@ -170,6 +173,9 @@
"//rust:protobuf_upb": "protobuf",
"//rust/test/shared:matchers_upb": "matchers",
},
+ proc_macro_deps = [
+ "@crate_index//:paste",
+ ],
tags = [
# TODO: Enable testing on arm once we support sanitizers for Rust on Arm.
"not_build:arm",
diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs
index 40098b1..910f9f1 100644
--- a/rust/test/shared/accessors_test.rs
+++ b/rust/test/shared/accessors_test.rs
@@ -9,6 +9,7 @@
use googletest::prelude::*;
use matchers::{is_set, is_unset};
+use paste::paste;
use protobuf::Optional;
use unittest_proto::proto2_unittest::{TestAllTypes, TestAllTypes_};
@@ -398,3 +399,51 @@
// This should show it set to the OneofBytes but its not supported yet.
assert_that!(msg.oneof_field(), matches_pattern!(not_set(_)));
}
+
+macro_rules! generate_repeated_numeric_test {
+ ($(($t: ty, $field: ident)),*) => {
+ paste! { $(
+ #[test]
+ fn [< test_repeated_ $field _accessors >]() {
+ let mut msg = TestAllTypes::new();
+ assert_that!(msg.[< repeated_ $field >]().len(), eq(0));
+ assert_that!(msg.[<repeated_ $field >]().get(0), none());
+
+ let mut mutator = msg.[<repeated_ $field _mut >]();
+ mutator.push(1 as $t);
+ assert_that!(mutator.len(), eq(1));
+ assert_that!(mutator.get(0), some(eq(1 as $t)));
+ mutator.set(0, 2 as $t);
+ assert_that!(mutator.get(0), some(eq(2 as $t)));
+ mutator.push(1 as $t);
+
+ assert_that!(mutator.into_iter().collect::<Vec<_>>(), eq(vec![2 as $t, 1 as $t]));
+ }
+ )* }
+ };
+}
+
+generate_repeated_numeric_test!(
+ (i32, int32),
+ (u32, uint32),
+ (i64, int64),
+ (u64, uint64),
+ (f32, float),
+ (f64, double)
+);
+
+#[test]
+fn test_repeated_bool_accessors() {
+ let mut msg = TestAllTypes::new();
+ assert_that!(msg.repeated_bool().len(), eq(0));
+ assert_that!(msg.repeated_bool().get(0), none());
+
+ let mut mutator = msg.repeated_bool_mut();
+ mutator.push(true);
+ assert_that!(mutator.len(), eq(1));
+ assert_that!(mutator.get(0), some(eq(true)));
+ mutator.set(0, false);
+ assert_that!(mutator.get(0), some(eq(false)));
+ mutator.push(true);
+ assert_that!(mutator.into_iter().collect::<Vec<_>>(), eq(vec![false, true]));
+}
diff --git a/rust/upb.rs b/rust/upb.rs
index cd5cc77..57f9c55 100644
--- a/rust/upb.rs
+++ b/rust/upb.rs
@@ -7,7 +7,7 @@
//! UPB FFI wrapper code for use by Rust Protobuf.
-use crate::__internal::{Private, RawArena, RawMessage};
+use crate::__internal::{Private, PtrAndLen, RawArena, RawMessage, RawRepeatedField};
use std::alloc;
use std::alloc::Layout;
use std::cell::UnsafeCell;
@@ -284,6 +284,149 @@
}
}
+/// RepeatedFieldInner contains a `upb_Array*` as well as a reference to an
+/// `Arena`, most likely that of the containing `Message`. upb requires an Arena
+/// to perform mutations on a repeated field.
+#[derive(Clone, Copy, Debug)]
+pub struct RepeatedFieldInner<'msg> {
+ pub raw: RawRepeatedField,
+ pub arena: &'msg Arena,
+}
+
+#[derive(Clone, Copy, Debug)]
+pub struct RepeatedField<'msg, T: ?Sized> {
+ inner: RepeatedFieldInner<'msg>,
+ _phantom: PhantomData<&'msg mut T>,
+}
+
+impl<'msg, T: ?Sized> RepeatedField<'msg, T> {
+ pub fn len(&self) -> usize {
+ unsafe { upb_Array_Size(self.inner.raw) }
+ }
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+ pub fn from_inner(_private: Private, inner: RepeatedFieldInner<'msg>) -> Self {
+ Self { inner, _phantom: PhantomData }
+ }
+}
+
+// Transcribed from google3/third_party/upb/upb/message/value.h
+#[repr(C)]
+#[derive(Clone, Copy)]
+union upb_MessageValue {
+ bool_val: bool,
+ float_val: std::ffi::c_float,
+ double_val: std::ffi::c_double,
+ uint32_val: u32,
+ int32_val: i32,
+ uint64_val: u64,
+ int64_val: i64,
+ array_val: *const std::ffi::c_void,
+ map_val: *const std::ffi::c_void,
+ msg_val: *const std::ffi::c_void,
+ str_val: PtrAndLen,
+}
+
+// Transcribed from google3/third_party/upb/upb/base/descriptor_constants.h
+#[repr(C)]
+#[allow(dead_code)]
+enum UpbCType {
+ Bool = 1,
+ Float = 2,
+ Int32 = 3,
+ UInt32 = 4,
+ Enum = 5,
+ Message = 6,
+ Double = 7,
+ Int64 = 8,
+ UInt64 = 9,
+ String = 10,
+ Bytes = 11,
+}
+
+extern "C" {
+ #[allow(dead_code)]
+ fn upb_Array_New(a: RawArena, r#type: std::ffi::c_int) -> RawRepeatedField;
+ fn upb_Array_Size(arr: RawRepeatedField) -> usize;
+ fn upb_Array_Set(arr: RawRepeatedField, i: usize, val: upb_MessageValue);
+ fn upb_Array_Get(arr: RawRepeatedField, i: usize) -> upb_MessageValue;
+ fn upb_Array_Append(arr: RawRepeatedField, val: upb_MessageValue, arena: RawArena);
+}
+
+macro_rules! impl_repeated_primitives {
+ ($(($rs_type:ty, $union_field:ident, $upb_tag:expr)),*) => {
+ $(
+ impl<'msg> RepeatedField<'msg, $rs_type> {
+ #[allow(dead_code)]
+ fn new(arena: &'msg Arena) -> Self {
+ Self {
+ inner: RepeatedFieldInner {
+ raw: unsafe { upb_Array_New(arena.raw, $upb_tag as std::ffi::c_int) },
+ arena,
+ },
+ _phantom: PhantomData,
+ }
+ }
+ pub fn push(&mut self, val: $rs_type) {
+ unsafe { upb_Array_Append(
+ self.inner.raw,
+ upb_MessageValue { $union_field: val },
+ self.inner.arena.raw(),
+ ) }
+ }
+ pub fn get(&self, i: usize) -> Option<$rs_type> {
+ if i >= self.len() {
+ None
+ } else {
+ unsafe { Some(upb_Array_Get(self.inner.raw, i).$union_field) }
+ }
+ }
+ pub fn set(&self, i: usize, val: $rs_type) {
+ if i >= self.len() {
+ return;
+ }
+ unsafe { upb_Array_Set(
+ self.inner.raw,
+ i,
+ upb_MessageValue { $union_field: val },
+ ) }
+ }
+ }
+ )*
+ }
+}
+
+impl_repeated_primitives!(
+ (bool, bool_val, UpbCType::Bool),
+ (f32, float_val, UpbCType::Float),
+ (f64, double_val, UpbCType::Double),
+ (i32, int32_val, UpbCType::Int32),
+ (u32, uint32_val, UpbCType::UInt32),
+ (i64, int64_val, UpbCType::Int64),
+ (u64, uint64_val, UpbCType::UInt64)
+);
+
+/// Returns a static thread-local empty RepeatedFieldInner for use in a
+/// RepeatedView.
+///
+/// # Safety
+/// TODO: Split RepeatedFieldInner into mut and const variants to
+/// enforce safety. The returned array must never be mutated.
+pub unsafe fn empty_array() -> RepeatedFieldInner<'static> {
+ // TODO: Consider creating empty array in C.
+ fn new_repeated_field_inner() -> RepeatedFieldInner<'static> {
+ let arena = Box::leak::<'static>(Box::new(Arena::new()));
+ // Provide `i32` as a placeholder type.
+ RepeatedField::<'static, i32>::new(arena).inner
+ }
+ thread_local! {
+ static REPEATED_FIELD: RepeatedFieldInner<'static> = new_repeated_field_inner();
+ }
+
+ REPEATED_FIELD.with(|inner| *inner)
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -309,4 +452,35 @@
};
assert_eq!(&*serialized_data, b"Hello world");
}
+
+ #[test]
+ fn i32_array() {
+ let mut arena = Arena::new();
+ let mut arr = RepeatedField::<i32>::new(&arena);
+ assert_eq!(arr.len(), 0);
+ arr.push(1);
+ assert_eq!(arr.get(0), Some(1));
+ assert_eq!(arr.len(), 1);
+ arr.set(0, 3);
+ assert_eq!(arr.get(0), Some(3));
+ for i in 0..2048 {
+ arr.push(i);
+ assert_eq!(arr.get(arr.len() - 1), Some(i));
+ }
+ }
+ #[test]
+ fn u32_array() {
+ let mut arena = Arena::new();
+ let mut arr = RepeatedField::<u32>::new(&mut arena);
+ assert_eq!(arr.len(), 0);
+ arr.push(1);
+ assert_eq!(arr.get(0), Some(1));
+ assert_eq!(arr.len(), 1);
+ arr.set(0, 3);
+ assert_eq!(arr.get(0), Some(3));
+ for i in 0..2048 {
+ arr.push(i);
+ assert_eq!(arr.get(arr.len() - 1), Some(i));
+ }
+ }
}
diff --git a/rust/upb_kernel/BUILD b/rust/upb_kernel/BUILD
index dc65231..b06f182 100644
--- a/rust/upb_kernel/BUILD
+++ b/rust/upb_kernel/BUILD
@@ -8,6 +8,7 @@
"//rust:__subpackages__",
],
deps = [
+ "//upb:collections",
"//upb:mem",
],
)
diff --git a/rust/upb_kernel/upb_api.c b/rust/upb_kernel/upb_api.c
index 985749d..a30b4dd 100644
--- a/rust/upb_kernel/upb_api.c
+++ b/rust/upb_kernel/upb_api.c
@@ -8,4 +8,5 @@
#define UPB_BUILD_API
-#include "upb/mem/arena.h" // IWYU pragma: keep
+#include "upb/collections/array.h" // IWYU pragma: keep
+#include "upb/mem/arena.h" // IWYU pragma: keep
diff --git a/src/google/protobuf/compiler/rust/BUILD.bazel b/src/google/protobuf/compiler/rust/BUILD.bazel
index f404bea..1429b9b 100644
--- a/src/google/protobuf/compiler/rust/BUILD.bazel
+++ b/src/google/protobuf/compiler/rust/BUILD.bazel
@@ -51,6 +51,7 @@
name = "accessors",
srcs = [
"accessors/accessors.cc",
+ "accessors/repeated_scalar.cc",
"accessors/singular_message.cc",
"accessors/singular_scalar.cc",
"accessors/singular_string.cc",
diff --git a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h
index e3a4534..3bf1dca 100644
--- a/src/google/protobuf/compiler/rust/accessors/accessor_generator.h
+++ b/src/google/protobuf/compiler/rust/accessors/accessor_generator.h
@@ -86,6 +86,14 @@
void InThunkCc(Context<FieldDescriptor> field) const override;
};
+class RepeatedScalar final : public AccessorGenerator {
+ public:
+ ~RepeatedScalar() override = default;
+ void InMsgImpl(Context<FieldDescriptor> field) const override;
+ void InExternC(Context<FieldDescriptor> field) const override;
+ void InThunkCc(Context<FieldDescriptor> field) const override;
+};
+
class UnsupportedField final : public AccessorGenerator {
public:
~UnsupportedField() override = default;
diff --git a/src/google/protobuf/compiler/rust/accessors/accessors.cc b/src/google/protobuf/compiler/rust/accessors/accessors.cc
index 82b38cc..fa5c876 100644
--- a/src/google/protobuf/compiler/rust/accessors/accessors.cc
+++ b/src/google/protobuf/compiler/rust/accessors/accessors.cc
@@ -29,10 +29,6 @@
return std::make_unique<UnsupportedField>();
}
- if (desc.is_repeated()) {
- return std::make_unique<UnsupportedField>();
- }
-
switch (desc.type()) {
case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_INT64:
@@ -47,11 +43,20 @@
case FieldDescriptor::TYPE_FLOAT:
case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_BOOL:
+ if (desc.is_repeated()) {
+ return std::make_unique<RepeatedScalar>();
+ }
return std::make_unique<SingularScalar>();
case FieldDescriptor::TYPE_BYTES:
case FieldDescriptor::TYPE_STRING:
+ if (desc.is_repeated()) {
+ return std::make_unique<UnsupportedField>();
+ }
return std::make_unique<SingularString>();
case FieldDescriptor::TYPE_MESSAGE:
+ if (desc.is_repeated()) {
+ return std::make_unique<UnsupportedField>();
+ }
return std::make_unique<SingularMessage>();
default:
diff --git a/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc
new file mode 100644
index 0000000..8f0a762
--- /dev/null
+++ b/src/google/protobuf/compiler/rust/accessors/repeated_scalar.cc
@@ -0,0 +1,156 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2023 Google LLC. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+#include "absl/strings/string_view.h"
+#include "google/protobuf/compiler/cpp/helpers.h"
+#include "google/protobuf/compiler/rust/accessors/accessor_generator.h"
+#include "google/protobuf/compiler/rust/context.h"
+#include "google/protobuf/compiler/rust/naming.h"
+#include "google/protobuf/descriptor.h"
+
+namespace google {
+namespace protobuf {
+namespace compiler {
+namespace rust {
+
+void RepeatedScalar::InMsgImpl(Context<FieldDescriptor> field) const {
+ field.Emit({{"field", field.desc().name()},
+ {"Scalar", PrimitiveRsTypeName(field.desc())},
+ {"getter_thunk", Thunk(field, "get")},
+ {"getter_mut_thunk", Thunk(field, "get_mut")},
+ {"getter",
+ [&] {
+ if (field.is_upb()) {
+ field.Emit({}, R"rs(
+ pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
+ let inner = unsafe {
+ $getter_thunk$(
+ self.inner.msg,
+ /* optional size pointer */ std::ptr::null(),
+ ) }
+ .map_or_else(|| unsafe {$pbr$::empty_array()}, |raw| {
+ $pbr$::RepeatedFieldInner{ raw, arena: &self.inner.arena }
+ });
+ $pb$::RepeatedView::from_inner($pbi$::Private, inner)
+ }
+ )rs");
+ } else {
+ field.Emit({}, R"rs(
+ pub fn r#$field$(&self) -> $pb$::RepeatedView<'_, $Scalar$> {
+ $pb$::RepeatedView::from_inner(
+ $pbi$::Private,
+ $pbr$::RepeatedFieldInner{
+ raw: unsafe { $getter_thunk$(self.inner.msg) },
+ _phantom: std::marker::PhantomData,
+ },
+ )
+ }
+ )rs");
+ }
+ }},
+ {"clearer_thunk", Thunk(field, "clear")},
+ {"field_mutator_getter",
+ [&] {
+ if (field.is_upb()) {
+ field.Emit({}, R"rs(
+ pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
+ $pb$::RepeatedMut::from_inner(
+ $pbi$::Private,
+ $pbr$::RepeatedFieldInner{
+ raw: unsafe { $getter_mut_thunk$(
+ self.inner.msg,
+ /* optional size pointer */ std::ptr::null(),
+ self.inner.arena.raw(),
+ ) },
+ arena: &self.inner.arena,
+ },
+ )
+ }
+ )rs");
+ } else {
+ field.Emit({}, R"rs(
+ pub fn r#$field$_mut(&mut self) -> $pb$::RepeatedMut<'_, $Scalar$> {
+ $pb$::RepeatedMut::from_inner(
+ $pbi$::Private,
+ $pbr$::RepeatedFieldInner{
+ raw: unsafe { $getter_mut_thunk$(self.inner.msg)},
+ _phantom: std::marker::PhantomData,
+ },
+ )
+ }
+ )rs");
+ }
+ }}},
+ R"rs(
+ $getter$
+ $field_mutator_getter$
+ )rs");
+}
+
+void RepeatedScalar::InExternC(Context<FieldDescriptor> field) const {
+ field.Emit({{"Scalar", PrimitiveRsTypeName(field.desc())},
+ {"getter_thunk", Thunk(field, "get")},
+ {"getter_mut_thunk", Thunk(field, "get_mut")},
+ {"getter",
+ [&] {
+ if (field.is_upb()) {
+ field.Emit(R"rs(
+ fn $getter_mut_thunk$(
+ raw_msg: $pbi$::RawMessage,
+ size: *const usize,
+ arena: $pbi$::RawArena,
+ ) -> $pbi$::RawRepeatedField;
+ // Returns `None` when returned array pointer is NULL.
+ fn $getter_thunk$(
+ raw_msg: $pbi$::RawMessage,
+ size: *const usize,
+ ) -> Option<$pbi$::RawRepeatedField>;
+ )rs");
+ } else {
+ field.Emit(R"rs(
+ fn $getter_mut_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField;
+ fn $getter_thunk$(raw_msg: $pbi$::RawMessage) -> $pbi$::RawRepeatedField;
+ )rs");
+ }
+ }},
+ {"clearer_thunk", Thunk(field, "clear")}},
+ R"rs(
+ fn $clearer_thunk$(raw_msg: $pbi$::RawMessage);
+ $getter$
+ )rs");
+}
+
+void RepeatedScalar::InThunkCc(Context<FieldDescriptor> field) const {
+ field.Emit({{"field", cpp::FieldName(&field.desc())},
+ {"Scalar", cpp::PrimitiveTypeName(field.desc().cpp_type())},
+ {"QualifiedMsg",
+ cpp::QualifiedClassName(field.desc().containing_type())},
+ {"clearer_thunk", Thunk(field, "clear")},
+ {"getter_thunk", Thunk(field, "get")},
+ {"getter_mut_thunk", Thunk(field, "get_mut")},
+ {"impls",
+ [&] {
+ field.Emit(
+ R"cc(
+ void $clearer_thunk$($QualifiedMsg$* msg) {
+ msg->clear_$field$();
+ }
+ google::protobuf::RepeatedField<$Scalar$>* $getter_mut_thunk$($QualifiedMsg$* msg) {
+ return msg->mutable_$field$();
+ }
+ const google::protobuf::RepeatedField<$Scalar$>& $getter_thunk$($QualifiedMsg$& msg) {
+ return msg.$field$();
+ }
+ )cc");
+ }}},
+ "$impls$");
+}
+
+} // namespace rust
+} // namespace compiler
+} // namespace protobuf
+} // namespace google
diff --git a/src/google/protobuf/compiler/rust/naming.cc b/src/google/protobuf/compiler/rust/naming.cc
index eb0b30b..bb1c35a 100644
--- a/src/google/protobuf/compiler/rust/naming.cc
+++ b/src/google/protobuf/compiler/rust/naming.cc
@@ -64,22 +64,28 @@
namespace {
template <typename T>
-std::string Thunk(Context<T> field, absl::string_view op) {
+std::string FieldPrefix(Context<T> field) {
// NOTE: When field.is_upb(), this functions outputs must match the symbols
// that the upbc plugin generates exactly. Failure to do so correctly results
// in a link-time failure.
absl::string_view prefix = field.is_cpp() ? "__rust_proto_thunk__" : "";
- std::string thunk =
+ std::string thunk_prefix =
absl::StrCat(prefix, GetUnderscoreDelimitedFullName(
field.WithDesc(field.desc().containing_type())));
+ return thunk_prefix;
+}
+
+template <typename T>
+std::string Thunk(Context<T> field, absl::string_view op) {
+ std::string thunk = FieldPrefix(field);
absl::string_view format;
if (field.is_upb() && op == "get") {
// upb getter is simply the field name (no "get" in the name).
format = "_$1";
- } else if (field.is_upb() && op == "case") {
- // upb oneof case function is x_case compared to has/set/clear which are in
- // the other order e.g. clear_x.
+ } else if (field.is_upb() && (op == "case")) {
+ // some upb functions are in the order x_op compared to has/set/clear which
+ // are in the other order e.g. op_x.
format = "_$1_$0";
} else {
format = "_$0_$1";
@@ -89,9 +95,32 @@
return thunk;
}
+std::string ThunkRepeated(Context<FieldDescriptor> field,
+ absl::string_view op) {
+ if (!field.is_upb()) {
+ return Thunk<FieldDescriptor>(field, op);
+ }
+
+ std::string thunk = absl::StrCat("_", FieldPrefix(field));
+ absl::string_view format;
+ if (op == "get") {
+ format = "_$1_upb_array";
+ } else if (op == "get_mut") {
+ format = "_$1_mutable_upb_array";
+ } else {
+ return Thunk<FieldDescriptor>(field, op);
+ }
+
+ absl::SubstituteAndAppend(&thunk, format, op, field.desc().name());
+ return thunk;
+}
+
} // namespace
std::string Thunk(Context<FieldDescriptor> field, absl::string_view op) {
+ if (field.desc().is_repeated()) {
+ return ThunkRepeated(field, op);
+ }
return Thunk<FieldDescriptor>(field, op);
}