| #include "google/protobuf/reflection_visit_fields.h" |
| |
| #include <cstddef> |
| #include <cstdint> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| #include "absl/log/absl_check.h" |
| #include "absl/strings/cord.h" |
| #include "absl/strings/string_view.h" |
| #include "google/protobuf/arena.h" |
| #include "google/protobuf/io/coded_stream.h" |
| #include "google/protobuf/map_test_util.h" |
| #include "google/protobuf/map_unittest.pb.h" |
| #include "google/protobuf/test_util.h" |
| #include "google/protobuf/unittest.pb.h" |
| #include "google/protobuf/unittest_mset.pb.h" |
| #include "google/protobuf/unittest_mset_wire_format.pb.h" |
| #include "google/protobuf/wire_format.h" |
| #include "google/protobuf/wire_format_lite.h" |
| |
| namespace google { |
| namespace protobuf { |
| namespace internal { |
| namespace { |
| |
| #ifdef __cpp_if_constexpr |
| |
| using ::protobuf_unittest::NestedTestAllTypes; |
| using ::protobuf_unittest::TestAllExtensions; |
| using ::protobuf_unittest::TestAllTypes; |
| using ::protobuf_unittest::TestMap; |
| using ::protobuf_unittest::TestOneof2; |
| using ::protobuf_unittest::TestPackedExtensions; |
| using ::protobuf_unittest::TestPackedTypes; |
| using ::proto2_wireformat_unittest::TestMessageSet; |
| |
| struct TestParam { |
| absl::string_view name; |
| Message* (*create_message)(Arena& arena); |
| }; |
| |
| class VisitFieldsTest : public testing::TestWithParam<TestParam> { |
| public: |
| VisitFieldsTest() { |
| message_ = GetParam().create_message(arena_); |
| reflection_ = message_->GetReflection(); |
| } |
| |
| protected: |
| Arena arena_; |
| Message* message_; |
| const Reflection* reflection_; |
| }; |
| |
| TEST_P(VisitFieldsTest, VisitedFieldsCountMatchesListFields) { |
| uint32_t count = 0; |
| VisitFields(*message_, [&](auto info) { ++count; }); |
| |
| std::vector<const FieldDescriptor*> fields; |
| reflection_->ListFields(*message_, &fields); |
| |
| EXPECT_EQ(count, fields.size()); |
| } |
| |
| TEST_P(VisitFieldsTest, VisitedMessageFieldsCountMatchesListFields) { |
| uint32_t count = 0; |
| VisitFields(*message_, [&](auto info) { ++count; }, FieldMask::kMessage); |
| |
| std::vector<const FieldDescriptor*> fields; |
| reflection_->ListFields(*message_, &fields); |
| int message_count = 0; |
| for (auto field : fields) { |
| if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) ++message_count; |
| } |
| |
| EXPECT_EQ(count, message_count); |
| } |
| |
| // Counts present message fields using ListFields() where: |
| // --N elements in a repeated message field are counted N times |
| // --M message values in a map field are counted M times. |
| // --A map field whose value type is not message is ignored. |
| int CountAllMessageFieldsViaListFields(const Reflection* reflection, |
| const Message& message) { |
| std::vector<const FieldDescriptor*> fields; |
| reflection->ListFields(message, &fields); |
| |
| int message_count = 0; |
| for (auto field : fields) { |
| if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue; |
| if (field->is_map()) { |
| if (field->message_type()->map_value()->cpp_type() != |
| FieldDescriptor::CPPTYPE_MESSAGE) |
| continue; |
| } |
| if (field->is_repeated()) { |
| message_count += reflection->FieldSize(message, field); |
| } else { |
| ++message_count; |
| } |
| } |
| return message_count; |
| } |
| |
| TEST_P(VisitFieldsTest, VisitMessageFieldsCountIncludesRepeatedElements) { |
| int count = 0; |
| VisitMessageFields(*message_, [&](const Message& msg) { ++count; }); |
| |
| EXPECT_EQ(count, CountAllMessageFieldsViaListFields(reflection_, *message_)); |
| } |
| |
| TEST_P(VisitFieldsTest, |
| VisitMutableMessageFieldsCountIncludesRepeatedElements) { |
| int count = 0; |
| VisitMutableMessageFields(*message_, [&](Message& msg) { ++count; }); |
| |
| EXPECT_EQ(count, CountAllMessageFieldsViaListFields(reflection_, *message_)); |
| } |
| |
| TEST_P(VisitFieldsTest, ClearByVisitFieldsMustBeEmpty) { |
| VisitFields(*message_, [](auto info) { info.Clear(); }); |
| |
| EXPECT_EQ(message_->ByteSizeLong(), 0); |
| } |
| |
| TEST_P(VisitFieldsTest, ClearByVisitFieldsRevisitNone) { |
| VisitFields(*message_, [](auto info) { info.Clear(); }); |
| |
| uint32_t count = 0; |
| VisitFields(*message_, [&](auto info) { ++count; }); |
| |
| EXPECT_EQ(count, 0); |
| } |
| |
| void MutateNothingByVisit(Message& message) { |
| VisitFields(message, [&](auto info) { |
| if constexpr (info.is_map) { |
| info.VisitElements([&](auto key, auto val) { |
| if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) { |
| auto* msg = val.Mutable(); |
| (void)msg->ByteSizeLong(); |
| } else { |
| val.Set(val.Get()); |
| } |
| }); |
| } else if constexpr (info.is_repeated) { |
| if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) { |
| size_t byte_size = 0; |
| for (auto& it : info.Mutable()) { |
| byte_size += it.ByteSizeLong(); |
| } |
| } else if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) { |
| if constexpr (info.is_cord) { |
| for (auto& it : info.Mutable()) { |
| absl::Cord tmp = it; |
| it = tmp; |
| } |
| } else if constexpr (info.is_string_piece) { |
| for (auto& it : info.Mutable()) { |
| it.CopyFrom({it.data(), it.size()}); |
| } |
| } else { |
| for (auto& it : info.Mutable()) { |
| std::string tmp; |
| tmp = it; |
| it = tmp; |
| } |
| } |
| } else { |
| for (auto& it : info.Mutable()) { |
| it = *⁢ // Avoid -Wself-assign. |
| } |
| } |
| } else { |
| if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) { |
| auto& msg = info.Mutable(); |
| ABSL_CHECK_GT(msg.ByteSizeLong(), 0); |
| } else if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) { |
| if constexpr (info.is_cord) { |
| info.Set(info.Get()); |
| } else { |
| absl::string_view view = info.Get(); |
| info.Set({view.data(), view.size()}); |
| } |
| } else { |
| info.Set(info.Get()); |
| } |
| } |
| }); |
| } |
| |
| TEST_P(VisitFieldsTest, MutateNothingByVisitIdempotent) { |
| std::string data; |
| ASSERT_TRUE(message_->SerializeToString(&data)); |
| |
| MutateNothingByVisit(*message_); |
| |
| // Checking the identity by comparing serialize bytes is discouraged, but this |
| // allows us to be type-agnostic for this test. Also, the back to back |
| // serialization should be stable. |
| EXPECT_EQ(data, message_->SerializeAsString()); |
| } |
| |
| template <typename InfoT> |
| inline size_t MapKeyByteSizeLong(FieldDescriptor::Type type, InfoT info) { |
| // There is a bug in GCC 9.5 where if-constexpr arguments are not understood |
| // if passed into a helper function. A reproduction of the bug can be found |
| // at: https://godbolt.org/z/65qW3vGhP |
| // This is fixed in GCC 10.1+. |
| (void)type; // Suppress -Wunused-but-set-parameter |
| |
| if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) { |
| return WireFormatLite::StringSize(info.Get()); |
| } else { |
| return MapPrimitiveFieldByteSize<info.cpp_type>(type, info.Get()); |
| } |
| } |
| |
| template <typename InfoT> |
| inline size_t MapValueByteSizeLong(FieldDescriptor::Type type, InfoT info) { |
| // There is a bug in GCC 9.5 where if-constexpr arguments are not understood |
| // if passed into a helper function. A reproduction of the bug can be found |
| // at: https://godbolt.org/z/65qW3vGhP |
| // This is fixed in GCC 10.1+. |
| (void)type; // Suppress -Wunused-but-set-parameter |
| |
| if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) { |
| return WireFormatLite::StringSize(info.Get()); |
| } else if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) { |
| ABSL_CHECK_NE(type, FieldDescriptor::TYPE_GROUP); |
| return WireFormatLite::MessageSize(info.Get()); |
| } else { |
| return MapPrimitiveFieldByteSize<info.cpp_type>(type, info.Get()); |
| } |
| } |
| |
| size_t ByteSizeLongByVisit(const Message& message) { |
| size_t byte_size = 0; |
| |
| VisitFields(message, [&](auto info) { |
| if constexpr (info.is_map) { |
| // A map field is encoded as repeated messages like: |
| // |
| // message MapField { |
| // message MapEntry { |
| // 1: key |
| // 2: value |
| // }; |
| // repeated MapEntry entry = <FIELD_NUMBER>; |
| // } |
| // where the key and the value tag are always encoded as two bytes. (no |
| // group). |
| // |
| // Visit each element to calculate the inner size (key, value) and the |
| // length field for LENGTH_DELIMITED field. |
| size_t size = static_cast<size_t>(info.size()); |
| ABSL_CHECK_GT(size, 0u); |
| byte_size += size * WireFormat::TagSize(info.number(), |
| FieldDescriptor::TYPE_MESSAGE); |
| |
| FieldDescriptor::Type key_type = info.key_type(); |
| FieldDescriptor::Type val_type = info.value_type(); |
| |
| info.VisitElements([&](auto key, auto val) { |
| size_t inner_size = 2 + MapKeyByteSizeLong(key_type, key) + |
| MapValueByteSizeLong(val_type, val); |
| |
| byte_size += io::CodedOutputStream::VarintSize32( |
| static_cast<uint32_t>(inner_size)) + |
| inner_size; |
| }); |
| } else if constexpr (info.is_repeated) { |
| if (info.is_packed()) { |
| size_t payload_size = info.FieldByteSize(); |
| byte_size += |
| WireFormat::TagSize(info.number(), FieldDescriptor::TYPE_STRING) + |
| payload_size + |
| io::CodedOutputStream::VarintSize32( |
| static_cast<uint32_t>(payload_size)); |
| } else { |
| size_t size = static_cast<size_t>(info.size()); |
| ABSL_CHECK_GT(size, 0u); |
| byte_size += size * WireFormat::TagSize(info.number(), info.type()) + |
| info.FieldByteSize(); |
| } |
| } else { |
| if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) { |
| if constexpr (info.is_extension) { |
| if (info.type() == FieldDescriptor::TYPE_MESSAGE && |
| info.is_message_set) { |
| byte_size += |
| WireFormatLite::kMessageSetItemTagsSize + |
| io::CodedOutputStream::VarintSize32( |
| static_cast<uint32_t>(info.number())) + |
| WireFormatLite::LengthDelimitedSize(info.FieldByteSize()); |
| return; |
| } |
| } |
| if (info.type() == FieldDescriptor::TYPE_GROUP) { |
| byte_size += WireFormat::TagSize(info.number(), info.type()) + |
| info.FieldByteSize(); |
| } else if (info.type() == FieldDescriptor::TYPE_MESSAGE) { |
| byte_size += |
| WireFormat::TagSize(info.number(), info.type()) + |
| WireFormatLite::LengthDelimitedSize(info.FieldByteSize()); |
| } |
| } else { |
| byte_size += WireFormat::TagSize(info.number(), info.type()) + |
| info.FieldByteSize(); |
| } |
| } |
| }); |
| return byte_size; |
| } |
| |
| TEST_P(VisitFieldsTest, ByteSizeByVisitFieldsMatchesCodegen) { |
| EXPECT_EQ(ByteSizeLongByVisit(*message_), message_->ByteSizeLong()); |
| } |
| |
| TestMessageSet* CreateTestMessageSet(Arena& arena) { |
| auto* msg = Arena::Create<TestMessageSet>(&arena); |
| auto* ext1 = msg->MutableExtension( |
| unittest::TestMessageSetExtension1::message_set_extension); |
| ext1->set_i(-1); |
| auto* ext3 = ext1->mutable_recursive()->MutableExtension( |
| unittest::TestMessageSetExtension3::message_set_extension); |
| ext3->set_required_int(-1); |
| ext3->mutable_msg()->set_b(0); |
| return msg; |
| } |
| |
| NestedTestAllTypes* CreateNestedTestAllTypes(Arena& arena) { |
| auto* msg = Arena::Create<NestedTestAllTypes>(&arena); |
| TestUtil::SetAllFields(msg->mutable_payload()); |
| TestUtil::SetAllFields(msg->mutable_lazy_child()->mutable_payload()); |
| return msg; |
| } |
| |
| INSTANTIATE_TEST_SUITE_P( |
| ReflectionVisitFieldsTest, VisitFieldsTest, |
| testing::Values( |
| TestParam{"TestAllTypes", |
| [](Arena& arena) -> Message* { |
| auto* msg = Arena::Create<TestAllTypes>(&arena); |
| TestUtil::SetAllFields(msg); |
| return msg; |
| }}, |
| TestParam{"TestAllExtensions", |
| [](Arena& arena) -> Message* { |
| auto* msg = Arena::Create<TestAllExtensions>(&arena); |
| TestUtil::SetAllExtensions(msg); |
| return msg; |
| }}, |
| TestParam{"TestAllExtensionsLazy", |
| [](Arena& arena) -> Message* { |
| TestAllExtensions original; |
| TestUtil::SetAllExtensions(&original); |
| auto* parsed = Arena::Create<TestAllExtensions>(&arena); |
| ABSL_CHECK( |
| parsed->ParseFromString(original.SerializeAsString())); |
| return parsed; |
| }}, |
| TestParam{"TestMap", |
| [](Arena& arena) -> Message* { |
| auto* msg = Arena::Create<TestMap>(&arena); |
| MapTestUtil::SetMapFields(msg); |
| return msg; |
| }}, |
| TestParam{"TestMessageSet", |
| [](Arena& arena) -> Message* { |
| return CreateTestMessageSet(arena); |
| }}, |
| TestParam{"TestMessageSetLazy", |
| [](Arena& arena) -> Message* { |
| auto* original = CreateTestMessageSet(arena); |
| auto* parsed = Arena::Create<TestMessageSet>(&arena); |
| ABSL_CHECK( |
| parsed->ParseFromString(original->SerializeAsString())); |
| return parsed; |
| }}, |
| TestParam{"TestOneof2LazyField", |
| [](Arena& arena) -> Message* { |
| auto* msg = Arena::Create<TestOneof2>(&arena); |
| TestUtil::SetOneof2(msg); |
| msg->mutable_foo_lazy_message()->set_moo_int(0); |
| return msg; |
| }}, |
| TestParam{"TestPacked", |
| [](Arena& arena) -> Message* { |
| auto* msg = Arena::Create<TestPackedTypes>(&arena); |
| TestUtil::SetPackedFields(msg); |
| return msg; |
| }}, |
| TestParam{"TestPackedExtensions", |
| [](Arena& arena) -> Message* { |
| auto* msg = Arena::Create<TestPackedExtensions>(&arena); |
| TestUtil::SetPackedExtensions(msg); |
| return msg; |
| }}, |
| TestParam{"NestedTestAllTypes", |
| [](Arena& arena) -> Message* { |
| return CreateNestedTestAllTypes(arena); |
| }}, |
| TestParam{"NestedTestAllTypesLazy", |
| [](Arena& arena) -> Message* { |
| auto* original = CreateNestedTestAllTypes(arena); |
| auto* parsed = Arena::Create<NestedTestAllTypes>(&arena); |
| ABSL_CHECK( |
| parsed->ParseFromString(original->SerializeAsString())); |
| return parsed; |
| }}), |
| [](const testing::TestParamInfo<TestParam>& info) { |
| return std::string(info.param.name); |
| }); |
| |
| template <typename T> |
| void MutateMapValue(TestMap& message, absl::string_view name, int index, |
| T&& callback) { |
| const Reflection* reflection = message.GetReflection(); |
| const Descriptor* descriptor = message.GetDescriptor(); |
| const FieldDescriptor* field = descriptor->FindFieldByName(name); |
| |
| auto* map_entry = reflection->MutableRepeatedMessage(&message, field, index); |
| const FieldDescriptor* val_field = map_entry->GetDescriptor()->map_value(); |
| ABSL_CHECK_NE(val_field, nullptr); |
| callback(map_entry->GetReflection(), map_entry, val_field); |
| } |
| |
| TEST(ReflectionVisitTest, VisitMapAfterMutableRepeated) { |
| TestMap message; |
| auto& map = *message.mutable_map_int32_int32(); |
| map[0] = 0; |
| map[1] = 0; |
| |
| // Reflectively overwrites values to 200 for all entries. This forces |
| // conversion to a mutable repeated field. |
| auto set_int32_val = [&](const Reflection* reflection, Message* msg, |
| const FieldDescriptor* field) { |
| reflection->SetInt32(msg, field, 200); |
| }; |
| MutateMapValue(message, "map_int32_int32", 0, set_int32_val); |
| MutateMapValue(message, "map_int32_int32", 1, set_int32_val); |
| |
| // Later visit fields must be map fields synced with the change. |
| std::vector<std::pair<int32_t, int32_t>> key_val_pairs; |
| VisitFields(message, [&](auto info) { |
| if constexpr (info.is_map) { |
| ASSERT_EQ(info.key_type(), FieldDescriptor::TYPE_INT32); |
| ASSERT_EQ(info.value_type(), FieldDescriptor::TYPE_INT32); |
| |
| info.VisitElements([&](auto key, auto val) { |
| if constexpr (key.cpp_type == FieldDescriptor::CPPTYPE_INT32 && |
| val.cpp_type == FieldDescriptor::CPPTYPE_INT32) { |
| key_val_pairs.emplace_back(key.Get(), val.Get()); |
| } |
| }); |
| } |
| }); |
| |
| EXPECT_THAT(key_val_pairs, testing::UnorderedElementsAre( |
| testing::Pair(0, 200), testing::Pair(1, 200))); |
| } |
| |
| #endif // __cpp_if_constexpr |
| |
| } // namespace |
| } // namespace internal |
| } // namespace protobuf |
| } // namespace google |