Add support for HasExtension/GetExtension without ExtensionRegistry.
Update signatures to use T* Ptr<T> consistently.
Cross language blocks utility updated to use GetArena(T*)
Fixes arena_ for Proxy/CProxy ::Access class now consistently fills in arena_ (where message was created in).
PiperOrigin-RevId: 543457599
diff --git a/protos/BUILD b/protos/BUILD
index be62cb7..e6b1a7f 100644
--- a/protos/BUILD
+++ b/protos/BUILD
@@ -71,6 +71,7 @@
visibility = ["//visibility:public"],
deps = [
"//:message_copy",
+ "//:message_promote",
"//:mini_table",
"//:upb",
"@com_google_absl//absl/status",
diff --git a/protos/protos.cc b/protos/protos.cc
index 7a651c4..229ce4c 100644
--- a/protos/protos.cc
+++ b/protos/protos.cc
@@ -28,6 +28,8 @@
#include "protos/protos.h"
#include "absl/strings/str_format.h"
+#include "upb/message/promote.h"
+#include "upb/wire/common.h"
namespace protos {
@@ -90,6 +92,27 @@
return extension_registry.registry_;
}
+bool HasExtensionOrUnknown(const upb_Message* msg,
+ const upb_MiniTableExtension* eid) {
+ return _upb_Message_Getext(msg, eid) != nullptr ||
+ upb_MiniTable_FindUnknown(msg, eid->field.number,
+ kUpb_WireFormat_DefaultDepthLimit)
+ .status == kUpb_FindUnknown_Ok;
+}
+
+const upb_Message_Extension* GetOrPromoteExtension(
+ upb_Message* msg, const upb_MiniTableExtension* eid, upb_Arena* arena) {
+ const upb_Message_Extension* ext = _upb_Message_Getext(msg, eid);
+ if (ext == nullptr) {
+ upb_GetExtension_Status ext_status = upb_MiniTable_GetOrPromoteExtension(
+ (upb_Message*)msg, eid, kUpb_WireFormat_DefaultDepthLimit, arena, &ext);
+ if (ext_status != kUpb_GetExtension_Ok) {
+ return nullptr;
+ }
+ }
+ return ext;
+}
+
absl::StatusOr<absl::string_view> Serialize(const upb_Message* message,
const upb_MiniTable* mini_table,
upb_Arena* arena, int options) {
diff --git a/protos/protos.h b/protos/protos.h
index abeb427..43b9885 100644
--- a/protos/protos.h
+++ b/protos/protos.h
@@ -33,6 +33,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
+#include "upb/mem/arena.h"
#include "upb/message/copy.h"
#include "upb/message/extension_internal.h"
#include "upb/mini_table/types.h"
@@ -195,8 +196,8 @@
}
template <typename T>
-typename T::CProxy CreateMessage(upb_Message* msg) {
- return typename T::CProxy(msg);
+typename T::CProxy CreateMessage(upb_Message* msg, upb_Arena* arena) {
+ return typename T::CProxy(msg, arena);
}
class ExtensionMiniTableProvider {
@@ -242,12 +243,12 @@
}
template <typename T>
-upb_Arena* GetArena(const T& message) {
- return static_cast<upb_Arena*>(message.GetInternalArena());
+upb_Arena* GetArena(Ptr<T> message) {
+ return static_cast<upb_Arena*>(message->GetInternalArena());
}
template <typename T>
-upb_Arena* GetArena(Ptr<T> message) {
+upb_Arena* GetArena(T* message) {
return static_cast<upb_Arena*>(message->GetInternalArena());
}
@@ -268,6 +269,12 @@
const upb_MiniTable* mini_table,
upb_Arena* arena, int options);
+bool HasExtensionOrUnknown(const upb_Message* msg,
+ const upb_MiniTableExtension* eid);
+
+const upb_Message_Extension* GetOrPromoteExtension(
+ upb_Message* msg, const upb_MiniTableExtension* eid, upb_Arena* arena);
+
} // namespace internal
class ExtensionRegistry {
@@ -306,17 +313,18 @@
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
bool HasExtension(
- const T& message,
+ const Ptr<T>& message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
- return _upb_Message_Getext(message.msg(), id.mini_table_ext()) != nullptr;
+ return ::protos::internal::HasExtensionOrUnknown(
+ ::protos::internal::GetInternalMsg(message), id.mini_table_ext());
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
bool HasExtension(
- const Ptr<T>& message,
+ const T* message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
- return _upb_Message_Getext(message->msg(), id.mini_table_ext()) != nullptr;
+ return HasExtension(protos::Ptr(message), id);
}
template <typename T, typename Extendee, typename Extension,
@@ -324,35 +332,17 @@
void ClearExtension(
const Ptr<T>& message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
- _upb_Message_ClearExtensionField(message->msg(), id.mini_table_ext());
+ static_assert(!std::is_const_v<T>, "");
+ _upb_Message_ClearExtensionField(::protos::internal::GetInternalMsg(message),
+ id.mini_table_ext());
}
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
void ClearExtension(
- const T& message,
+ T* message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
- _upb_Message_ClearExtensionField(message.msg(), id.mini_table_ext());
-}
-
-template <typename T, typename Extendee, typename Extension,
- typename = EnableIfProtosClass<T>>
-absl::Status SetExtension(
- const T& message,
- const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
- Extension& value) {
- auto* message_arena = static_cast<upb_Arena*>(message.GetInternalArena());
- upb_Message_Extension* msg_ext = _upb_Message_GetOrCreateExtension(
- message.msg(), id.mini_table_ext(), message_arena);
- if (!msg_ext) {
- return MessageAllocationError();
- }
- auto* extension_arena = static_cast<upb_Arena*>(value.GetInternalArena());
- if (message_arena != extension_arena) {
- upb_Arena_Fuse(message_arena, extension_arena);
- }
- msg_ext->data.ptr = value.msg();
- return absl::OkStatus();
+ ClearExtension(::protos::Ptr(message), id);
}
template <typename T, typename Extendee, typename Extension,
@@ -377,16 +367,11 @@
template <typename T, typename Extendee, typename Extension,
typename = EnableIfProtosClass<T>>
-absl::StatusOr<Ptr<const Extension>> GetExtension(
- const T& message,
- const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
- const upb_Message_Extension* ext =
- _upb_Message_Getext(message.msg(), id.mini_table_ext());
- if (!ext) {
- return ExtensionNotFoundError(id.mini_table_ext()->field.number);
- }
- return Ptr<const Extension>(
- ::protos::internal::CreateMessage<Extension>(ext->data.ptr));
+absl::Status SetExtension(
+ T* message,
+ const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
+ Extension& value) {
+ return ::protos::SetExtension(::protos::Ptr(message), id, value);
}
template <typename T, typename Extendee, typename Extension,
@@ -394,34 +379,22 @@
absl::StatusOr<Ptr<const Extension>> GetExtension(
const Ptr<T>& message,
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
- const upb_Message_Extension* ext =
- _upb_Message_Getext(message->msg(), id.mini_table_ext());
+ const upb_Message_Extension* ext = ::protos::internal::GetOrPromoteExtension(
+ ::protos::internal::GetInternalMsg(message), id.mini_table_ext(),
+ ::protos::internal::GetArena(message));
if (!ext) {
return ExtensionNotFoundError(id.mini_table_ext()->field.number);
}
- return Ptr<const Extension>(
- ::protos::internal::CreateMessage<Extension>(ext->data.ptr));
+ return Ptr<const Extension>(::protos::internal::CreateMessage<Extension>(
+ ext->data.ptr, ::protos::internal::GetArena(message)));
}
-template <typename T>
-bool Parse(T& message, absl::string_view bytes) {
- upb_Message_Clear(message.msg(), ::protos::internal::GetMiniTable(&message));
- auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
- return upb_Decode(bytes.data(), bytes.size(), message.msg(), T::minitable(),
- /* extreg= */ nullptr, /* options= */ 0,
- arena) == kUpb_DecodeStatus_Ok;
-}
-
-template <typename T>
-bool Parse(T& message, absl::string_view bytes,
- const ::protos::ExtensionRegistry& extension_registry) {
- upb_Message_Clear(message.msg(), ::protos::internal::GetMiniTable(message));
- auto* arena = static_cast<upb_Arena*>(message.GetInternalArena());
- return upb_Decode(bytes.data(), bytes.size(), message.msg(),
- ::protos::internal::GetMiniTable(message),
- /* extreg= */
- ::protos::internal::GetUpbExtensions(extension_registry),
- /* options= */ 0, arena) == kUpb_DecodeStatus_Ok;
+template <typename T, typename Extendee, typename Extension,
+ typename = EnableIfProtosClass<T>>
+absl::StatusOr<Ptr<const Extension>> GetExtension(
+ const T* message,
+ const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
+ return GetExtension(protos::Ptr(message), id);
}
template <typename T>
@@ -447,6 +420,12 @@
}
template <typename T>
+bool Parse(T* message, absl::string_view bytes,
+ const ::protos::ExtensionRegistry& extension_registry) {
+ return Parse(protos::Ptr(message, bytes, extension_registry));
+}
+
+template <typename T>
bool Parse(T* message, absl::string_view bytes) {
upb_Message_Clear(message->msg(), ::protos::internal::GetMiniTable(message));
auto* arena = static_cast<upb_Arena*>(message->GetInternalArena());
diff --git a/protos/repeated_field.h b/protos/repeated_field.h
index 3bdb20d..6233109 100644
--- a/protos/repeated_field.h
+++ b/protos/repeated_field.h
@@ -64,7 +64,8 @@
using Array = add_const_if_T_is_const<T, upb_Array>;
public:
- explicit RepeatedFieldProxyBase(Array* arr) : arr_(arr) {}
+ explicit RepeatedFieldProxyBase(Array* arr, upb_Arena* arena)
+ : arr_(arr), arena_(arena) {}
size_t size() const { return arr_ != nullptr ? upb_Array_Size(arr_) : 0; }
@@ -78,6 +79,7 @@
inline upb_Message* GetMessage(size_t n) const;
Array* arr_;
+ upb_Arena* arena_;
};
template <class T>
@@ -98,12 +100,9 @@
class RepeatedFieldProxyMutableBase : public RepeatedFieldProxyBase<T> {
public:
RepeatedFieldProxyMutableBase(upb_Array* arr, upb_Arena* arena)
- : RepeatedFieldProxyBase<T>(arr), arena_(arena) {}
+ : RepeatedFieldProxyBase<T>(arr, arena) {}
- void clear() { upb_Array_Resize(this->arr_, 0, arena_); }
-
- protected:
- upb_Arena* arena_;
+ void clear() { upb_Array_Resize(this->arr_, 0, this->arena_); }
};
// RepeatedField proxy for repeated messages.
@@ -117,8 +116,8 @@
static constexpr bool kIsConst = std::is_const_v<T>;
public:
- explicit RepeatedFieldProxy(const upb_Array* arr)
- : RepeatedFieldProxyBase<T>(arr) {}
+ explicit RepeatedFieldProxy(const upb_Array* arr, upb_Arena* arena)
+ : RepeatedFieldProxyBase<T>(arr, arena) {}
RepeatedFieldProxy(upb_Array* arr, upb_Arena* arena)
: RepeatedFieldProxyMutableBase<T>(arr, arena) {}
// Constructor used by ::protos::Ptr.
@@ -128,7 +127,7 @@
typename T::CProxy operator[](size_t n) const {
upb_MessageValue message_value = upb_Array_Get(this->arr_, n);
return ::protos::internal::CreateMessage<typename std::remove_const_t<T>>(
- (upb_Message*)message_value.msg_val);
+ (upb_Message*)message_value.msg_val, this->arena_);
}
// TODO(b:/280069986) : Audit/Finalize based on Iterator Design.
@@ -156,7 +155,7 @@
void push_back(T&& msg) {
upb_MessageValue message_value;
message_value.msg_val = GetInternalMsg(&msg);
- upb_Arena_Fuse(GetArena(msg), this->arena_);
+ upb_Arena_Fuse(GetArena(&msg), this->arena_);
upb_Array_Append(this->arr_, message_value, this->arena_);
T moved_msg = std::move(msg);
}
@@ -177,8 +176,8 @@
public:
// Immutable constructor.
- explicit RepeatedFieldStringProxy(const upb_Array* arr)
- : RepeatedFieldProxyBase<T>(arr) {}
+ explicit RepeatedFieldStringProxy(const upb_Array* arr, upb_Arena* arena)
+ : RepeatedFieldProxyBase<T>(arr, arena) {}
// Mutable constructor.
RepeatedFieldStringProxy(upb_Array* arr, upb_Arena* arena)
: RepeatedFieldProxyMutableBase<T>(arr, arena) {}
@@ -210,8 +209,8 @@
static constexpr bool kIsConst = std::is_const_v<T>;
public:
- explicit RepeatedFieldScalarProxy(const upb_Array* arr)
- : RepeatedFieldProxyBase<T>(arr) {}
+ explicit RepeatedFieldScalarProxy(const upb_Array* arr, upb_Arena* arena)
+ : RepeatedFieldProxyBase<T>(arr, arena) {}
RepeatedFieldScalarProxy(upb_Array* arr, upb_Arena* arena)
: RepeatedFieldProxyMutableBase<T>(arr, arena) {}
// Constructor used by ::protos::Ptr.
diff --git a/protos_generator/gen_accessors.cc b/protos_generator/gen_accessors.cc
index f7324e8..1b4f9ed 100644
--- a/protos_generator/gen_accessors.cc
+++ b/protos_generator/gen_accessors.cc
@@ -262,7 +262,8 @@
if (!has_$2()) {
return $4::default_instance();
}
- return ::protos::internal::CreateMessage<$4>((upb_Message*)($3_$5(msg_)));
+ return ::protos::internal::CreateMessage<$4>(
+ (upb_Message*)($3_$5(msg_)), arena_);
}
)cc",
class_name, MessagePtrConstType(field, /* is_const */ true),
@@ -338,7 +339,7 @@
$5* msg_value;
$7bool success = $4_$9_get(msg_, $8, &msg_value);
if (success) {
- return ::protos::internal::CreateMessage<$6>(msg_value);
+ return ::protos::internal::CreateMessage<$6>(msg_value, arena_);
}
return absl::NotFoundError("");
}
diff --git a/protos_generator/gen_messages.cc b/protos_generator/gen_messages.cc
index d8cc569..669d122 100644
--- a/protos_generator/gen_messages.cc
+++ b/protos_generator/gen_messages.cc
@@ -111,9 +111,13 @@
class $0Access {
public:
$0Access() {}
- $0Access($1* msg, upb_Arena* arena) : msg_(msg), arena_(arena) {} // NOLINT
+ $0Access($1* msg, upb_Arena* arena) : msg_(msg), arena_(arena) {
+ assert(arena != nullptr);
+ } // NOLINT
$0Access(const $1* msg, upb_Arena* arena)
- : msg_(const_cast<$1*>(msg)), arena_(arena) {} // NOLINT
+ : msg_(const_cast<$1*>(msg)), arena_(arena) {
+ assert(arena != nullptr);
+ } // NOLINT
void* GetInternalArena() const { return arena_; }
)cc",
ClassName(descriptor), MessageName(descriptor));
@@ -222,7 +226,7 @@
absl::string_view bytes,
const ::protos::ExtensionRegistry& extension_registry,
int options));
- friend upb_Arena* ::protos::internal::GetArena<$0>(const $0& message);
+ friend upb_Arena* ::protos::internal::GetArena<$0>($0* message);
friend upb_Arena* ::protos::internal::GetArena<$0>(::protos::Ptr<$0> message);
friend $0(::protos::internal::MoveMessage<$0>(upb_Message* msg,
upb_Arena* arena));
@@ -279,7 +283,7 @@
const $0Proxy* message);
friend const upb_MiniTable* ::protos::internal::GetMiniTable<$0Proxy>(
::protos::Ptr<$0Proxy> message);
- friend upb_Arena* ::protos::internal::GetArena<$2>(const $2& message);
+ friend upb_Arena* ::protos::internal::GetArena<$2>($2* message);
friend upb_Arena* ::protos::internal::GetArena<$2>(::protos::Ptr<$2> message);
friend $0Proxy(::protos::CloneMessage(::protos::Ptr<$2> message,
::upb::Arena& arena));
@@ -302,7 +306,8 @@
class $0CProxy final : private internal::$0Access {
public:
$0CProxy() = delete;
- $0CProxy(const $0* m) : internal::$0Access(m->msg_, nullptr) {}
+ $0CProxy(const $0* m)
+ : internal::$0Access(m->msg_, ::protos::internal::GetArena(m)) {}
$0CProxy($0Proxy m);
using $0Access::GetInternalArena;
)cc",
@@ -315,8 +320,9 @@
output(
R"cc(
private:
- $0CProxy(void* msg) : internal::$0Access(($1*)msg, nullptr){};
- friend $0::CProxy(::protos::internal::CreateMessage<$0>(upb_Message* msg));
+ $0CProxy(void* msg, upb_Arena* arena) : internal::$0Access(($1*)msg, arena){};
+ friend $0::CProxy(::protos::internal::CreateMessage<$0>(
+ upb_Message* msg, upb_Arena* arena));
friend class RepeatedFieldProxy;
friend class ::protos::Ptr<$0>;
friend class ::protos::Ptr<const $0>;
@@ -390,9 +396,13 @@
R"cc(
struct $0DefaultTypeInternal {
$1* msg;
+ upb_Arena* arena;
};
- $0DefaultTypeInternal _$0_default_instance_ =
- $0DefaultTypeInternal{$1_new(upb_Arena_New())};
+ static $0DefaultTypeInternal _$0DefaultTypeBuilder() {
+ upb_Arena* arena = upb_Arena_New();
+ return $0DefaultTypeInternal{$1_new(arena), arena};
+ }
+ $0DefaultTypeInternal _$0_default_instance_ = _$0DefaultTypeBuilder();
)cc",
ClassName(descriptor), MessageName(descriptor));
@@ -400,7 +410,8 @@
R"cc(
::protos::Ptr<const $0> $0::default_instance() {
return ::protos::internal::CreateMessage<$0>(
- (upb_Message *)_$0_default_instance_.msg);
+ (upb_Message *)_$0_default_instance_.msg,
+ _$0_default_instance_.arena);
}
)cc",
ClassName(descriptor));
diff --git a/protos_generator/gen_repeated_fields.cc b/protos_generator/gen_repeated_fields.cc
index 48d77c1..15219f8 100644
--- a/protos_generator/gen_repeated_fields.cc
+++ b/protos_generator/gen_repeated_fields.cc
@@ -153,7 +153,8 @@
size_t len;
auto* ptr = $3_$5(msg_, &len);
assert(index < len);
- return ::protos::internal::CreateMessage<$4>((upb_Message*)*(ptr + index));
+ return ::protos::internal::CreateMessage<$4>(
+ (upb_Message*)*(ptr + index), arena_);
}
)cc",
class_name, MessagePtrConstType(field, /* is_const */ true),
@@ -192,7 +193,7 @@
const ::protos::RepeatedField<const $1>::CProxy $0::$2() const {
size_t size;
const upb_Array* arr = _$3_$4_$5(msg_, &size);
- return ::protos::RepeatedField<const $1>::CProxy(arr);
+ return ::protos::RepeatedField<const $1>::CProxy(arr, arena_);
};
::protos::Ptr<::protos::RepeatedField<$1>> $0::mutable_$2() {
size_t size;
@@ -258,7 +259,7 @@
const ::protos::RepeatedField<$1>::CProxy $0::$2() const {
size_t size;
const upb_Array* arr = _$3_$4_$5(msg_, &size);
- return ::protos::RepeatedField<$1>::CProxy(arr);
+ return ::protos::RepeatedField<$1>::CProxy(arr, arena_);
};
::protos::Ptr<::protos::RepeatedField<$1>> $0::mutable_$2() {
size_t size;
@@ -322,7 +323,7 @@
const ::protos::RepeatedField<$1>::CProxy $0::$2() const {
size_t size;
const upb_Array* arr = _$3_$4_$5(msg_, &size);
- return ::protos::RepeatedField<$1>::CProxy(arr);
+ return ::protos::RepeatedField<$1>::CProxy(arr, arena_);
};
::protos::Ptr<::protos::RepeatedField<$1>> $0::mutable_$2() {
size_t size;
diff --git a/protos_generator/tests/test_generated.cc b/protos_generator/tests/test_generated.cc
index f0406ad..bfeffff 100644
--- a/protos_generator/tests/test_generated.cc
+++ b/protos_generator/tests/test_generated.cc
@@ -25,6 +25,7 @@
#include <limits>
#include <memory>
+#include <string>
#include <utility>
#include "gtest/gtest.h"
@@ -626,7 +627,7 @@
TEST(CppGeneratedCode, HasExtension) {
TestModel model;
- EXPECT_EQ(false, ::protos::HasExtension(model, theme));
+ EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
}
TEST(CppGeneratedCode, HasExtensionPtr) {
@@ -636,9 +637,9 @@
TEST(CppGeneratedCode, ClearExtensionWithEmptyExtension) {
TestModel model;
- EXPECT_EQ(false, ::protos::HasExtension(model, theme));
- ::protos::ClearExtension(model, theme);
- EXPECT_EQ(false, ::protos::HasExtension(model, theme));
+ EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
+ ::protos::ClearExtension(&model, theme);
+ EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
}
TEST(CppGeneratedCode, ClearExtensionWithEmptyExtensionPtr) {
@@ -652,9 +653,9 @@
TestModel model;
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
- EXPECT_EQ(false, ::protos::HasExtension(model, theme));
- EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok());
- EXPECT_EQ(true, ::protos::HasExtension(model, theme));
+ EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
+ EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok());
+ EXPECT_EQ(true, ::protos::HasExtension(&model, theme));
}
TEST(CppGeneratedCode, SetExtensionOnMutableChild) {
@@ -674,10 +675,10 @@
TestModel model;
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
- EXPECT_EQ(false, ::protos::HasExtension(model, theme));
- EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok());
+ EXPECT_EQ(false, ::protos::HasExtension(&model, theme));
+ EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok());
EXPECT_EQ("Hello World",
- ::protos::GetExtension(model, theme).value()->ext_name());
+ ::protos::GetExtension(&model, theme).value()->ext_name());
}
TEST(CppGeneratedCode, GetExtensionOnMutableChild) {
@@ -750,14 +751,13 @@
model.set_str1("Test123");
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
- EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok());
+ EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok());
::upb::Arena arena;
auto bytes = ::protos::Serialize(&model, arena);
EXPECT_EQ(true, bytes.ok());
TestModel parsed_model = ::protos::Parse<TestModel>(bytes.value()).value();
EXPECT_EQ("Test123", parsed_model.str1());
- // Should not return an extension since we did not pass ExtensionRegistry.
- EXPECT_EQ(false, ::protos::GetExtension(parsed_model, theme).ok());
+ EXPECT_EQ(true, ::protos::GetExtension(&parsed_model, theme).ok());
}
TEST(CppGeneratedCode, ParseIntoPtrToModel) {
@@ -765,7 +765,7 @@
model.set_str1("Test123");
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
- EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok());
+ EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok());
::upb::Arena arena;
auto bytes = ::protos::Serialize(&model, arena);
EXPECT_EQ(true, bytes.ok());
@@ -773,8 +773,9 @@
::protos::CreateMessage<TestModel>(arena);
EXPECT_TRUE(::protos::Parse(parsed_model, bytes.value()));
EXPECT_EQ("Test123", parsed_model->str1());
- // Should not return an extension since we did not pass ExtensionRegistry.
- EXPECT_EQ(false, ::protos::GetExtension(parsed_model, theme).ok());
+ // Should return an extension even if we don't pass ExtensionRegistry
+ // by promoting unknown.
+ EXPECT_EQ(true, ::protos::GetExtension(parsed_model, theme).ok());
}
TEST(CppGeneratedCode, ParseWithExtensionRegistry) {
@@ -782,9 +783,9 @@
model.set_str1("Test123");
ThemeExtension extension1;
extension1.set_ext_name("Hello World");
- EXPECT_EQ(true, ::protos::SetExtension(model, theme, extension1).ok());
- EXPECT_EQ(true, ::protos::SetExtension(model, ThemeExtension::theme_extension,
- extension1)
+ EXPECT_EQ(true, ::protos::SetExtension(&model, theme, extension1).ok());
+ EXPECT_EQ(true, ::protos::SetExtension(
+ &model, ThemeExtension::theme_extension, extension1)
.ok());
::upb::Arena arena;
auto bytes = ::protos::Serialize(&model, arena);
@@ -794,12 +795,12 @@
TestModel parsed_model =
::protos::Parse<TestModel>(bytes.value(), extensions).value();
EXPECT_EQ("Test123", parsed_model.str1());
- EXPECT_EQ(true, ::protos::GetExtension(parsed_model, theme).ok());
- EXPECT_EQ(true, ::protos::GetExtension(parsed_model,
+ EXPECT_EQ(true, ::protos::GetExtension(&parsed_model, theme).ok());
+ EXPECT_EQ(true, ::protos::GetExtension(&parsed_model,
ThemeExtension::theme_extension)
.ok());
EXPECT_EQ("Hello World", ::protos::GetExtension(
- parsed_model, ThemeExtension::theme_extension)
+ &parsed_model, ThemeExtension::theme_extension)
.value()
->ext_name());
}
@@ -898,7 +899,7 @@
new_child->set_child_str1("text in child");
ThemeExtension extension1;
extension1.set_ext_name("name in extension");
- EXPECT_TRUE(::protos::SetExtension(model, theme, extension1).ok());
+ EXPECT_TRUE(::protos::SetExtension(&model, theme, extension1).ok());
EXPECT_TRUE(model.mutable_child_model_1()->has_child_str1());
// Clear using Ptr<T>
::protos::ClearMessage(model.mutable_child_model_1());
@@ -915,14 +916,14 @@
new_child.value()->set_child_str1("text in child");
ThemeExtension extension1;
extension1.set_ext_name("name in extension");
- EXPECT_TRUE(::protos::SetExtension(model, theme, extension1).ok());
+ EXPECT_TRUE(::protos::SetExtension(&model, theme, extension1).ok());
// Clear using T*
::protos::ClearMessage(&model);
// Verify that scalars, repeated fields and extensions are cleared.
EXPECT_FALSE(model.has_int64());
EXPECT_FALSE(model.has_str2());
EXPECT_TRUE(model.child_models().empty());
- EXPECT_FALSE(::protos::HasExtension(model, theme));
+ EXPECT_FALSE(::protos::HasExtension(&model, theme));
}
TEST(CppGeneratedCode, DeepCopy) {
@@ -935,13 +936,35 @@
new_child.value()->set_child_str1("text in child");
ThemeExtension extension1;
extension1.set_ext_name("name in extension");
- EXPECT_TRUE(::protos::SetExtension(model, theme, extension1).ok());
+ EXPECT_TRUE(::protos::SetExtension(&model, theme, extension1).ok());
TestModel target;
target.set_b1(true);
::protos::DeepCopy(&model, &target);
- EXPECT_FALSE(target.b1()) << "Target was not cleared before copying content";
+ EXPECT_FALSE(target.b1()) << "Target was not cleared before copying content ";
EXPECT_EQ(target.str2(), "Hello");
- EXPECT_TRUE(::protos::HasExtension(target, theme));
+ EXPECT_TRUE(::protos::HasExtension(&target, theme));
+}
+
+TEST(CppGeneratedCode, HasExtensionAndRegistry) {
+ // Fill model.
+ TestModel source;
+ source.set_int64(5);
+ source.set_str2("Hello");
+ auto new_child = source.add_child_models();
+ ASSERT_TRUE(new_child.ok());
+ new_child.value()->set_child_str1("text in child");
+ ThemeExtension extension1;
+ extension1.set_ext_name("name in extension");
+ ASSERT_TRUE(::protos::SetExtension(&source, theme, extension1).ok());
+
+ // Now that we have a source model with extension data, serialize.
+ ::protos::Arena arena;
+ std::string data = std::string(::protos::Serialize(&source, arena).value());
+
+ // Test with ExtensionRegistry
+ ::protos::ExtensionRegistry extensions({&theme}, arena);
+ TestModel parsed_model = ::protos::Parse<TestModel>(data, extensions).value();
+ EXPECT_TRUE(::protos::HasExtension(&parsed_model, theme));
}
// TODO(b/288491350) : Add BUILD rule to test failures below.