This CL implements msg.<field>() and msg.<field_mut>() accessors for maps with primitive-typed keys and values for the UPB kernel only.
Support for the CPP runtime and non-scalar value types will be implemented in follow up CLs.
PiperOrigin-RevId: 580453646
diff --git a/rust/BUILD b/rust/BUILD
index 3e0ca71..61f8236 100644
--- a/rust/BUILD
+++ b/rust/BUILD
@@ -65,7 +65,10 @@
# setting.
rust_library(
name = "protobuf_upb",
- srcs = PROTOBUF_SHARED + ["upb.rs"],
+ srcs = PROTOBUF_SHARED + [
+ "map.rs",
+ "upb.rs",
+ ],
crate_root = "shared.rs",
rustc_flags = ["--cfg=upb_kernel"],
visibility = [
@@ -82,6 +85,9 @@
name = "protobuf_upb_test",
crate = ":protobuf_upb",
rustc_flags = ["--cfg=upb_kernel"],
+ deps = [
+ "@crate_index//:googletest",
+ ],
)
# The Rust Protobuf runtime using the cpp kernel.
diff --git a/rust/internal.rs b/rust/internal.rs
index d7ea96b..be9ddb2 100644
--- a/rust/internal.rs
+++ b/rust/internal.rs
@@ -62,6 +62,17 @@
_data: [u8; 0],
_marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>,
}
+
+ /// Opaque pointee for [`RawMap`]
+ ///
+ /// This type is not meant to be dereferenced in Rust code.
+ /// It is only meant to provide type safety for raw pointers
+ /// which are manipulated behind FFI.
+ #[repr(C)]
+ pub struct RawMapData {
+ _data: [u8; 0],
+ _marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>,
+ }
}
/// A raw pointer to the underlying message for this runtime.
@@ -73,6 +84,9 @@
/// A raw pointer to the underlying repeated field container for this runtime.
pub type RawRepeatedField = NonNull<_opaque_pointees::RawRepeatedFieldData>;
+/// A raw pointer to the underlying arena for this runtime.
+pub type RawMap = NonNull<_opaque_pointees::RawMapData>;
+
/// Represents an ABI-stable version of `NonNull<[u8]>`/`string_view` (a
/// borrowed slice of bytes) for FFI use only.
///
diff --git a/rust/map.rs b/rust/map.rs
new file mode 100644
index 0000000..450b854
--- /dev/null
+++ b/rust/map.rs
@@ -0,0 +1,77 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2023 Google LLC. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+use crate::{
+ __internal::Private,
+ __runtime::{Map, MapInner, MapValueType},
+};
+
+#[derive(Clone, Copy)]
+#[repr(transparent)]
+pub struct MapView<'a, K: ?Sized, V: ?Sized> {
+ inner: Map<'a, K, V>,
+}
+
+#[derive(Clone, Copy)]
+#[repr(transparent)]
+pub struct MapMut<'a, K: ?Sized, V: ?Sized> {
+ inner: Map<'a, K, V>,
+}
+
+impl<'a, K: ?Sized, V: ?Sized> MapView<'a, K, V> {
+ pub fn from_inner(_private: Private, inner: MapInner<'a>) -> Self {
+ Self { inner: Map::<'a, K, V>::from_inner(_private, inner) }
+ }
+
+ pub fn len(&self) -> usize {
+ self.inner.len()
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+}
+
+impl<'a, K: ?Sized, V: ?Sized> MapMut<'a, K, V> {
+ pub fn from_inner(_private: Private, inner: MapInner<'a>) -> Self {
+ Self { inner: Map::<'a, K, V>::from_inner(_private, inner) }
+ }
+}
+
+macro_rules! impl_scalar_map_keys {
+ ($(key_type $type:ty;)*) => {
+ $(
+ impl<'a, V: MapValueType> MapView<'a, $type, V> {
+ pub fn get(&self, key: $type) -> Option<V> {
+ self.inner.get(key)
+ }
+ }
+
+ impl<'a, V: MapValueType> MapMut<'a, $type, V> {
+ pub fn insert(&mut self, key: $type, value: V) -> bool {
+ self.inner.insert(key, value)
+ }
+
+ pub fn remove(&mut self, key: $type) -> Option<V> {
+ self.inner.remove(key)
+ }
+
+ pub fn clear(&mut self) {
+ self.inner.clear()
+ }
+ }
+ )*
+ };
+}
+
+impl_scalar_map_keys!(
+ key_type i32;
+ key_type u32;
+ key_type i64;
+ key_type u64;
+ key_type bool;
+);
diff --git a/rust/shared.rs b/rust/shared.rs
index 7c3a3d1..5aa054c 100644
--- a/rust/shared.rs
+++ b/rust/shared.rs
@@ -17,6 +17,8 @@
/// These are the items protobuf users can access directly.
#[doc(hidden)]
pub mod __public {
+ #[cfg(upb_kernel)]
+ pub use crate::map::{MapMut, MapView};
pub use crate::optional::{AbsentField, FieldEntry, Optional, PresentField};
pub use crate::primitive::{PrimitiveMut, SingularPrimitiveMut};
pub use crate::proxied::{
@@ -44,6 +46,8 @@
pub mod __runtime;
mod macros;
+#[cfg(upb_kernel)]
+mod map;
mod optional;
mod primitive;
mod proxied;
diff --git a/rust/test/BUILD b/rust/test/BUILD
index 9416781..1c8255c 100644
--- a/rust/test/BUILD
+++ b/rust/test/BUILD
@@ -1,10 +1,10 @@
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
load(
"//rust:defs.bzl",
"rust_cc_proto_library",
"rust_proto_library",
"rust_upb_proto_library",
)
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
UNITTEST_PROTO_TARGET = "//src/google/protobuf:test_protos"
UNITTEST_CC_PROTO_TARGET = "//src/google/protobuf:cc_test_protos"
@@ -330,3 +330,12 @@
],
deps = [":nested_proto"],
)
+
+rust_upb_proto_library(
+ name = "map_unittest_upb_rust_proto",
+ testonly = True,
+ visibility = [
+ "//rust/test/shared:__subpackages__",
+ ],
+ deps = ["//src/google/protobuf:map_unittest_proto"],
+)
diff --git a/rust/test/shared/BUILD b/rust/test/shared/BUILD
index 73be626..b5864a3 100644
--- a/rust/test/shared/BUILD
+++ b/rust/test/shared/BUILD
@@ -23,7 +23,7 @@
"//rust:protobuf_upb": "protobuf",
},
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust:protobuf_upb",
],
)
@@ -35,7 +35,7 @@
"//rust:protobuf_cpp": "protobuf",
},
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust:protobuf_cpp",
],
)
@@ -48,7 +48,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust/test:child_upb_rust_proto",
"//rust/test:parent_upb_rust_proto",
],
@@ -62,7 +62,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust/test:child_cc_rust_proto",
"//rust/test:parent_cc_rust_proto",
],
@@ -104,7 +104,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust/test:reserved_cc_rust_proto",
"//rust/test:unittest_cc_rust_proto",
],
@@ -118,7 +118,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust/test:reserved_upb_rust_proto",
"//rust/test:unittest_upb_rust_proto",
],
@@ -159,7 +159,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust:protobuf_cpp",
"//rust/test:unittest_cc_rust_proto",
"//rust/test/shared:matchers_cpp",
@@ -181,7 +181,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust:protobuf_upb",
"//rust/test:unittest_upb_rust_proto",
"//rust/test/shared:matchers_upb",
@@ -200,7 +200,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust:protobuf_cpp",
"//rust/test:unittest_proto3_cc_rust_proto",
"//rust/test:unittest_proto3_optional_cc_rust_proto",
@@ -220,7 +220,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust:protobuf_upb",
"//rust/test:unittest_proto3_optional_upb_rust_proto",
"//rust/test:unittest_proto3_upb_rust_proto",
@@ -236,7 +236,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust/test:unittest_upb_rust_proto",
],
)
@@ -249,7 +249,7 @@
"not_build:arm",
],
deps = [
- "//third_party/gtest_rust/googletest",
+ "@crate_index//:googletest",
"//rust/test:unittest_cc_rust_proto",
],
)
@@ -273,3 +273,15 @@
],
deps = ["//rust/test:nested_upb_rust_proto"],
)
+
+rust_test(
+ name = "accessors_map_upb_test",
+ srcs = ["accessors_map_test.rs"],
+ proc_macro_deps = [
+ "@crate_index//:paste",
+ ],
+ deps = [
+ "@crate_index//:googletest",
+ "//rust/test:map_unittest_upb_rust_proto",
+ ],
+)
diff --git a/rust/test/shared/accessors_map_test.rs b/rust/test/shared/accessors_map_test.rs
new file mode 100644
index 0000000..a401d94
--- /dev/null
+++ b/rust/test/shared/accessors_map_test.rs
@@ -0,0 +1,43 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2023 Google LLC. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+use googletest::prelude::*;
+use map_unittest_proto::proto2_unittest::TestMap;
+use paste::paste;
+
+macro_rules! generate_map_primitives_tests {
+ (
+ $(($k_type:ty, $v_type:ty, $k_field:ident, $v_field:ident)),*
+ ) => {
+ paste! { $(
+ #[test]
+ fn [< test_map_ $k_field _ $v_field >]() {
+ let mut msg = TestMap::new();
+ let k: $k_type = Default::default();
+ let v: $v_type = Default::default();
+ assert_that!(msg.[< map_ $k_field _ $v_field _mut>]().insert(k, v), eq(true));
+ assert_that!(msg.[< map_ $k_field _ $v_field >]().len(), eq(1));
+ }
+ )* }
+ };
+}
+
+generate_map_primitives_tests!(
+ (i32, i32, int32, int32),
+ (i64, i64, int64, int64),
+ (u32, u32, uint32, uint32),
+ (u64, u64, uint64, uint64),
+ (i32, i32, sint32, sint32),
+ (i64, i64, sint64, sint64),
+ (u32, u32, fixed32, fixed32),
+ (u64, u64, fixed64, fixed64),
+ (i32, i32, sfixed32, sfixed32),
+ (i64, i64, sfixed64, sfixed64),
+ (i32, f32, int32, float),
+ (i32, f64, int32, double),
+ (bool, bool, bool, bool)
+);
diff --git a/rust/upb.rs b/rust/upb.rs
index b6fd7b3..bab1a9b 100644
--- a/rust/upb.rs
+++ b/rust/upb.rs
@@ -7,7 +7,7 @@
//! UPB FFI wrapper code for use by Rust Protobuf.
-use crate::__internal::{Private, PtrAndLen, RawArena, RawMessage, RawRepeatedField};
+use crate::__internal::{Private, PtrAndLen, RawArena, RawMap, RawMessage, RawRepeatedField};
use std::alloc;
use std::alloc::Layout;
use std::cell::UnsafeCell;
@@ -324,7 +324,7 @@
// Transcribed from google3/third_party/upb/upb/message/value.h
#[repr(C)]
#[derive(Clone, Copy)]
-union upb_MessageValue {
+pub union upb_MessageValue {
bool_val: bool,
float_val: std::ffi::c_float,
double_val: std::ffi::c_double,
@@ -341,7 +341,7 @@
// Transcribed from google3/third_party/upb/upb/base/descriptor_constants.h
#[repr(C)]
#[allow(dead_code)]
-enum UpbCType {
+pub enum UpbCType {
Bool = 1,
Float = 2,
Int32 = 3,
@@ -435,6 +435,8 @@
/// RepeatedView.
///
/// # Safety
+/// The returned array must never be mutated.
+///
/// TODO: Split RepeatedFieldInner into mut and const variants to
/// enforce safety. The returned array must never be mutated.
pub unsafe fn empty_array() -> RepeatedFieldInner<'static> {
@@ -451,9 +453,205 @@
REPEATED_FIELD.with(|inner| *inner)
}
+/// Returns a static thread-local empty MapInner for use in a
+/// MapView.
+///
+/// # Safety
+/// The returned map must never be mutated.
+///
+/// TODO: Split MapInner into mut and const variants to
+/// enforce safety. The returned array must never be mutated.
+pub unsafe fn empty_map() -> MapInner<'static> {
+ fn new_map_inner() -> MapInner<'static> {
+ // TODO: Consider creating empty map in C.
+ let arena = Box::leak::<'static>(Box::new(Arena::new()));
+ // Provide `i32` as a placeholder type.
+ Map::<'static, i32, i32>::new(arena).inner
+ }
+ thread_local! {
+ static MAP: MapInner<'static> = new_map_inner();
+ }
+
+ MAP.with(|inner| *inner)
+}
+
+#[derive(Clone, Copy, Debug)]
+pub struct MapInner<'msg> {
+ pub raw: RawMap,
+ pub arena: &'msg Arena,
+}
+
+#[derive(Debug)]
+pub struct Map<'msg, K: ?Sized, V: ?Sized> {
+ inner: MapInner<'msg>,
+ _phantom_key: PhantomData<&'msg mut K>,
+ _phantom_value: PhantomData<&'msg mut V>,
+}
+
+// These use manual impls instead of derives to avoid unnecessary bounds on `K`
+// and `V`. This problem is referred to as "perfect derive".
+// https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/
+impl<'msg, K: ?Sized, V: ?Sized> Copy for Map<'msg, K, V> {}
+impl<'msg, K: ?Sized, V: ?Sized> Clone for Map<'msg, K, V> {
+ fn clone(&self) -> Map<'msg, K, V> {
+ *self
+ }
+}
+
+impl<'msg, K: ?Sized, V: ?Sized> Map<'msg, K, V> {
+ pub fn len(&self) -> usize {
+ unsafe { upb_Map_Size(self.inner.raw) }
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+
+ pub fn from_inner(_private: Private, inner: MapInner<'msg>) -> Self {
+ Map { inner, _phantom_key: PhantomData, _phantom_value: PhantomData }
+ }
+
+ pub fn clear(&mut self) {
+ unsafe { upb_Map_Clear(self.inner.raw) }
+ }
+}
+
+/// # Safety
+/// Implementers of this trait must ensure that `pack_message_value` returns
+/// a `upb_MessageValue` with the active variant indicated by `Self`.
+pub unsafe trait MapType {
+ /// # Safety
+ /// The active variant of `outer` must be the `type PrimitiveValue`
+ unsafe fn unpack_message_value(_private: Private, outer: upb_MessageValue) -> Self;
+
+ fn pack_message_value(_private: Private, inner: Self) -> upb_MessageValue;
+
+ fn upb_ctype(_private: Private) -> UpbCType;
+
+ fn zero_value(_private: Private) -> Self;
+}
+
+/// Types implementing this trait can be used as map keys.
+pub trait MapKeyType: MapType {}
+
+/// Types implementing this trait can be used as map values.
+pub trait MapValueType: MapType {}
+
+macro_rules! impl_scalar_map_value_types {
+ ($($type:ty, $union_field:ident, $upb_tag:expr, $zero_val:literal;)*) => {
+ $(
+ unsafe impl MapType for $type {
+ unsafe fn unpack_message_value(_private: Private, outer: upb_MessageValue) -> Self {
+ unsafe { outer.$union_field }
+ }
+
+ fn pack_message_value(_private: Private, inner: Self) -> upb_MessageValue {
+ upb_MessageValue { $union_field: inner }
+ }
+
+ fn upb_ctype(_private: Private) -> UpbCType {
+ $upb_tag
+ }
+
+ fn zero_value(_private: Private) -> Self {
+ $zero_val
+ }
+ }
+
+ impl MapValueType for $type {}
+ )*
+ };
+}
+
+impl_scalar_map_value_types!(
+ f32, float_val, UpbCType::Float, 0f32;
+ f64, double_val, UpbCType::Double, 0f64;
+ i32, int32_val, UpbCType::Int32, 0i32;
+ u32, uint32_val, UpbCType::UInt32, 0u32;
+ i64, int64_val, UpbCType::Int64, 0i64;
+ u64, uint64_val, UpbCType::UInt64, 0u64;
+ bool, bool_val, UpbCType::Bool, false;
+);
+
+macro_rules! impl_scalar_map_key_types {
+ ($($type:ty;)*) => {
+ $(
+ impl MapKeyType for $type {}
+ )*
+ };
+}
+
+impl_scalar_map_key_types!(
+ i32; u32; i64; u64; bool;
+);
+
+impl<'msg, K: MapKeyType, V: MapValueType> Map<'msg, K, V> {
+ pub fn new(arena: &'msg Arena) -> Self {
+ unsafe {
+ let raw_map = upb_Map_New(arena.raw(), K::upb_ctype(Private), V::upb_ctype(Private));
+ Map {
+ inner: MapInner { raw: raw_map, arena },
+ _phantom_key: PhantomData,
+ _phantom_value: PhantomData,
+ }
+ }
+ }
+
+ pub fn get(&self, key: K) -> Option<V> {
+ let mut val = V::pack_message_value(Private, V::zero_value(Private));
+ let found =
+ unsafe { upb_Map_Get(self.inner.raw, K::pack_message_value(Private, key), &mut val) };
+ if !found {
+ return None;
+ }
+ Some(unsafe { V::unpack_message_value(Private, val) })
+ }
+
+ pub fn insert(&mut self, key: K, value: V) -> bool {
+ unsafe {
+ upb_Map_Set(
+ self.inner.raw,
+ K::pack_message_value(Private, key),
+ V::pack_message_value(Private, value),
+ self.inner.arena.raw(),
+ )
+ }
+ }
+
+ pub fn remove(&mut self, key: K) -> Option<V> {
+ let mut val = V::pack_message_value(Private, V::zero_value(Private));
+ let removed = unsafe {
+ upb_Map_Delete(self.inner.raw, K::pack_message_value(Private, key), &mut val)
+ };
+ if !removed {
+ return None;
+ }
+ Some(unsafe { V::unpack_message_value(Private, val) })
+ }
+}
+
+extern "C" {
+ fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap;
+ fn upb_Map_Size(map: RawMap) -> usize;
+ fn upb_Map_Set(
+ map: RawMap,
+ key: upb_MessageValue,
+ value: upb_MessageValue,
+ arena: RawArena,
+ ) -> bool;
+ fn upb_Map_Get(map: RawMap, key: upb_MessageValue, value: *mut upb_MessageValue) -> bool;
+ fn upb_Map_Delete(
+ map: RawMap,
+ key: upb_MessageValue,
+ removed_value: *mut upb_MessageValue,
+ ) -> bool;
+ fn upb_Map_Clear(map: RawMap);
+}
+
#[cfg(test)]
mod tests {
use super::*;
+ use googletest::prelude::*;
#[test]
fn test_arena_new_and_free() {
@@ -507,4 +705,46 @@
assert_eq!(arr.get(arr.len() - 1), Some(i));
}
}
+
+ #[test]
+ fn i32_i32_map() {
+ let arena = Arena::new();
+ let mut map = Map::<'_, i32, i32>::new(&arena);
+ assert_that!(map.len(), eq(0));
+
+ assert_that!(map.insert(1, 2), eq(true));
+ assert_that!(map.get(1), eq(Some(2)));
+ assert_that!(map.get(3), eq(None));
+ assert_that!(map.len(), eq(1));
+
+ assert_that!(map.remove(1), eq(Some(2)));
+ assert_that!(map.len(), eq(0));
+ assert_that!(map.remove(1), eq(None));
+
+ assert_that!(map.insert(4, 5), eq(true));
+ assert_that!(map.insert(6, 7), eq(true));
+ map.clear();
+ assert_that!(map.len(), eq(0));
+ }
+
+ #[test]
+ fn i64_f64_map() {
+ let arena = Arena::new();
+ let mut map = Map::<'_, i64, f64>::new(&arena);
+ assert_that!(map.len(), eq(0));
+
+ assert_that!(map.insert(1, 2.5), eq(true));
+ assert_that!(map.get(1), eq(Some(2.5)));
+ assert_that!(map.get(3), eq(None));
+ assert_that!(map.len(), eq(1));
+
+ assert_that!(map.remove(1), eq(Some(2.5)));
+ assert_that!(map.len(), eq(0));
+ assert_that!(map.remove(1), eq(None));
+
+ assert_that!(map.insert(4, 5.1), eq(true));
+ assert_that!(map.insert(6, 7.2), eq(true));
+ map.clear();
+ assert_that!(map.len(), eq(0));
+ }
}
diff --git a/rust/upb_kernel/upb_api.c b/rust/upb_kernel/upb_api.c
index a30b4dd..e2a67eb 100644
--- a/rust/upb_kernel/upb_api.c
+++ b/rust/upb_kernel/upb_api.c
@@ -8,5 +8,6 @@
#define UPB_BUILD_API
+#include "upb/collections/map.h" // IWYU pragma: keep
#include "upb/collections/array.h" // IWYU pragma: keep
#include "upb/mem/arena.h" // IWYU pragma: keep