| // Protocol Buffers - Google's data interchange format |
| // Copyright 2008 Google Inc. All rights reserved. |
| // |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file or at |
| // https://developers.google.com/open-source/licenses/bsd |
| |
| #include "google/protobuf/map_field.h" |
| |
| #include <atomic> |
| #include <cstddef> |
| #include <cstdint> |
| #include <string> |
| #include <type_traits> |
| |
| #include "absl/functional/overload.h" |
| #include "absl/log/absl_check.h" |
| #include "absl/synchronization/mutex.h" |
| #include "google/protobuf/arena.h" |
| #include "google/protobuf/descriptor.h" |
| #include "google/protobuf/map.h" |
| #include "google/protobuf/port.h" |
| #include "google/protobuf/raw_ptr.h" |
| #include "google/protobuf/repeated_ptr_field.h" |
| |
| // Must be included last. |
| #include "google/protobuf/port_def.inc" |
| |
| namespace google { |
| namespace protobuf { |
| namespace internal { |
| |
| MapFieldBase::~MapFieldBase() { delete maybe_payload(); } |
| |
| void MapFieldBase::MergeFrom(Arena* arena, const MapFieldBase& other) { |
| MutableMap()->UntypedMergeFrom(arena, other.GetMap()); |
| } |
| |
| void MapFieldBase::Swap(Arena* arena, MapFieldBase* other, Arena* other_arena) { |
| ABSL_DCHECK_EQ(arena, this->arena()); |
| ABSL_DCHECK_EQ(other_arena, other->arena()); |
| |
| if (arena == other_arena) { |
| InternalSwap(other); |
| return; |
| } |
| MapFieldBase::SwapPayload(*this, *other); |
| GetMapRaw().UntypedSwap(arena, other->GetMapRaw(), other_arena); |
| } |
| |
| const Message* MapFieldBase::GetPrototype() const { |
| const void* p = prototype_or_payload_.load(std::memory_order_acquire); |
| if (IsPayload(p)) { |
| return ToPayload(p)->prototype(); |
| } |
| return reinterpret_cast<const Message*>(p); |
| } |
| |
| template <typename Map, typename F> |
| auto VisitMapKey(const MapKey& map_key, Map& map, F f) { |
| switch (map_key.type()) { |
| #define HANDLE_TYPE(CPPTYPE, Type, KeyBaseType) \ |
| case FieldDescriptor::CPPTYPE_##CPPTYPE: { \ |
| using KMB = KeyMapBase<KeyBaseType>; \ |
| return f( \ |
| static_cast< \ |
| std::conditional_t<std::is_const_v<Map>, const KMB&, KMB&>>(map), \ |
| TransparentSupport<KeyBaseType>::ToView(map_key.Get##Type##Value())); \ |
| } |
| HANDLE_TYPE(INT32, Int32, uint32_t); |
| HANDLE_TYPE(UINT32, UInt32, uint32_t); |
| HANDLE_TYPE(INT64, Int64, uint64_t); |
| HANDLE_TYPE(UINT64, UInt64, uint64_t); |
| HANDLE_TYPE(BOOL, Bool, bool); |
| HANDLE_TYPE(STRING, String, std::string); |
| #undef HANDLE_TYPE |
| default: |
| Unreachable(); |
| } |
| } |
| |
| bool MapFieldBase::InsertOrLookupMapValueNoSync(const MapKey& map_key, |
| MapValueRef* val) { |
| if (LookupMapValueNoSync(map_key, static_cast<MapValueConstRef*>(val))) { |
| return false; |
| } |
| |
| auto& map = GetMapRaw(); |
| Arena* arena = map.arena(); |
| |
| NodeBase* node = map.AllocNode(arena); |
| map.VisitValue(node, [&](auto* v) { InitializeKeyValue(v); }); |
| val->SetValue(map.GetVoidValue(node)); |
| |
| return VisitMapKey(map_key, map, [&](auto& map, const auto& key) { |
| InitializeKeyValue(map.GetKey(node), key); |
| map.InsertOrReplaceNode( |
| arena, |
| static_cast<typename std::decay_t<decltype(map)>::KeyNode*>(node)); |
| return true; |
| }); |
| } |
| |
| bool MapFieldBase::DeleteMapValue(Arena* arena, const MapKey& map_key) { |
| return VisitMapKey(map_key, *MutableMap(), |
| [arena](auto& map, const auto& key) { |
| return map.EraseImpl(arena, key); |
| }); |
| } |
| |
| void MapFieldBase::ClearMapNoSync() { |
| GetMapRaw().ClearTable(arena(), /*reset=*/true); |
| } |
| |
| template <bool kIsMutable> |
| void MapFieldBase::SetMapIteratorValue( |
| MapIteratorBase<kIsMutable>* map_iter) const { |
| if (map_iter->iter_.Equals(UntypedMapBase::EndIterator())) return; |
| |
| const UntypedMapBase& map = *map_iter->iter_.m_; |
| NodeBase* node = map_iter->iter_.node_; |
| auto& key = map_iter->key_; |
| map.VisitKey(node, |
| absl::Overload{ |
| [&](const std::string* v) { key.val_.string_value = *v; }, |
| [&](const auto* v) { |
| // Memcpy the scalar into the union. |
| memcpy(static_cast<void*>(&key.val_), v, sizeof(*v)); |
| }, |
| }); |
| map_iter->value_.SetValue(map.GetVoidValue(node)); |
| } |
| |
| bool MapFieldBase::LookupMapValueNoSync(const MapKey& map_key, |
| MapValueConstRef* val) const { |
| auto& map = GetMapRaw(); |
| if (map.empty()) return false; |
| |
| return VisitMapKey(map_key, map, [&](auto& map, const auto& key) { |
| auto res = map.FindHelper(key); |
| if (res.node == nullptr) { |
| return false; |
| } |
| if (val != nullptr) { |
| val->SetValue(map.GetVoidValue(res.node)); |
| } |
| return true; |
| }); |
| } |
| |
| void MapFieldBase::MapBegin(MapIterator* map_iter) const { |
| map_iter->iter_ = GetMap().begin(); |
| SetMapIteratorValue(map_iter); |
| } |
| |
| void MapFieldBase::MapEnd(MapIterator* map_iter) const { |
| map_iter->iter_ = UntypedMapBase::EndIterator(); |
| } |
| |
| void MapFieldBase::ConstMapBegin(ConstMapIterator* map_iter) const { |
| map_iter->iter_ = GetMap().begin(); |
| SetMapIteratorValue(map_iter); |
| } |
| |
| void MapFieldBase::ConstMapEnd(ConstMapIterator* map_iter) const { |
| map_iter->iter_ = UntypedMapBase::EndIterator(); |
| } |
| |
| template <bool kIsMutable> |
| bool MapFieldBase::EqualIterator(const MapIteratorBase<kIsMutable>& a, |
| const MapIteratorBase<kIsMutable>& b) const { |
| return a.iter_.Equals(b.iter_); |
| } |
| |
| template <bool kIsMutable> |
| void MapFieldBase::IncreaseIterator( |
| MapIteratorBase<kIsMutable>* map_iter) const { |
| map_iter->iter_.PlusPlus(); |
| SetMapIteratorValue(map_iter); |
| } |
| |
| template <bool kIsMutable> |
| void MapFieldBase::CopyIterator( |
| MapIteratorBase<kIsMutable>* this_iter, |
| const MapIteratorBase<kIsMutable>& that_iter) const { |
| this_iter->iter_ = that_iter.iter_; |
| this_iter->key_.SetType(that_iter.key_.type()); |
| // MapValueRef::type() fails when containing data is null. However, if |
| // this_iter points to MapEnd, data can be null. |
| this_iter->value_.SetType( |
| static_cast<FieldDescriptor::CppType>(that_iter.value_.type_)); |
| SetMapIteratorValue(this_iter); |
| } |
| |
| const RepeatedPtrFieldBase& MapFieldBase::GetRepeatedField() const { |
| ConstAccess(); |
| return SyncRepeatedFieldWithMap(false); |
| } |
| |
| RepeatedPtrFieldBase* MapFieldBase::MutableRepeatedField() { |
| MutableAccess(); |
| auto& res = SyncRepeatedFieldWithMap(true); |
| SetRepeatedDirty(); |
| return const_cast<RepeatedPtrFieldBase*>(&res); |
| } |
| |
| template <typename T> |
| static void SwapRelaxed(std::atomic<T>& a, std::atomic<T>& b) { |
| auto value_b = b.load(std::memory_order_relaxed); |
| auto value_a = a.load(std::memory_order_relaxed); |
| b.store(value_a, std::memory_order_relaxed); |
| a.store(value_b, std::memory_order_relaxed); |
| } |
| |
| MapFieldBase::ReflectionPayload& MapFieldBase::PayloadSlow() const { |
| const void* p = prototype_or_payload_.load(std::memory_order_acquire); |
| if (!IsPayload(p)) { |
| // Inject the sync callback. |
| sync_map_with_repeated.store( |
| [](auto& map, bool is_mutable) { |
| const auto& self = static_cast<const MapFieldBase&>(map); |
| self.SyncMapWithRepeatedField(); |
| if (is_mutable) const_cast<MapFieldBase&>(self).SetMapDirty(); |
| }, |
| std::memory_order_relaxed); |
| |
| const Message* prototype = static_cast<const Message*>(p); |
| auto* payload = |
| Arena::Create<ReflectionPayload>(arena(), arena(), prototype); |
| |
| auto new_p = ToTaggedPtr(payload); |
| if (prototype_or_payload_.compare_exchange_strong( |
| p, new_p, std::memory_order_acq_rel)) { |
| // We were able to store it. |
| p = new_p; |
| } else { |
| // Someone beat us to it. Throw away the one we made. `p` already contains |
| // the one we want. |
| if (arena() == nullptr) delete payload; |
| } |
| } |
| return *ToPayload(p); |
| } |
| |
| void MapFieldBase::SwapPayload(MapFieldBase& lhs, MapFieldBase& rhs) { |
| if (lhs.arena() == rhs.arena()) { |
| SwapRelaxed(lhs.prototype_or_payload_, rhs.prototype_or_payload_); |
| return; |
| } |
| auto* p1 = lhs.maybe_payload(); |
| auto* p2 = rhs.maybe_payload(); |
| if (p1 == nullptr && p2 == nullptr) return; |
| |
| if (p1 == nullptr) p1 = &lhs.payload(); |
| if (p2 == nullptr) p2 = &rhs.payload(); |
| p1->Swap(*p2); |
| } |
| |
| void MapFieldBase::InternalSwap(MapFieldBase* other) { |
| GetMapRaw().InternalSwap(&other->GetMapRaw()); |
| SwapPayload(*this, *other); |
| } |
| |
| size_t MapFieldBase::SpaceUsedExcludingSelfLong() const { |
| ConstAccess(); |
| size_t size = 0; |
| if (auto* p = maybe_payload()) { |
| absl::MutexLock lock(&p->mutex()); |
| // Measure the map under the lock, because there could be some repeated |
| // field data that might be sync'd back into the map. |
| size = GetMapRaw().SpaceUsedExcludingSelfLong(); |
| size += p->repeated_field().SpaceUsedExcludingSelfLong(); |
| ConstAccess(); |
| } else { |
| // Only measure the map without the repeated field, because it is not there. |
| size = GetMapRaw().SpaceUsedExcludingSelfLong(); |
| ConstAccess(); |
| } |
| return size; |
| } |
| |
| bool MapFieldBase::IsMapValid() const { |
| ConstAccess(); |
| // "Acquire" insures the operation after SyncRepeatedFieldWithMap won't get |
| // executed before state_ is checked. |
| return state() != STATE_MODIFIED_REPEATED; |
| } |
| |
| bool MapFieldBase::IsRepeatedFieldValid() const { |
| ConstAccess(); |
| return state() != STATE_MODIFIED_MAP; |
| } |
| |
| void MapFieldBase::SetRepeatedDirty() { |
| MutableAccess(); |
| // These are called by (non-const) mutator functions. So by our API it's the |
| // callers responsibility to have these calls properly ordered. |
| payload().set_state_relaxed(STATE_MODIFIED_REPEATED); |
| } |
| |
| const RepeatedPtrFieldBase& MapFieldBase::SyncRepeatedFieldWithMap( |
| bool for_mutation) const { |
| ConstAccess(); |
| if (state() == STATE_MODIFIED_MAP) { |
| auto* p = maybe_payload(); |
| if (p == nullptr) { |
| // If we have no payload, and we do not want to mutate the object, and the |
| // map is empty, then do nothing. |
| // This prevents modifying global default instances which might be in ro |
| // memory. |
| if (!for_mutation && GetMapRaw().empty()) { |
| return *RawPtr<const RepeatedPtrFieldBase>(); |
| } |
| p = &payload(); |
| } |
| |
| { |
| absl::MutexLock lock(&p->mutex()); |
| // Double check state, because another thread may have seen the same |
| // state and done the synchronization before the current thread. |
| if (p->load_state_relaxed() == STATE_MODIFIED_MAP) { |
| const_cast<MapFieldBase*>(this)->SyncRepeatedFieldWithMapNoLock(); |
| p->set_state_release(CLEAN); |
| } |
| } |
| ConstAccess(); |
| return static_cast<const RepeatedPtrFieldBase&>(p->repeated_field()); |
| } |
| return static_cast<const RepeatedPtrFieldBase&>(payload().repeated_field()); |
| } |
| |
| void MapFieldBase::SyncRepeatedFieldWithMapNoLock() { |
| const Message* prototype = GetPrototype(); |
| const Reflection* reflection = prototype->GetReflection(); |
| const Descriptor* descriptor = prototype->GetDescriptor(); |
| const FieldDescriptor* key_des = descriptor->map_key(); |
| const FieldDescriptor* val_des = descriptor->map_value(); |
| |
| RepeatedPtrField<Message>& rep = payload().repeated_field(); |
| rep.Clear(); |
| |
| ConstMapIterator it(this, descriptor); |
| ConstMapIterator end(this, descriptor); |
| |
| it.iter_ = GetMapRaw().begin(); |
| SetMapIteratorValue(&it); |
| end.iter_ = UntypedMapBase::EndIterator(); |
| |
| Arena* arena = this->arena(); |
| for (; !EqualIterator(it, end); IncreaseIterator(&it)) { |
| Message* new_entry = reinterpret_cast<Message*>( |
| rep.AddInternal(arena, [prototype](Arena* arena, void*& ptr) { |
| ptr = prototype->New(arena); |
| })); |
| |
| const MapKey& map_key = it.GetKey(); |
| switch (key_des->cpp_type()) { |
| case FieldDescriptor::CPPTYPE_STRING: |
| reflection->SetString(new_entry, key_des, |
| std::string(map_key.GetStringValue())); |
| break; |
| case FieldDescriptor::CPPTYPE_INT64: |
| reflection->SetInt64(new_entry, key_des, map_key.GetInt64Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_INT32: |
| reflection->SetInt32(new_entry, key_des, map_key.GetInt32Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_UINT64: |
| reflection->SetUInt64(new_entry, key_des, map_key.GetUInt64Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_UINT32: |
| reflection->SetUInt32(new_entry, key_des, map_key.GetUInt32Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_BOOL: |
| reflection->SetBool(new_entry, key_des, map_key.GetBoolValue()); |
| break; |
| default: |
| Unreachable(); |
| } |
| |
| const MapValueConstRef& map_val = it.GetValueRef(); |
| switch (val_des->cpp_type()) { |
| case FieldDescriptor::CPPTYPE_STRING: |
| reflection->SetString(new_entry, val_des, |
| std::string(map_val.GetStringValue())); |
| break; |
| case FieldDescriptor::CPPTYPE_INT64: |
| reflection->SetInt64(new_entry, val_des, map_val.GetInt64Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_INT32: |
| reflection->SetInt32(new_entry, val_des, map_val.GetInt32Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_UINT64: |
| reflection->SetUInt64(new_entry, val_des, map_val.GetUInt64Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_UINT32: |
| reflection->SetUInt32(new_entry, val_des, map_val.GetUInt32Value()); |
| break; |
| case FieldDescriptor::CPPTYPE_BOOL: |
| reflection->SetBool(new_entry, val_des, map_val.GetBoolValue()); |
| break; |
| case FieldDescriptor::CPPTYPE_DOUBLE: |
| reflection->SetDouble(new_entry, val_des, map_val.GetDoubleValue()); |
| break; |
| case FieldDescriptor::CPPTYPE_FLOAT: |
| reflection->SetFloat(new_entry, val_des, map_val.GetFloatValue()); |
| break; |
| case FieldDescriptor::CPPTYPE_ENUM: |
| reflection->SetEnumValue(new_entry, val_des, map_val.GetEnumValue()); |
| break; |
| case FieldDescriptor::CPPTYPE_MESSAGE: { |
| const Message& message = map_val.GetMessageValue(); |
| reflection->MutableMessage(new_entry, val_des)->CopyFrom(message); |
| break; |
| } |
| } |
| } |
| } |
| |
| void MapFieldBase::SyncMapWithRepeatedField() const { |
| ConstAccess(); |
| // acquire here matches with release below to ensure that we can only see a |
| // value of CLEAN after all previous changes have been synced. |
| if (state() == STATE_MODIFIED_REPEATED) { |
| auto& p = payload(); |
| { |
| absl::MutexLock lock(&p.mutex()); |
| // Double check state, because another thread may have seen the same state |
| // and done the synchronization before the current thread. |
| if (p.load_state_relaxed() == STATE_MODIFIED_REPEATED) { |
| const_cast<MapFieldBase*>(this)->SyncMapWithRepeatedFieldNoLock(); |
| p.set_state_release(CLEAN); |
| } |
| } |
| ConstAccess(); |
| } |
| } |
| |
| void MapFieldBase::SyncMapWithRepeatedFieldNoLock() { |
| ClearMapNoSync(); |
| |
| RepeatedPtrField<Message>& rep = payload().repeated_field(); |
| |
| if (rep.empty()) return; |
| |
| const Message* prototype = &rep[0]; |
| const Reflection* reflection = prototype->GetReflection(); |
| const Descriptor* descriptor = prototype->GetDescriptor(); |
| const FieldDescriptor* key_des = descriptor->map_key(); |
| const FieldDescriptor* val_des = descriptor->map_value(); |
| |
| for (const Message& elem : rep) { |
| // MapKey type will be set later. |
| Reflection::ScratchSpace map_key_scratch_space; |
| MapKey map_key; |
| switch (key_des->cpp_type()) { |
| case FieldDescriptor::CPPTYPE_STRING: |
| map_key.SetStringValue( |
| reflection->GetStringView(elem, key_des, map_key_scratch_space)); |
| break; |
| case FieldDescriptor::CPPTYPE_INT64: |
| map_key.SetInt64Value(reflection->GetInt64(elem, key_des)); |
| break; |
| case FieldDescriptor::CPPTYPE_INT32: |
| map_key.SetInt32Value(reflection->GetInt32(elem, key_des)); |
| break; |
| case FieldDescriptor::CPPTYPE_UINT64: |
| map_key.SetUInt64Value(reflection->GetUInt64(elem, key_des)); |
| break; |
| case FieldDescriptor::CPPTYPE_UINT32: |
| map_key.SetUInt32Value(reflection->GetUInt32(elem, key_des)); |
| break; |
| case FieldDescriptor::CPPTYPE_BOOL: |
| map_key.SetBoolValue(reflection->GetBool(elem, key_des)); |
| break; |
| default: |
| Unreachable(); |
| } |
| |
| MapValueRef map_val; |
| map_val.SetType(val_des->cpp_type()); |
| InsertOrLookupMapValueNoSync(map_key, &map_val); |
| |
| switch (val_des->cpp_type()) { |
| #define HANDLE_TYPE(CPPTYPE, METHOD) \ |
| case FieldDescriptor::CPPTYPE_##CPPTYPE: \ |
| map_val.Set##METHOD##Value(reflection->Get##METHOD(elem, val_des)); \ |
| break; |
| HANDLE_TYPE(INT32, Int32); |
| HANDLE_TYPE(INT64, Int64); |
| HANDLE_TYPE(UINT32, UInt32); |
| HANDLE_TYPE(UINT64, UInt64); |
| HANDLE_TYPE(DOUBLE, Double); |
| HANDLE_TYPE(FLOAT, Float); |
| HANDLE_TYPE(BOOL, Bool); |
| HANDLE_TYPE(STRING, String); |
| #undef HANDLE_TYPE |
| case FieldDescriptor::CPPTYPE_ENUM: |
| map_val.SetEnumValue(reflection->GetEnumValue(elem, val_des)); |
| break; |
| case FieldDescriptor::CPPTYPE_MESSAGE: { |
| map_val.MutableMessageValue()->CopyFrom( |
| reflection->GetMessage(elem, val_des)); |
| break; |
| } |
| } |
| } |
| } |
| |
| void MapFieldBase::Clear() { |
| if (ReflectionPayload* p = maybe_payload()) { |
| p->repeated_field().Clear(); |
| } |
| |
| ClearMapNoSync(); |
| // Data in map and repeated field are both empty, but we can't set status |
| // CLEAN. Because clear is a generated API, we cannot invalidate previous |
| // reference to map. |
| SetMapDirty(); |
| } |
| |
| int MapFieldBase::size() const { return GetMap().size(); } |
| |
| bool MapFieldBase::InsertOrLookupMapValue(const MapKey& map_key, |
| MapValueRef* val) { |
| SyncMapWithRepeatedField(); |
| SetMapDirty(); |
| return InsertOrLookupMapValueNoSync(map_key, val); |
| } |
| |
| void MapFieldBase::ReflectionPayload::Swap(ReflectionPayload& other) { |
| repeated_field().Swap(&other.repeated_field()); |
| SwapRelaxed(state_, other.state_); |
| } |
| |
| } // namespace internal |
| |
| template <bool kIsMutable> |
| MapIteratorBase<kIsMutable>::MapIteratorBase(MessageT* message, |
| const FieldDescriptor* field) { |
| const Reflection* reflection = message->GetReflection(); |
| if constexpr (kIsMutable) { |
| map_ = reflection->MutableMapData(message, field); |
| } else { |
| map_ = reflection->GetMapData(*message, field); |
| } |
| key_.SetType(field->message_type()->map_key()->cpp_type()); |
| value_.SetType(field->message_type()->map_value()->cpp_type()); |
| } |
| |
| template <bool kIsMutable> |
| MapIteratorBase<kIsMutable>& MapIteratorBase<kIsMutable>::operator=( |
| const MapIteratorBase& other) { |
| map_ = other.map_; |
| map_->CopyIterator(this, other); |
| return *this; |
| } |
| |
| template <bool kIsMutable> |
| bool MapIteratorBase<kIsMutable>::operator==( |
| const MapIteratorBase<kIsMutable>& other) const { |
| return map_->EqualIterator(*this, other); |
| } |
| |
| template <bool kIsMutable> |
| typename MapIteratorBase<kIsMutable>::DerivedIterator& |
| MapIteratorBase<kIsMutable>::operator++() { |
| map_->IncreaseIterator(this); |
| return static_cast<DerivedIterator&>(*this); |
| } |
| |
| template <bool kIsMutable> |
| typename MapIteratorBase<kIsMutable>::DerivedIterator |
| MapIteratorBase<kIsMutable>::operator++(int) { |
| // iter_ is copied from Map<...>::iterator, no need to |
| // copy from its self again. Use the same implementation |
| // with operator++() |
| map_->IncreaseIterator(this); |
| return *static_cast<DerivedIterator*>(this); |
| } |
| |
| template <bool kIsMutable> |
| MapIteratorBase<kIsMutable>::MapIteratorBase(MapFieldBase* map, |
| const Descriptor* descriptor) { |
| map_ = map; |
| key_.SetType(descriptor->map_key()->cpp_type()); |
| value_.SetType(descriptor->map_value()->cpp_type()); |
| } |
| |
| template class MapIteratorBase</*kIsMutable=*/false>; |
| template class MapIteratorBase</*kIsMutable=*/true>; |
| |
| } // namespace protobuf |
| } // namespace google |
| |
| #include "google/protobuf/port_undef.inc" |