| # Protocol Buffers - Google's data interchange format |
| # Copyright 2008 Google Inc. All rights reserved. |
| # |
| # Use of this source code is governed by a BSD-style |
| # license that can be found in the LICENSE file or at |
| # https://developers.google.com/open-source/licenses/bsd |
| |
| # This code is meant to work on Python 2.4 and above only. |
| # |
| # TODO: Helpers for verbose, common checks like seeing if a |
| # descriptor's cpp_type is CPPTYPE_MESSAGE. |
| |
| """Contains a metaclass and helper functions used to create |
| protocol message classes from Descriptor objects at runtime. |
| |
| Recall that a metaclass is the "type" of a class. |
| (A class is to a metaclass what an instance is to a class.) |
| |
| In this case, we use the GeneratedProtocolMessageType metaclass |
| to inject all the useful functionality into the classes |
| output by the protocol compiler at compile-time. |
| |
| The upshot of all this is that the real implementation |
| details for ALL pure-Python protocol buffers are *here in |
| this file*. |
| """ |
| |
| __author__ = 'robinson@google.com (Will Robinson)' |
| |
| import datetime |
| from io import BytesIO |
| import struct |
| import sys |
| import warnings |
| import weakref |
| |
| from google.protobuf import descriptor as descriptor_mod |
| from google.protobuf import message as message_mod |
| from google.protobuf import text_format |
| # We use "as" to avoid name collisions with variables. |
| from google.protobuf.internal import api_implementation |
| from google.protobuf.internal import containers |
| from google.protobuf.internal import decoder |
| from google.protobuf.internal import encoder |
| from google.protobuf.internal import enum_type_wrapper |
| from google.protobuf.internal import extension_dict |
| from google.protobuf.internal import message_listener as message_listener_mod |
| from google.protobuf.internal import type_checkers |
| from google.protobuf.internal import well_known_types |
| from google.protobuf.internal import wire_format |
| |
| _FieldDescriptor = descriptor_mod.FieldDescriptor |
| _AnyFullTypeName = 'google.protobuf.Any' |
| _StructFullTypeName = 'google.protobuf.Struct' |
| _ListValueFullTypeName = 'google.protobuf.ListValue' |
| _ExtensionDict = extension_dict._ExtensionDict |
| |
| class GeneratedProtocolMessageType(type): |
| |
| """Metaclass for protocol message classes created at runtime from Descriptors. |
| |
| We add implementations for all methods described in the Message class. We |
| also create properties to allow getting/setting all fields in the protocol |
| message. Finally, we create slots to prevent users from accidentally |
| "setting" nonexistent fields in the protocol message, which then wouldn't get |
| serialized / deserialized properly. |
| |
| The protocol compiler currently uses this metaclass to create protocol |
| message classes at runtime. Clients can also manually create their own |
| classes at runtime, as in this example: |
| |
| mydescriptor = Descriptor(.....) |
| factory = symbol_database.Default() |
| factory.pool.AddDescriptor(mydescriptor) |
| MyProtoClass = factory.GetPrototype(mydescriptor) |
| myproto_instance = MyProtoClass() |
| myproto.foo_field = 23 |
| ... |
| """ |
| |
| # Must be consistent with the protocol-compiler code in |
| # proto2/compiler/internal/generator.*. |
| _DESCRIPTOR_KEY = 'DESCRIPTOR' |
| |
| def __new__(cls, name, bases, dictionary): |
| """Custom allocation for runtime-generated class types. |
| |
| We override __new__ because this is apparently the only place |
| where we can meaningfully set __slots__ on the class we're creating(?). |
| (The interplay between metaclasses and slots is not very well-documented). |
| |
| Args: |
| name: Name of the class (ignored, but required by the |
| metaclass protocol). |
| bases: Base classes of the class we're constructing. |
| (Should be message.Message). We ignore this field, but |
| it's required by the metaclass protocol |
| dictionary: The class dictionary of the class we're |
| constructing. dictionary[_DESCRIPTOR_KEY] must contain |
| a Descriptor object describing this protocol message |
| type. |
| |
| Returns: |
| Newly-allocated class. |
| |
| Raises: |
| RuntimeError: Generated code only work with python cpp extension. |
| """ |
| descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
| |
| if isinstance(descriptor, str): |
| raise RuntimeError('The generated code only work with python cpp ' |
| 'extension, but it is using pure python runtime.') |
| |
| # If a concrete class already exists for this descriptor, don't try to |
| # create another. Doing so will break any messages that already exist with |
| # the existing class. |
| # |
| # The C++ implementation appears to have its own internal `PyMessageFactory` |
| # to achieve similar results. |
| # |
| # This most commonly happens in `text_format.py` when using descriptors from |
| # a custom pool; it calls symbol_database.Global().getPrototype() on a |
| # descriptor which already has an existing concrete class. |
| new_class = getattr(descriptor, '_concrete_class', None) |
| if new_class: |
| return new_class |
| |
| if descriptor.full_name in well_known_types.WKTBASES: |
| bases += (well_known_types.WKTBASES[descriptor.full_name],) |
| _AddClassAttributesForNestedExtensions(descriptor, dictionary) |
| _AddSlots(descriptor, dictionary) |
| |
| superclass = super(GeneratedProtocolMessageType, cls) |
| new_class = superclass.__new__(cls, name, bases, dictionary) |
| return new_class |
| |
| def __init__(cls, name, bases, dictionary): |
| """Here we perform the majority of our work on the class. |
| We add enum getters, an __init__ method, implementations |
| of all Message methods, and properties for all fields |
| in the protocol type. |
| |
| Args: |
| name: Name of the class (ignored, but required by the |
| metaclass protocol). |
| bases: Base classes of the class we're constructing. |
| (Should be message.Message). We ignore this field, but |
| it's required by the metaclass protocol |
| dictionary: The class dictionary of the class we're |
| constructing. dictionary[_DESCRIPTOR_KEY] must contain |
| a Descriptor object describing this protocol message |
| type. |
| """ |
| descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] |
| |
| # If this is an _existing_ class looked up via `_concrete_class` in the |
| # __new__ method above, then we don't need to re-initialize anything. |
| existing_class = getattr(descriptor, '_concrete_class', None) |
| if existing_class: |
| assert existing_class is cls, ( |
| 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' |
| % (descriptor.full_name)) |
| return |
| |
| cls._message_set_decoders_by_tag = {} |
| cls._fields_by_tag = {} |
| if (descriptor.has_options and |
| descriptor.GetOptions().message_set_wire_format): |
| cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( |
| decoder.MessageSetItemDecoder(descriptor), |
| None, |
| ) |
| |
| # Attach stuff to each FieldDescriptor for quick lookup later on. |
| for field in descriptor.fields: |
| _AttachFieldHelpers(cls, field) |
| |
| if descriptor.is_extendable and hasattr(descriptor.file, 'pool'): |
| extensions = descriptor.file.pool.FindAllExtensions(descriptor) |
| for ext in extensions: |
| _AttachFieldHelpers(cls, ext) |
| |
| descriptor._concrete_class = cls # pylint: disable=protected-access |
| _AddEnumValues(descriptor, cls) |
| _AddInitMethod(descriptor, cls) |
| _AddPropertiesForFields(descriptor, cls) |
| _AddPropertiesForExtensions(descriptor, cls) |
| _AddStaticMethods(cls) |
| _AddMessageMethods(descriptor, cls) |
| _AddPrivateHelperMethods(descriptor, cls) |
| |
| superclass = super(GeneratedProtocolMessageType, cls) |
| superclass.__init__(name, bases, dictionary) |
| |
| |
| # Stateless helpers for GeneratedProtocolMessageType below. |
| # Outside clients should not access these directly. |
| # |
| # I opted not to make any of these methods on the metaclass, to make it more |
| # clear that I'm not really using any state there and to keep clients from |
| # thinking that they have direct access to these construction helpers. |
| |
| |
| def _PropertyName(proto_field_name): |
| """Returns the name of the public property attribute which |
| clients can use to get and (in some cases) set the value |
| of a protocol message field. |
| |
| Args: |
| proto_field_name: The protocol message field name, exactly |
| as it appears (or would appear) in a .proto file. |
| """ |
| # TODO: Escape Python keywords (e.g., yield), and test this support. |
| # nnorwitz makes my day by writing: |
| # """ |
| # FYI. See the keyword module in the stdlib. This could be as simple as: |
| # |
| # if keyword.iskeyword(proto_field_name): |
| # return proto_field_name + "_" |
| # return proto_field_name |
| # """ |
| # Kenton says: The above is a BAD IDEA. People rely on being able to use |
| # getattr() and setattr() to reflectively manipulate field values. If we |
| # rename the properties, then every such user has to also make sure to apply |
| # the same transformation. Note that currently if you name a field "yield", |
| # you can still access it just fine using getattr/setattr -- it's not even |
| # that cumbersome to do so. |
| # TODO: Remove this method entirely if/when everyone agrees with my |
| # position. |
| return proto_field_name |
| |
| |
| def _AddSlots(message_descriptor, dictionary): |
| """Adds a __slots__ entry to dictionary, containing the names of all valid |
| attributes for this message type. |
| |
| Args: |
| message_descriptor: A Descriptor instance describing this message type. |
| dictionary: Class dictionary to which we'll add a '__slots__' entry. |
| """ |
| dictionary['__slots__'] = ['_cached_byte_size', |
| '_cached_byte_size_dirty', |
| '_fields', |
| '_unknown_fields', |
| '_is_present_in_parent', |
| '_listener', |
| '_listener_for_children', |
| '__weakref__', |
| '_oneofs'] |
| |
| |
| def _IsMessageSetExtension(field): |
| return (field.is_extension and |
| field.containing_type.has_options and |
| field.containing_type.GetOptions().message_set_wire_format and |
| field.type == _FieldDescriptor.TYPE_MESSAGE and |
| field.label == _FieldDescriptor.LABEL_OPTIONAL) |
| |
| |
| def _IsMapField(field): |
| return (field.type == _FieldDescriptor.TYPE_MESSAGE and |
| field.message_type._is_map_entry) |
| |
| |
| def _IsMessageMapField(field): |
| value_type = field.message_type.fields_by_name['value'] |
| return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE |
| |
| def _AttachFieldHelpers(cls, field_descriptor): |
| is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED |
| field_descriptor._default_constructor = _DefaultValueConstructorForField( |
| field_descriptor |
| ) |
| |
| def AddFieldByTag(wiretype, is_packed): |
| tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) |
| cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed) |
| |
| AddFieldByTag( |
| type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False |
| ) |
| |
| if is_repeated and wire_format.IsTypePackable(field_descriptor.type): |
| # To support wire compatibility of adding packed = true, add a decoder for |
| # packed values regardless of the field's options. |
| AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True) |
| |
| |
| def _MaybeAddEncoder(cls, field_descriptor): |
| if hasattr(field_descriptor, '_encoder'): |
| return |
| is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) |
| is_map_entry = _IsMapField(field_descriptor) |
| is_packed = field_descriptor.is_packed |
| |
| if is_map_entry: |
| field_encoder = encoder.MapEncoder(field_descriptor) |
| sizer = encoder.MapSizer(field_descriptor, |
| _IsMessageMapField(field_descriptor)) |
| elif _IsMessageSetExtension(field_descriptor): |
| field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) |
| sizer = encoder.MessageSetItemSizer(field_descriptor.number) |
| else: |
| field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( |
| field_descriptor.number, is_repeated, is_packed) |
| sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( |
| field_descriptor.number, is_repeated, is_packed) |
| |
| field_descriptor._sizer = sizer |
| field_descriptor._encoder = field_encoder |
| |
| |
| def _MaybeAddDecoder(cls, field_descriptor): |
| if hasattr(field_descriptor, '_decoders'): |
| return |
| |
| is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED |
| is_map_entry = _IsMapField(field_descriptor) |
| helper_decoders = {} |
| |
| def AddDecoder(is_packed): |
| decode_type = field_descriptor.type |
| if (decode_type == _FieldDescriptor.TYPE_ENUM and |
| not field_descriptor.enum_type.is_closed): |
| decode_type = _FieldDescriptor.TYPE_INT32 |
| |
| oneof_descriptor = None |
| if field_descriptor.containing_oneof is not None: |
| oneof_descriptor = field_descriptor |
| |
| if is_map_entry: |
| is_message_map = _IsMessageMapField(field_descriptor) |
| |
| field_decoder = decoder.MapDecoder( |
| field_descriptor, _GetInitializeDefaultForMap(field_descriptor), |
| is_message_map) |
| elif decode_type == _FieldDescriptor.TYPE_STRING: |
| field_decoder = decoder.StringDecoder( |
| field_descriptor.number, is_repeated, is_packed, |
| field_descriptor, field_descriptor._default_constructor, |
| not field_descriptor.has_presence) |
| elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( |
| field_descriptor.number, is_repeated, is_packed, |
| field_descriptor, field_descriptor._default_constructor) |
| else: |
| field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( |
| field_descriptor.number, is_repeated, is_packed, |
| # pylint: disable=protected-access |
| field_descriptor, field_descriptor._default_constructor, |
| not field_descriptor.has_presence) |
| |
| helper_decoders[is_packed] = field_decoder |
| |
| AddDecoder(False) |
| |
| if is_repeated and wire_format.IsTypePackable(field_descriptor.type): |
| # To support wire compatibility of adding packed = true, add a decoder for |
| # packed values regardless of the field's options. |
| AddDecoder(True) |
| |
| field_descriptor._decoders = helper_decoders |
| |
| |
| def _AddClassAttributesForNestedExtensions(descriptor, dictionary): |
| extensions = descriptor.extensions_by_name |
| for extension_name, extension_field in extensions.items(): |
| assert extension_name not in dictionary |
| dictionary[extension_name] = extension_field |
| |
| |
| def _AddEnumValues(descriptor, cls): |
| """Sets class-level attributes for all enum fields defined in this message. |
| |
| Also exporting a class-level object that can name enum values. |
| |
| Args: |
| descriptor: Descriptor object for this message type. |
| cls: Class we're constructing for this message type. |
| """ |
| for enum_type in descriptor.enum_types: |
| setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) |
| for enum_value in enum_type.values: |
| setattr(cls, enum_value.name, enum_value.number) |
| |
| |
| def _GetInitializeDefaultForMap(field): |
| if field.label != _FieldDescriptor.LABEL_REPEATED: |
| raise ValueError('map_entry set on non-repeated field %s' % ( |
| field.name)) |
| fields_by_name = field.message_type.fields_by_name |
| key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) |
| |
| value_field = fields_by_name['value'] |
| if _IsMessageMapField(field): |
| def MakeMessageMapDefault(message): |
| return containers.MessageMap( |
| message._listener_for_children, value_field.message_type, key_checker, |
| field.message_type) |
| return MakeMessageMapDefault |
| else: |
| value_checker = type_checkers.GetTypeChecker(value_field) |
| def MakePrimitiveMapDefault(message): |
| return containers.ScalarMap( |
| message._listener_for_children, key_checker, value_checker, |
| field.message_type) |
| return MakePrimitiveMapDefault |
| |
| def _DefaultValueConstructorForField(field): |
| """Returns a function which returns a default value for a field. |
| |
| Args: |
| field: FieldDescriptor object for this field. |
| |
| The returned function has one argument: |
| message: Message instance containing this field, or a weakref proxy |
| of same. |
| |
| That function in turn returns a default value for this field. The default |
| value may refer back to |message| via a weak reference. |
| """ |
| |
| if _IsMapField(field): |
| return _GetInitializeDefaultForMap(field) |
| |
| if field.label == _FieldDescriptor.LABEL_REPEATED: |
| if field.has_default_value and field.default_value != []: |
| raise ValueError('Repeated field default value not empty list: %s' % ( |
| field.default_value)) |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| # We can't look at _concrete_class yet since it might not have |
| # been set. (Depends on order in which we initialize the classes). |
| message_type = field.message_type |
| def MakeRepeatedMessageDefault(message): |
| return containers.RepeatedCompositeFieldContainer( |
| message._listener_for_children, field.message_type) |
| return MakeRepeatedMessageDefault |
| else: |
| type_checker = type_checkers.GetTypeChecker(field) |
| def MakeRepeatedScalarDefault(message): |
| return containers.RepeatedScalarFieldContainer( |
| message._listener_for_children, type_checker) |
| return MakeRepeatedScalarDefault |
| |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| message_type = field.message_type |
| def MakeSubMessageDefault(message): |
| # _concrete_class may not yet be initialized. |
| if not hasattr(message_type, '_concrete_class'): |
| from google.protobuf import message_factory |
| message_factory.GetMessageClass(message_type) |
| result = message_type._concrete_class() |
| result._SetListener( |
| _OneofListener(message, field) |
| if field.containing_oneof is not None |
| else message._listener_for_children) |
| return result |
| return MakeSubMessageDefault |
| |
| def MakeScalarDefault(message): |
| # TODO: This may be broken since there may not be |
| # default_value. Combine with has_default_value somehow. |
| return field.default_value |
| return MakeScalarDefault |
| |
| |
| def _ReraiseTypeErrorWithFieldName(message_name, field_name): |
| """Re-raise the currently-handled TypeError with the field name added.""" |
| exc = sys.exc_info()[1] |
| if len(exc.args) == 1 and type(exc) is TypeError: |
| # simple TypeError; add field name to exception message |
| exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) |
| |
| # re-raise possibly-amended exception with original traceback: |
| raise exc.with_traceback(sys.exc_info()[2]) |
| |
| |
| def _AddInitMethod(message_descriptor, cls): |
| """Adds an __init__ method to cls.""" |
| |
| def _GetIntegerEnumValue(enum_type, value): |
| """Convert a string or integer enum value to an integer. |
| |
| If the value is a string, it is converted to the enum value in |
| enum_type with the same name. If the value is not a string, it's |
| returned as-is. (No conversion or bounds-checking is done.) |
| """ |
| if isinstance(value, str): |
| try: |
| return enum_type.values_by_name[value].number |
| except KeyError: |
| raise ValueError('Enum type %s: unknown label "%s"' % ( |
| enum_type.full_name, value)) |
| return value |
| |
| def init(self, **kwargs): |
| self._cached_byte_size = 0 |
| self._cached_byte_size_dirty = len(kwargs) > 0 |
| self._fields = {} |
| # Contains a mapping from oneof field descriptors to the descriptor |
| # of the currently set field in that oneof field. |
| self._oneofs = {} |
| |
| # _unknown_fields is () when empty for efficiency, and will be turned into |
| # a list if fields are added. |
| self._unknown_fields = () |
| self._is_present_in_parent = False |
| self._listener = message_listener_mod.NullMessageListener() |
| self._listener_for_children = _Listener(self) |
| for field_name, field_value in kwargs.items(): |
| field = _GetFieldByName(message_descriptor, field_name) |
| if field is None: |
| raise TypeError('%s() got an unexpected keyword argument "%s"' % |
| (message_descriptor.name, field_name)) |
| if field_value is None: |
| # field=None is the same as no field at all. |
| continue |
| if field.label == _FieldDescriptor.LABEL_REPEATED: |
| field_copy = field._default_constructor(self) |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite |
| if _IsMapField(field): |
| if _IsMessageMapField(field): |
| for key in field_value: |
| field_copy[key].MergeFrom(field_value[key]) |
| else: |
| field_copy.update(field_value) |
| else: |
| for val in field_value: |
| if isinstance(val, dict): |
| field_copy.add(**val) |
| else: |
| field_copy.add().MergeFrom(val) |
| else: # Scalar |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
| field_value = [_GetIntegerEnumValue(field.enum_type, val) |
| for val in field_value] |
| field_copy.extend(field_value) |
| self._fields[field] = field_copy |
| elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| field_copy = field._default_constructor(self) |
| new_val = None |
| if isinstance(field_value, message_mod.Message): |
| new_val = field_value |
| elif isinstance(field_value, dict): |
| if field.message_type.full_name == _StructFullTypeName: |
| field_copy.Clear() |
| if len(field_value) == 1 and 'fields' in field_value: |
| try: |
| field_copy.update(field_value) |
| except: |
| # Fall back to init normal message field |
| field_copy.Clear() |
| new_val = field.message_type._concrete_class(**field_value) |
| else: |
| field_copy.update(field_value) |
| else: |
| new_val = field.message_type._concrete_class(**field_value) |
| elif hasattr(field_copy, '_internal_assign'): |
| field_copy._internal_assign(field_value) |
| else: |
| raise TypeError( |
| 'Message field {0}.{1} must be initialized with a ' |
| 'dict or instance of same class, got {2}.'.format( |
| message_descriptor.name, |
| field_name, |
| type(field_value).__name__, |
| ) |
| ) |
| |
| if new_val: |
| try: |
| field_copy.MergeFrom(new_val) |
| except TypeError: |
| _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
| self._fields[field] = field_copy |
| else: |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: |
| field_value = _GetIntegerEnumValue(field.enum_type, field_value) |
| try: |
| setattr(self, field_name, field_value) |
| except TypeError: |
| _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) |
| |
| init.__module__ = None |
| init.__doc__ = None |
| cls.__init__ = init |
| |
| |
| def _GetFieldByName(message_descriptor, field_name): |
| """Returns a field descriptor by field name. |
| |
| Args: |
| message_descriptor: A Descriptor describing all fields in message. |
| field_name: The name of the field to retrieve. |
| Returns: |
| The field descriptor associated with the field name. |
| """ |
| try: |
| return message_descriptor.fields_by_name[field_name] |
| except KeyError: |
| raise ValueError('Protocol message %s has no "%s" field.' % |
| (message_descriptor.name, field_name)) |
| |
| |
| def _AddPropertiesForFields(descriptor, cls): |
| """Adds properties for all fields in this protocol message type.""" |
| for field in descriptor.fields: |
| _AddPropertiesForField(field, cls) |
| |
| if descriptor.is_extendable: |
| # _ExtensionDict is just an adaptor with no state so we allocate a new one |
| # every time it is accessed. |
| cls.Extensions = property(lambda self: _ExtensionDict(self)) |
| |
| |
| def _AddPropertiesForField(field, cls): |
| """Adds a public property for a protocol message field. |
| Clients can use this property to get and (in the case |
| of non-repeated scalar fields) directly set the value |
| of a protocol message field. |
| |
| Args: |
| field: A FieldDescriptor for this field. |
| cls: The class we're constructing. |
| """ |
| # Catch it if we add other types that we should |
| # handle specially here. |
| assert _FieldDescriptor.MAX_CPPTYPE == 10 |
| |
| constant_name = field.name.upper() + '_FIELD_NUMBER' |
| setattr(cls, constant_name, field.number) |
| |
| if field.label == _FieldDescriptor.LABEL_REPEATED: |
| _AddPropertiesForRepeatedField(field, cls) |
| elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| _AddPropertiesForNonRepeatedCompositeField(field, cls) |
| else: |
| _AddPropertiesForNonRepeatedScalarField(field, cls) |
| |
| |
| class _FieldProperty(property): |
| __slots__ = ('DESCRIPTOR',) |
| |
| def __init__(self, descriptor, getter, setter, doc): |
| property.__init__(self, getter, setter, doc=doc) |
| self.DESCRIPTOR = descriptor |
| |
| |
| def _AddPropertiesForRepeatedField(field, cls): |
| """Adds a public property for a "repeated" protocol message field. Clients |
| can use this property to get the value of the field, which will be either a |
| RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see |
| below). |
| |
| Note that when clients add values to these containers, we perform |
| type-checking in the case of repeated scalar fields, and we also set any |
| necessary "has" bits as a side-effect. |
| |
| Args: |
| field: A FieldDescriptor for this field. |
| cls: The class we're constructing. |
| """ |
| proto_field_name = field.name |
| property_name = _PropertyName(proto_field_name) |
| |
| def getter(self): |
| field_value = self._fields.get(field) |
| if field_value is None: |
| # Construct a new object to represent this field. |
| field_value = field._default_constructor(self) |
| |
| # Atomically check if another thread has preempted us and, if not, swap |
| # in the new object we just created. If someone has preempted us, we |
| # take that object and discard ours. |
| # WARNING: We are relying on setdefault() being atomic. This is true |
| # in CPython but we haven't investigated others. This warning appears |
| # in several other locations in this file. |
| field_value = self._fields.setdefault(field, field_value) |
| return field_value |
| getter.__module__ = None |
| getter.__doc__ = 'Getter for %s.' % proto_field_name |
| |
| # We define a setter just so we can throw an exception with a more |
| # helpful error message. |
| def setter(self, new_value): |
| raise AttributeError('Assignment not allowed to repeated field ' |
| '"%s" in protocol message object.' % proto_field_name) |
| |
| doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) |
| |
| |
| def _AddPropertiesForNonRepeatedScalarField(field, cls): |
| """Adds a public property for a nonrepeated, scalar protocol message field. |
| Clients can use this property to get and directly set the value of the field. |
| Note that when the client sets the value of a field by using this property, |
| all necessary "has" bits are set as a side-effect, and we also perform |
| type-checking. |
| |
| Args: |
| field: A FieldDescriptor for this field. |
| cls: The class we're constructing. |
| """ |
| proto_field_name = field.name |
| property_name = _PropertyName(proto_field_name) |
| type_checker = type_checkers.GetTypeChecker(field) |
| default_value = field.default_value |
| |
| def getter(self): |
| # TODO: This may be broken since there may not be |
| # default_value. Combine with has_default_value somehow. |
| return self._fields.get(field, default_value) |
| getter.__module__ = None |
| getter.__doc__ = 'Getter for %s.' % proto_field_name |
| |
| def field_setter(self, new_value): |
| # pylint: disable=protected-access |
| # Testing the value for truthiness captures all of the proto3 defaults |
| # (0, 0.0, enum 0, and False). |
| try: |
| new_value = type_checker.CheckValue(new_value) |
| except TypeError as e: |
| raise TypeError( |
| 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) |
| if not field.has_presence and not new_value: |
| self._fields.pop(field, None) |
| else: |
| self._fields[field] = new_value |
| # Check _cached_byte_size_dirty inline to improve performance, since scalar |
| # setters are called frequently. |
| if not self._cached_byte_size_dirty: |
| self._Modified() |
| |
| if field.containing_oneof: |
| def setter(self, new_value): |
| field_setter(self, new_value) |
| self._UpdateOneofState(field) |
| else: |
| setter = field_setter |
| |
| setter.__module__ = None |
| setter.__doc__ = 'Setter for %s.' % proto_field_name |
| |
| # Add a property to encapsulate the getter/setter. |
| doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) |
| |
| |
| def _AddPropertiesForNonRepeatedCompositeField(field, cls): |
| """Adds a public property for a nonrepeated, composite protocol message field. |
| A composite field is a "group" or "message" field. |
| |
| Clients can use this property to get the value of the field, but cannot |
| assign to the property directly. |
| |
| Args: |
| field: A FieldDescriptor for this field. |
| cls: The class we're constructing. |
| """ |
| # TODO: Remove duplication with similar method |
| # for non-repeated scalars. |
| proto_field_name = field.name |
| property_name = _PropertyName(proto_field_name) |
| |
| def getter(self): |
| field_value = self._fields.get(field) |
| if field_value is None: |
| # Construct a new object to represent this field. |
| field_value = field._default_constructor(self) |
| |
| # Atomically check if another thread has preempted us and, if not, swap |
| # in the new object we just created. If someone has preempted us, we |
| # take that object and discard ours. |
| # WARNING: We are relying on setdefault() being atomic. This is true |
| # in CPython but we haven't investigated others. This warning appears |
| # in several other locations in this file. |
| field_value = self._fields.setdefault(field, field_value) |
| return field_value |
| getter.__module__ = None |
| getter.__doc__ = 'Getter for %s.' % proto_field_name |
| |
| # We define a setter just so we can throw an exception with a more |
| # helpful error message. |
| def setter(self, new_value): |
| if field.message_type.full_name == 'google.protobuf.Timestamp': |
| getter(self) |
| self._fields[field].FromDatetime(new_value) |
| elif field.message_type.full_name == 'google.protobuf.Duration': |
| getter(self) |
| self._fields[field].FromTimedelta(new_value) |
| elif field.message_type.full_name == _StructFullTypeName: |
| getter(self) |
| self._fields[field].Clear() |
| self._fields[field].update(new_value) |
| elif field.message_type.full_name == _ListValueFullTypeName: |
| getter(self) |
| self._fields[field].Clear() |
| self._fields[field].extend(new_value) |
| else: |
| raise AttributeError( |
| 'Assignment not allowed to composite field ' |
| '"%s" in protocol message object.' % proto_field_name |
| ) |
| |
| # Add a property to encapsulate the getter. |
| doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name |
| setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) |
| |
| |
| def _AddPropertiesForExtensions(descriptor, cls): |
| """Adds properties for all fields in this protocol message type.""" |
| extensions = descriptor.extensions_by_name |
| for extension_name, extension_field in extensions.items(): |
| constant_name = extension_name.upper() + '_FIELD_NUMBER' |
| setattr(cls, constant_name, extension_field.number) |
| |
| # TODO: Migrate all users of these attributes to functions like |
| # pool.FindExtensionByNumber(descriptor). |
| if descriptor.file is not None: |
| # TODO: Use cls.MESSAGE_FACTORY.pool when available. |
| pool = descriptor.file.pool |
| |
| def _AddStaticMethods(cls): |
| def FromString(s): |
| message = cls() |
| message.MergeFromString(s) |
| return message |
| cls.FromString = staticmethod(FromString) |
| |
| |
| def _IsPresent(item): |
| """Given a (FieldDescriptor, value) tuple from _fields, return true if the |
| value should be included in the list returned by ListFields().""" |
| |
| if item[0].label == _FieldDescriptor.LABEL_REPEATED: |
| return bool(item[1]) |
| elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| return item[1]._is_present_in_parent |
| else: |
| return True |
| |
| |
| def _AddListFieldsMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| |
| def ListFields(self): |
| all_fields = [item for item in self._fields.items() if _IsPresent(item)] |
| all_fields.sort(key = lambda item: item[0].number) |
| return all_fields |
| |
| cls.ListFields = ListFields |
| |
| |
| def _AddHasFieldMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| |
| hassable_fields = {} |
| for field in message_descriptor.fields: |
| if field.label == _FieldDescriptor.LABEL_REPEATED: |
| continue |
| # For proto3, only submessages and fields inside a oneof have presence. |
| if not field.has_presence: |
| continue |
| hassable_fields[field.name] = field |
| |
| # Has methods are supported for oneof descriptors. |
| for oneof in message_descriptor.oneofs: |
| hassable_fields[oneof.name] = oneof |
| |
| def HasField(self, field_name): |
| try: |
| field = hassable_fields[field_name] |
| except KeyError as exc: |
| raise ValueError('Protocol message %s has no non-repeated field "%s" ' |
| 'nor has presence is not available for this field.' % ( |
| message_descriptor.full_name, field_name)) from exc |
| |
| if isinstance(field, descriptor_mod.OneofDescriptor): |
| try: |
| return HasField(self, self._oneofs[field].name) |
| except KeyError: |
| return False |
| else: |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| value = self._fields.get(field) |
| return value is not None and value._is_present_in_parent |
| else: |
| return field in self._fields |
| |
| cls.HasField = HasField |
| |
| |
| def _AddClearFieldMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| def ClearField(self, field_name): |
| try: |
| field = message_descriptor.fields_by_name[field_name] |
| except KeyError: |
| try: |
| field = message_descriptor.oneofs_by_name[field_name] |
| if field in self._oneofs: |
| field = self._oneofs[field] |
| else: |
| return |
| except KeyError: |
| raise ValueError('Protocol message %s has no "%s" field.' % |
| (message_descriptor.name, field_name)) |
| |
| if field in self._fields: |
| # To match the C++ implementation, we need to invalidate iterators |
| # for map fields when ClearField() happens. |
| if hasattr(self._fields[field], 'InvalidateIterators'): |
| self._fields[field].InvalidateIterators() |
| |
| # Note: If the field is a sub-message, its listener will still point |
| # at us. That's fine, because the worst than can happen is that it |
| # will call _Modified() and invalidate our byte size. Big deal. |
| del self._fields[field] |
| |
| if self._oneofs.get(field.containing_oneof, None) is field: |
| del self._oneofs[field.containing_oneof] |
| |
| # Always call _Modified() -- even if nothing was changed, this is |
| # a mutating method, and thus calling it should cause the field to become |
| # present in the parent message. |
| self._Modified() |
| |
| cls.ClearField = ClearField |
| |
| |
| def _AddClearExtensionMethod(cls): |
| """Helper for _AddMessageMethods().""" |
| def ClearExtension(self, field_descriptor): |
| extension_dict._VerifyExtensionHandle(self, field_descriptor) |
| |
| # Similar to ClearField(), above. |
| if field_descriptor in self._fields: |
| del self._fields[field_descriptor] |
| self._Modified() |
| cls.ClearExtension = ClearExtension |
| |
| |
| def _AddHasExtensionMethod(cls): |
| """Helper for _AddMessageMethods().""" |
| def HasExtension(self, field_descriptor): |
| extension_dict._VerifyExtensionHandle(self, field_descriptor) |
| if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: |
| raise KeyError('"%s" is repeated.' % field_descriptor.full_name) |
| |
| if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| value = self._fields.get(field_descriptor) |
| return value is not None and value._is_present_in_parent |
| else: |
| return field_descriptor in self._fields |
| cls.HasExtension = HasExtension |
| |
| def _InternalUnpackAny(msg): |
| """Unpacks Any message and returns the unpacked message. |
| |
| This internal method is different from public Any Unpack method which takes |
| the target message as argument. _InternalUnpackAny method does not have |
| target message type and need to find the message type in descriptor pool. |
| |
| Args: |
| msg: An Any message to be unpacked. |
| |
| Returns: |
| The unpacked message. |
| """ |
| # TODO: Don't use the factory of generated messages. |
| # To make Any work with custom factories, use the message factory of the |
| # parent message. |
| # pylint: disable=g-import-not-at-top |
| from google.protobuf import symbol_database |
| factory = symbol_database.Default() |
| |
| type_url = msg.type_url |
| |
| if not type_url: |
| return None |
| |
| # TODO: For now we just strip the hostname. Better logic will be |
| # required. |
| type_name = type_url.split('/')[-1] |
| descriptor = factory.pool.FindMessageTypeByName(type_name) |
| |
| if descriptor is None: |
| return None |
| |
| message_class = factory.GetPrototype(descriptor) |
| message = message_class() |
| |
| message.ParseFromString(msg.value) |
| return message |
| |
| |
| def _AddEqualsMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| def __eq__(self, other): |
| if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance( |
| other, list |
| ): |
| return self._internal_compare(other) |
| if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance( |
| other, dict |
| ): |
| return self._internal_compare(other) |
| |
| if (not isinstance(other, message_mod.Message) or |
| other.DESCRIPTOR != self.DESCRIPTOR): |
| return NotImplemented |
| |
| if self is other: |
| return True |
| |
| if self.DESCRIPTOR.full_name == _AnyFullTypeName: |
| any_a = _InternalUnpackAny(self) |
| any_b = _InternalUnpackAny(other) |
| if any_a and any_b: |
| return any_a == any_b |
| |
| if not self.ListFields() == other.ListFields(): |
| return False |
| |
| # TODO: Fix UnknownFieldSet to consider MessageSet extensions, |
| # then use it for the comparison. |
| unknown_fields = list(self._unknown_fields) |
| unknown_fields.sort() |
| other_unknown_fields = list(other._unknown_fields) |
| other_unknown_fields.sort() |
| return unknown_fields == other_unknown_fields |
| |
| cls.__eq__ = __eq__ |
| |
| |
| def _AddStrMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| def __str__(self): |
| return text_format.MessageToString(self) |
| cls.__str__ = __str__ |
| |
| |
| def _AddReprMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| def __repr__(self): |
| return text_format.MessageToString(self) |
| cls.__repr__ = __repr__ |
| |
| |
| def _AddUnicodeMethod(unused_message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| |
| def __unicode__(self): |
| return text_format.MessageToString(self, as_utf8=True).decode('utf-8') |
| cls.__unicode__ = __unicode__ |
| |
| |
| def _AddContainsMethod(message_descriptor, cls): |
| |
| if message_descriptor.full_name == 'google.protobuf.Struct': |
| def __contains__(self, key): |
| return key in self.fields |
| elif message_descriptor.full_name == 'google.protobuf.ListValue': |
| def __contains__(self, value): |
| return value in self.items() |
| else: |
| def __contains__(self, field): |
| return self.HasField(field) |
| |
| cls.__contains__ = __contains__ |
| |
| |
| def _BytesForNonRepeatedElement(value, field_number, field_type): |
| """Returns the number of bytes needed to serialize a non-repeated element. |
| The returned byte count includes space for tag information and any |
| other additional space associated with serializing value. |
| |
| Args: |
| value: Value we're serializing. |
| field_number: Field number of this value. (Since the field number |
| is stored as part of a varint-encoded tag, this has an impact |
| on the total bytes required to serialize the value). |
| field_type: The type of the field. One of the TYPE_* constants |
| within FieldDescriptor. |
| """ |
| try: |
| fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] |
| return fn(field_number, value) |
| except KeyError: |
| raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) |
| |
| |
| def _AddByteSizeMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| |
| def ByteSize(self): |
| if not self._cached_byte_size_dirty: |
| return self._cached_byte_size |
| |
| size = 0 |
| descriptor = self.DESCRIPTOR |
| if descriptor._is_map_entry: |
| # Fields of map entry should always be serialized. |
| key_field = descriptor.fields_by_name['key'] |
| _MaybeAddEncoder(cls, key_field) |
| size = key_field._sizer(self.key) |
| value_field = descriptor.fields_by_name['value'] |
| _MaybeAddEncoder(cls, value_field) |
| size += value_field._sizer(self.value) |
| else: |
| for field_descriptor, field_value in self.ListFields(): |
| _MaybeAddEncoder(cls, field_descriptor) |
| size += field_descriptor._sizer(field_value) |
| for tag_bytes, value_bytes in self._unknown_fields: |
| size += len(tag_bytes) + len(value_bytes) |
| |
| self._cached_byte_size = size |
| self._cached_byte_size_dirty = False |
| self._listener_for_children.dirty = False |
| return size |
| |
| cls.ByteSize = ByteSize |
| |
| |
| def _AddSerializeToStringMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| |
| def SerializeToString(self, **kwargs): |
| # Check if the message has all of its required fields set. |
| if not self.IsInitialized(): |
| raise message_mod.EncodeError( |
| 'Message %s is missing required fields: %s' % ( |
| self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) |
| return self.SerializePartialToString(**kwargs) |
| cls.SerializeToString = SerializeToString |
| |
| |
| def _AddSerializePartialToStringMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| |
| def SerializePartialToString(self, **kwargs): |
| out = BytesIO() |
| self._InternalSerialize(out.write, **kwargs) |
| return out.getvalue() |
| cls.SerializePartialToString = SerializePartialToString |
| |
| def InternalSerialize(self, write_bytes, deterministic=None): |
| if deterministic is None: |
| deterministic = ( |
| api_implementation.IsPythonDefaultSerializationDeterministic()) |
| else: |
| deterministic = bool(deterministic) |
| |
| descriptor = self.DESCRIPTOR |
| if descriptor._is_map_entry: |
| # Fields of map entry should always be serialized. |
| key_field = descriptor.fields_by_name['key'] |
| _MaybeAddEncoder(cls, key_field) |
| key_field._encoder(write_bytes, self.key, deterministic) |
| value_field = descriptor.fields_by_name['value'] |
| _MaybeAddEncoder(cls, value_field) |
| value_field._encoder(write_bytes, self.value, deterministic) |
| else: |
| for field_descriptor, field_value in self.ListFields(): |
| _MaybeAddEncoder(cls, field_descriptor) |
| field_descriptor._encoder(write_bytes, field_value, deterministic) |
| for tag_bytes, value_bytes in self._unknown_fields: |
| write_bytes(tag_bytes) |
| write_bytes(value_bytes) |
| cls._InternalSerialize = InternalSerialize |
| |
| |
| def _AddMergeFromStringMethod(message_descriptor, cls): |
| """Helper for _AddMessageMethods().""" |
| def MergeFromString(self, serialized): |
| serialized = memoryview(serialized) |
| length = len(serialized) |
| try: |
| if self._InternalParse(serialized, 0, length) != length: |
| # The only reason _InternalParse would return early is if it |
| # encountered an end-group tag. |
| raise message_mod.DecodeError('Unexpected end-group tag.') |
| except (IndexError, TypeError): |
| # Now ord(buf[p:p+1]) == ord('') gets TypeError. |
| raise message_mod.DecodeError('Truncated message.') |
| except struct.error as e: |
| raise message_mod.DecodeError(e) |
| return length # Return this for legacy reasons. |
| cls.MergeFromString = MergeFromString |
| |
| local_ReadTag = decoder.ReadTag |
| local_SkipField = decoder.SkipField |
| fields_by_tag = cls._fields_by_tag |
| message_set_decoders_by_tag = cls._message_set_decoders_by_tag |
| |
| def InternalParse(self, buffer, pos, end): |
| """Create a message from serialized bytes. |
| |
| Args: |
| self: Message, instance of the proto message object. |
| buffer: memoryview of the serialized data. |
| pos: int, position to start in the serialized data. |
| end: int, end position of the serialized data. |
| |
| Returns: |
| Message object. |
| """ |
| # Guard against internal misuse, since this function is called internally |
| # quite extensively, and its easy to accidentally pass bytes. |
| assert isinstance(buffer, memoryview) |
| self._Modified() |
| field_dict = self._fields |
| while pos != end: |
| (tag_bytes, new_pos) = local_ReadTag(buffer, pos) |
| field_decoder, field_des = message_set_decoders_by_tag.get( |
| tag_bytes, (None, None) |
| ) |
| if field_decoder: |
| pos = field_decoder(buffer, new_pos, end, self, field_dict) |
| continue |
| field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None)) |
| if field_des is None: |
| if not self._unknown_fields: # pylint: disable=protected-access |
| self._unknown_fields = [] # pylint: disable=protected-access |
| # pylint: disable=protected-access |
| (tag, _) = decoder._DecodeVarint(tag_bytes, 0) |
| field_number, wire_type = wire_format.UnpackTag(tag) |
| if field_number == 0: |
| raise message_mod.DecodeError('Field number 0 is illegal.') |
| # TODO: remove old_pos. |
| old_pos = new_pos |
| (data, new_pos) = decoder._DecodeUnknownField( |
| buffer, new_pos, wire_type) # pylint: disable=protected-access |
| if new_pos == -1: |
| return pos |
| # TODO: remove _unknown_fields. |
| new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) |
| if new_pos == -1: |
| return pos |
| self._unknown_fields.append( |
| (tag_bytes, buffer[old_pos:new_pos].tobytes())) |
| pos = new_pos |
| else: |
| _MaybeAddDecoder(cls, field_des) |
| field_decoder = field_des._decoders[is_packed] |
| pos = field_decoder(buffer, new_pos, end, self, field_dict) |
| if field_des.containing_oneof: |
| self._UpdateOneofState(field_des) |
| return pos |
| cls._InternalParse = InternalParse |
| |
| |
| def _AddIsInitializedMethod(message_descriptor, cls): |
| """Adds the IsInitialized and FindInitializationError methods to the |
| protocol message class.""" |
| |
| required_fields = [field for field in message_descriptor.fields |
| if field.label == _FieldDescriptor.LABEL_REQUIRED] |
| |
| def IsInitialized(self, errors=None): |
| """Checks if all required fields of a message are set. |
| |
| Args: |
| errors: A list which, if provided, will be populated with the field |
| paths of all missing required fields. |
| |
| Returns: |
| True iff the specified message has all required fields set. |
| """ |
| |
| # Performance is critical so we avoid HasField() and ListFields(). |
| |
| for field in required_fields: |
| if (field not in self._fields or |
| (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and |
| not self._fields[field]._is_present_in_parent)): |
| if errors is not None: |
| errors.extend(self.FindInitializationErrors()) |
| return False |
| |
| for field, value in list(self._fields.items()): # dict can change size! |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| if field.label == _FieldDescriptor.LABEL_REPEATED: |
| if (field.message_type._is_map_entry): |
| continue |
| for element in value: |
| if not element.IsInitialized(): |
| if errors is not None: |
| errors.extend(self.FindInitializationErrors()) |
| return False |
| elif value._is_present_in_parent and not value.IsInitialized(): |
| if errors is not None: |
| errors.extend(self.FindInitializationErrors()) |
| return False |
| |
| return True |
| |
| cls.IsInitialized = IsInitialized |
| |
| def FindInitializationErrors(self): |
| """Finds required fields which are not initialized. |
| |
| Returns: |
| A list of strings. Each string is a path to an uninitialized field from |
| the top-level message, e.g. "foo.bar[5].baz". |
| """ |
| |
| errors = [] # simplify things |
| |
| for field in required_fields: |
| if not self.HasField(field.name): |
| errors.append(field.name) |
| |
| for field, value in self.ListFields(): |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| if field.is_extension: |
| name = '(%s)' % field.full_name |
| else: |
| name = field.name |
| |
| if _IsMapField(field): |
| if _IsMessageMapField(field): |
| for key in value: |
| element = value[key] |
| prefix = '%s[%s].' % (name, key) |
| sub_errors = element.FindInitializationErrors() |
| errors += [prefix + error for error in sub_errors] |
| else: |
| # ScalarMaps can't have any initialization errors. |
| pass |
| elif field.label == _FieldDescriptor.LABEL_REPEATED: |
| for i in range(len(value)): |
| element = value[i] |
| prefix = '%s[%d].' % (name, i) |
| sub_errors = element.FindInitializationErrors() |
| errors += [prefix + error for error in sub_errors] |
| else: |
| prefix = name + '.' |
| sub_errors = value.FindInitializationErrors() |
| errors += [prefix + error for error in sub_errors] |
| |
| return errors |
| |
| cls.FindInitializationErrors = FindInitializationErrors |
| |
| |
| def _FullyQualifiedClassName(klass): |
| module = klass.__module__ |
| name = getattr(klass, '__qualname__', klass.__name__) |
| if module in (None, 'builtins', '__builtin__'): |
| return name |
| return module + '.' + name |
| |
| |
| def _AddMergeFromMethod(cls): |
| LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED |
| CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE |
| |
| def MergeFrom(self, msg): |
| if not isinstance(msg, cls): |
| raise TypeError( |
| 'Parameter to MergeFrom() must be instance of same class: ' |
| 'expected %s got %s.' % (_FullyQualifiedClassName(cls), |
| _FullyQualifiedClassName(msg.__class__))) |
| |
| assert msg is not self |
| self._Modified() |
| |
| fields = self._fields |
| |
| for field, value in msg._fields.items(): |
| if field.label == LABEL_REPEATED: |
| field_value = fields.get(field) |
| if field_value is None: |
| # Construct a new object to represent this field. |
| field_value = field._default_constructor(self) |
| fields[field] = field_value |
| field_value.MergeFrom(value) |
| elif field.cpp_type == CPPTYPE_MESSAGE: |
| if value._is_present_in_parent: |
| field_value = fields.get(field) |
| if field_value is None: |
| # Construct a new object to represent this field. |
| field_value = field._default_constructor(self) |
| fields[field] = field_value |
| field_value.MergeFrom(value) |
| else: |
| self._fields[field] = value |
| if field.containing_oneof: |
| self._UpdateOneofState(field) |
| |
| if msg._unknown_fields: |
| if not self._unknown_fields: |
| self._unknown_fields = [] |
| self._unknown_fields.extend(msg._unknown_fields) |
| |
| cls.MergeFrom = MergeFrom |
| |
| |
| def _AddWhichOneofMethod(message_descriptor, cls): |
| def WhichOneof(self, oneof_name): |
| """Returns the name of the currently set field inside a oneof, or None.""" |
| try: |
| field = message_descriptor.oneofs_by_name[oneof_name] |
| except KeyError: |
| raise ValueError( |
| 'Protocol message has no oneof "%s" field.' % oneof_name) |
| |
| nested_field = self._oneofs.get(field, None) |
| if nested_field is not None and self.HasField(nested_field.name): |
| return nested_field.name |
| else: |
| return None |
| |
| cls.WhichOneof = WhichOneof |
| |
| |
| def _Clear(self): |
| # Clear fields. |
| self._fields = {} |
| self._unknown_fields = () |
| |
| self._oneofs = {} |
| self._Modified() |
| |
| |
| def _UnknownFields(self): |
| raise NotImplementedError('Please use the add-on feaure ' |
| 'unknown_fields.UnknownFieldSet(message) in ' |
| 'unknown_fields.py instead.') |
| |
| |
| def _DiscardUnknownFields(self): |
| self._unknown_fields = [] |
| for field, value in self.ListFields(): |
| if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: |
| if _IsMapField(field): |
| if _IsMessageMapField(field): |
| for key in value: |
| value[key].DiscardUnknownFields() |
| elif field.label == _FieldDescriptor.LABEL_REPEATED: |
| for sub_message in value: |
| sub_message.DiscardUnknownFields() |
| else: |
| value.DiscardUnknownFields() |
| |
| |
| def _SetListener(self, listener): |
| if listener is None: |
| self._listener = message_listener_mod.NullMessageListener() |
| else: |
| self._listener = listener |
| |
| |
| def _AddMessageMethods(message_descriptor, cls): |
| """Adds implementations of all Message methods to cls.""" |
| _AddListFieldsMethod(message_descriptor, cls) |
| _AddHasFieldMethod(message_descriptor, cls) |
| _AddClearFieldMethod(message_descriptor, cls) |
| if message_descriptor.is_extendable: |
| _AddClearExtensionMethod(cls) |
| _AddHasExtensionMethod(cls) |
| _AddEqualsMethod(message_descriptor, cls) |
| _AddStrMethod(message_descriptor, cls) |
| _AddReprMethod(message_descriptor, cls) |
| _AddUnicodeMethod(message_descriptor, cls) |
| _AddContainsMethod(message_descriptor, cls) |
| _AddByteSizeMethod(message_descriptor, cls) |
| _AddSerializeToStringMethod(message_descriptor, cls) |
| _AddSerializePartialToStringMethod(message_descriptor, cls) |
| _AddMergeFromStringMethod(message_descriptor, cls) |
| _AddIsInitializedMethod(message_descriptor, cls) |
| _AddMergeFromMethod(cls) |
| _AddWhichOneofMethod(message_descriptor, cls) |
| # Adds methods which do not depend on cls. |
| cls.Clear = _Clear |
| cls.DiscardUnknownFields = _DiscardUnknownFields |
| cls._SetListener = _SetListener |
| |
| |
| def _AddPrivateHelperMethods(message_descriptor, cls): |
| """Adds implementation of private helper methods to cls.""" |
| |
| def Modified(self): |
| """Sets the _cached_byte_size_dirty bit to true, |
| and propagates this to our listener iff this was a state change. |
| """ |
| |
| # Note: Some callers check _cached_byte_size_dirty before calling |
| # _Modified() as an extra optimization. So, if this method is ever |
| # changed such that it does stuff even when _cached_byte_size_dirty is |
| # already true, the callers need to be updated. |
| if not self._cached_byte_size_dirty: |
| self._cached_byte_size_dirty = True |
| self._listener_for_children.dirty = True |
| self._is_present_in_parent = True |
| self._listener.Modified() |
| |
| def _UpdateOneofState(self, field): |
| """Sets field as the active field in its containing oneof. |
| |
| Will also delete currently active field in the oneof, if it is different |
| from the argument. Does not mark the message as modified. |
| """ |
| other_field = self._oneofs.setdefault(field.containing_oneof, field) |
| if other_field is not field: |
| del self._fields[other_field] |
| self._oneofs[field.containing_oneof] = field |
| |
| cls._Modified = Modified |
| cls.SetInParent = Modified |
| cls._UpdateOneofState = _UpdateOneofState |
| |
| |
| class _Listener(object): |
| |
| """MessageListener implementation that a parent message registers with its |
| child message. |
| |
| In order to support semantics like: |
| |
| foo.bar.baz.moo = 23 |
| assert foo.HasField('bar') |
| |
| ...child objects must have back references to their parents. |
| This helper class is at the heart of this support. |
| """ |
| |
| def __init__(self, parent_message): |
| """Args: |
| parent_message: The message whose _Modified() method we should call when |
| we receive Modified() messages. |
| """ |
| # This listener establishes a back reference from a child (contained) object |
| # to its parent (containing) object. We make this a weak reference to avoid |
| # creating cyclic garbage when the client finishes with the 'parent' object |
| # in the tree. |
| if isinstance(parent_message, weakref.ProxyType): |
| self._parent_message_weakref = parent_message |
| else: |
| self._parent_message_weakref = weakref.proxy(parent_message) |
| |
| # As an optimization, we also indicate directly on the listener whether |
| # or not the parent message is dirty. This way we can avoid traversing |
| # up the tree in the common case. |
| self.dirty = False |
| |
| def Modified(self): |
| if self.dirty: |
| return |
| try: |
| # Propagate the signal to our parents iff this is the first field set. |
| self._parent_message_weakref._Modified() |
| except ReferenceError: |
| # We can get here if a client has kept a reference to a child object, |
| # and is now setting a field on it, but the child's parent has been |
| # garbage-collected. This is not an error. |
| pass |
| |
| |
| class _OneofListener(_Listener): |
| """Special listener implementation for setting composite oneof fields.""" |
| |
| def __init__(self, parent_message, field): |
| """Args: |
| parent_message: The message whose _Modified() method we should call when |
| we receive Modified() messages. |
| field: The descriptor of the field being set in the parent message. |
| """ |
| super(_OneofListener, self).__init__(parent_message) |
| self._field = field |
| |
| def Modified(self): |
| """Also updates the state of the containing oneof in the parent message.""" |
| try: |
| self._parent_message_weakref._UpdateOneofState(self._field) |
| super(_OneofListener, self).Modified() |
| except ReferenceError: |
| pass |