| # Protocol Buffers - Google's data interchange format |
| # Copyright 2008 Google Inc. All rights reserved. |
| # |
| # Use of this source code is governed by a BSD-style |
| # license that can be found in the LICENSE file or at |
| # https://developers.google.com/open-source/licenses/bsd |
| |
| """Code for decoding protocol buffer primitives. |
| |
| This code is very similar to encoder.py -- read the docs for that module first. |
| |
| A "decoder" is a function with the signature: |
| Decode(buffer, pos, end, message, field_dict) |
| The arguments are: |
| buffer: The string containing the encoded message. |
| pos: The current position in the string. |
| end: The position in the string where the current message ends. May be |
| less than len(buffer) if we're reading a sub-message. |
| message: The message object into which we're parsing. |
| field_dict: message._fields (avoids a hashtable lookup). |
| The decoder reads the field and stores it into field_dict, returning the new |
| buffer position. A decoder for a repeated field may proactively decode all of |
| the elements of that field, if they appear consecutively. |
| |
| Note that decoders may throw any of the following: |
| IndexError: Indicates a truncated message. |
| struct.error: Unpacking of a fixed-width field failed. |
| message.DecodeError: Other errors. |
| |
| Decoders are expected to raise an exception if they are called with pos > end. |
| This allows callers to be lax about bounds checking: it's fineto read past |
| "end" as long as you are sure that someone else will notice and throw an |
| exception later on. |
| |
| Something up the call stack is expected to catch IndexError and struct.error |
| and convert them to message.DecodeError. |
| |
| Decoders are constructed using decoder constructors with the signature: |
| MakeDecoder(field_number, is_repeated, is_packed, key, new_default) |
| The arguments are: |
| field_number: The field number of the field we want to decode. |
| is_repeated: Is the field a repeated field? (bool) |
| is_packed: Is the field a packed field? (bool) |
| key: The key to use when looking up the field within field_dict. |
| (This is actually the FieldDescriptor but nothing in this |
| file should depend on that.) |
| new_default: A function which takes a message object as a parameter and |
| returns a new instance of the default value for this field. |
| (This is called for repeated fields and sub-messages, when an |
| instance does not already exist.) |
| |
| As with encoders, we define a decoder constructor for every type of field. |
| Then, for every field of every message class we construct an actual decoder. |
| That decoder goes into a dict indexed by tag, so when we decode a message |
| we repeatedly read a tag, look up the corresponding decoder, and invoke it. |
| """ |
| |
| __author__ = 'kenton@google.com (Kenton Varda)' |
| |
| import math |
| import struct |
| |
| from google.protobuf import message |
| from google.protobuf.internal import containers |
| from google.protobuf.internal import encoder |
| from google.protobuf.internal import wire_format |
| |
| |
| # This is not for optimization, but rather to avoid conflicts with local |
| # variables named "message". |
| _DecodeError = message.DecodeError |
| |
| |
| def _VarintDecoder(mask, result_type): |
| """Return an encoder for a basic varint value (does not include tag). |
| |
| Decoded values will be bitwise-anded with the given mask before being |
| returned, e.g. to limit them to 32 bits. The returned decoder does not |
| take the usual "end" parameter -- the caller is expected to do bounds checking |
| after the fact (often the caller can defer such checking until later). The |
| decoder returns a (value, new_pos) pair. |
| """ |
| |
| def DecodeVarint(buffer, pos: int=None): |
| result = 0 |
| shift = 0 |
| while 1: |
| if pos is None: |
| # Read from BytesIO |
| try: |
| b = buffer.read(1)[0] |
| except IndexError as e: |
| if shift == 0: |
| # End of BytesIO. |
| return None |
| else: |
| raise ValueError('Fail to read varint %s' % str(e)) |
| else: |
| b = buffer[pos] |
| pos += 1 |
| result |= ((b & 0x7f) << shift) |
| if not (b & 0x80): |
| result &= mask |
| result = result_type(result) |
| return result if pos is None else (result, pos) |
| shift += 7 |
| if shift >= 64: |
| raise _DecodeError('Too many bytes when decoding varint.') |
| |
| return DecodeVarint |
| |
| |
| def _SignedVarintDecoder(bits, result_type): |
| """Like _VarintDecoder() but decodes signed values.""" |
| |
| signbit = 1 << (bits - 1) |
| mask = (1 << bits) - 1 |
| |
| def DecodeVarint(buffer, pos): |
| result = 0 |
| shift = 0 |
| while 1: |
| b = buffer[pos] |
| result |= ((b & 0x7f) << shift) |
| pos += 1 |
| if not (b & 0x80): |
| result &= mask |
| result = (result ^ signbit) - signbit |
| result = result_type(result) |
| return (result, pos) |
| shift += 7 |
| if shift >= 64: |
| raise _DecodeError('Too many bytes when decoding varint.') |
| return DecodeVarint |
| |
| # All 32-bit and 64-bit values are represented as int. |
| _DecodeVarint = _VarintDecoder((1 << 64) - 1, int) |
| _DecodeSignedVarint = _SignedVarintDecoder(64, int) |
| |
| # Use these versions for values which must be limited to 32 bits. |
| _DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) |
| _DecodeSignedVarint32 = _SignedVarintDecoder(32, int) |
| |
| |
| def ReadTag(buffer, pos): |
| """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. |
| |
| We return the raw bytes of the tag rather than decoding them. The raw |
| bytes can then be used to look up the proper decoder. This effectively allows |
| us to trade some work that would be done in pure-python (decoding a varint) |
| for work that is done in C (searching for a byte string in a hash table). |
| In a low-level language it would be much cheaper to decode the varint and |
| use that, but not in Python. |
| |
| Args: |
| buffer: memoryview object of the encoded bytes |
| pos: int of the current position to start from |
| |
| Returns: |
| Tuple[bytes, int] of the tag data and new position. |
| """ |
| start = pos |
| while buffer[pos] & 0x80: |
| pos += 1 |
| pos += 1 |
| |
| tag_bytes = buffer[start:pos].tobytes() |
| return tag_bytes, pos |
| |
| |
| # -------------------------------------------------------------------- |
| |
| |
| def _SimpleDecoder(wire_type, decode_value): |
| """Return a constructor for a decoder for fields of a particular type. |
| |
| Args: |
| wire_type: The field's wire type. |
| decode_value: A function which decodes an individual value, e.g. |
| _DecodeVarint() |
| """ |
| |
| def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, |
| clear_if_default=False): |
| if is_packed: |
| local_DecodeVarint = _DecodeVarint |
| def DecodePackedField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| (endpoint, pos) = local_DecodeVarint(buffer, pos) |
| endpoint += pos |
| if endpoint > end: |
| raise _DecodeError('Truncated message.') |
| while pos < endpoint: |
| (element, pos) = decode_value(buffer, pos) |
| value.append(element) |
| if pos > endpoint: |
| del value[-1] # Discard corrupt value. |
| raise _DecodeError('Packed element was truncated.') |
| return pos |
| return DecodePackedField |
| elif is_repeated: |
| tag_bytes = encoder.TagBytes(field_number, wire_type) |
| tag_len = len(tag_bytes) |
| def DecodeRepeatedField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| while 1: |
| (element, new_pos) = decode_value(buffer, pos) |
| value.append(element) |
| # Predict that the next tag is another copy of the same repeated |
| # field. |
| pos = new_pos + tag_len |
| if buffer[new_pos:pos] != tag_bytes or new_pos >= end: |
| # Prediction failed. Return. |
| if new_pos > end: |
| raise _DecodeError('Truncated message.') |
| return new_pos |
| return DecodeRepeatedField |
| else: |
| def DecodeField(buffer, pos, end, message, field_dict): |
| (new_value, pos) = decode_value(buffer, pos) |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| if clear_if_default and not new_value: |
| field_dict.pop(key, None) |
| else: |
| field_dict[key] = new_value |
| return pos |
| return DecodeField |
| |
| return SpecificDecoder |
| |
| |
| def _ModifiedDecoder(wire_type, decode_value, modify_value): |
| """Like SimpleDecoder but additionally invokes modify_value on every value |
| before storing it. Usually modify_value is ZigZagDecode. |
| """ |
| |
| # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but |
| # not enough to make a significant difference. |
| |
| def InnerDecode(buffer, pos): |
| (result, new_pos) = decode_value(buffer, pos) |
| return (modify_value(result), new_pos) |
| return _SimpleDecoder(wire_type, InnerDecode) |
| |
| |
| def _StructPackDecoder(wire_type, format): |
| """Return a constructor for a decoder for a fixed-width field. |
| |
| Args: |
| wire_type: The field's wire type. |
| format: The format string to pass to struct.unpack(). |
| """ |
| |
| value_size = struct.calcsize(format) |
| local_unpack = struct.unpack |
| |
| # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but |
| # not enough to make a significant difference. |
| |
| # Note that we expect someone up-stack to catch struct.error and convert |
| # it to _DecodeError -- this way we don't have to set up exception- |
| # handling blocks every time we parse one value. |
| |
| def InnerDecode(buffer, pos): |
| new_pos = pos + value_size |
| result = local_unpack(format, buffer[pos:new_pos])[0] |
| return (result, new_pos) |
| return _SimpleDecoder(wire_type, InnerDecode) |
| |
| |
| def _FloatDecoder(): |
| """Returns a decoder for a float field. |
| |
| This code works around a bug in struct.unpack for non-finite 32-bit |
| floating-point values. |
| """ |
| |
| local_unpack = struct.unpack |
| |
| def InnerDecode(buffer, pos): |
| """Decode serialized float to a float and new position. |
| |
| Args: |
| buffer: memoryview of the serialized bytes |
| pos: int, position in the memory view to start at. |
| |
| Returns: |
| Tuple[float, int] of the deserialized float value and new position |
| in the serialized data. |
| """ |
| # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign |
| # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. |
| new_pos = pos + 4 |
| float_bytes = buffer[pos:new_pos].tobytes() |
| |
| # If this value has all its exponent bits set, then it's non-finite. |
| # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. |
| # To avoid that, we parse it specially. |
| if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): |
| # If at least one significand bit is set... |
| if float_bytes[0:3] != b'\x00\x00\x80': |
| return (math.nan, new_pos) |
| # If sign bit is set... |
| if float_bytes[3:4] == b'\xFF': |
| return (-math.inf, new_pos) |
| return (math.inf, new_pos) |
| |
| # Note that we expect someone up-stack to catch struct.error and convert |
| # it to _DecodeError -- this way we don't have to set up exception- |
| # handling blocks every time we parse one value. |
| result = local_unpack('<f', float_bytes)[0] |
| return (result, new_pos) |
| return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) |
| |
| |
| def _DoubleDecoder(): |
| """Returns a decoder for a double field. |
| |
| This code works around a bug in struct.unpack for not-a-number. |
| """ |
| |
| local_unpack = struct.unpack |
| |
| def InnerDecode(buffer, pos): |
| """Decode serialized double to a double and new position. |
| |
| Args: |
| buffer: memoryview of the serialized bytes. |
| pos: int, position in the memory view to start at. |
| |
| Returns: |
| Tuple[float, int] of the decoded double value and new position |
| in the serialized data. |
| """ |
| # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign |
| # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. |
| new_pos = pos + 8 |
| double_bytes = buffer[pos:new_pos].tobytes() |
| |
| # If this value has all its exponent bits set and at least one significand |
| # bit set, it's not a number. In Python 2.4, struct.unpack will treat it |
| # as inf or -inf. To avoid that, we treat it specially. |
| if ((double_bytes[7:8] in b'\x7F\xFF') |
| and (double_bytes[6:7] >= b'\xF0') |
| and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): |
| return (math.nan, new_pos) |
| |
| # Note that we expect someone up-stack to catch struct.error and convert |
| # it to _DecodeError -- this way we don't have to set up exception- |
| # handling blocks every time we parse one value. |
| result = local_unpack('<d', double_bytes)[0] |
| return (result, new_pos) |
| return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) |
| |
| |
| def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, |
| clear_if_default=False): |
| """Returns a decoder for enum field.""" |
| enum_type = key.enum_type |
| if is_packed: |
| local_DecodeVarint = _DecodeVarint |
| def DecodePackedField(buffer, pos, end, message, field_dict): |
| """Decode serialized packed enum to its value and a new position. |
| |
| Args: |
| buffer: memoryview of the serialized bytes. |
| pos: int, position in the memory view to start at. |
| end: int, end position of serialized data |
| message: Message object to store unknown fields in |
| field_dict: Map[Descriptor, Any] to store decoded values in. |
| |
| Returns: |
| int, new position in serialized data. |
| """ |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| (endpoint, pos) = local_DecodeVarint(buffer, pos) |
| endpoint += pos |
| if endpoint > end: |
| raise _DecodeError('Truncated message.') |
| while pos < endpoint: |
| value_start_pos = pos |
| (element, pos) = _DecodeSignedVarint32(buffer, pos) |
| # pylint: disable=protected-access |
| if element in enum_type.values_by_number: |
| value.append(element) |
| else: |
| if not message._unknown_fields: |
| message._unknown_fields = [] |
| tag_bytes = encoder.TagBytes(field_number, |
| wire_format.WIRETYPE_VARINT) |
| |
| message._unknown_fields.append( |
| (tag_bytes, buffer[value_start_pos:pos].tobytes())) |
| # pylint: enable=protected-access |
| if pos > endpoint: |
| if element in enum_type.values_by_number: |
| del value[-1] # Discard corrupt value. |
| else: |
| del message._unknown_fields[-1] |
| # pylint: enable=protected-access |
| raise _DecodeError('Packed element was truncated.') |
| return pos |
| return DecodePackedField |
| elif is_repeated: |
| tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) |
| tag_len = len(tag_bytes) |
| def DecodeRepeatedField(buffer, pos, end, message, field_dict): |
| """Decode serialized repeated enum to its value and a new position. |
| |
| Args: |
| buffer: memoryview of the serialized bytes. |
| pos: int, position in the memory view to start at. |
| end: int, end position of serialized data |
| message: Message object to store unknown fields in |
| field_dict: Map[Descriptor, Any] to store decoded values in. |
| |
| Returns: |
| int, new position in serialized data. |
| """ |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| while 1: |
| (element, new_pos) = _DecodeSignedVarint32(buffer, pos) |
| # pylint: disable=protected-access |
| if element in enum_type.values_by_number: |
| value.append(element) |
| else: |
| if not message._unknown_fields: |
| message._unknown_fields = [] |
| message._unknown_fields.append( |
| (tag_bytes, buffer[pos:new_pos].tobytes())) |
| # pylint: enable=protected-access |
| # Predict that the next tag is another copy of the same repeated |
| # field. |
| pos = new_pos + tag_len |
| if buffer[new_pos:pos] != tag_bytes or new_pos >= end: |
| # Prediction failed. Return. |
| if new_pos > end: |
| raise _DecodeError('Truncated message.') |
| return new_pos |
| return DecodeRepeatedField |
| else: |
| def DecodeField(buffer, pos, end, message, field_dict): |
| """Decode serialized repeated enum to its value and a new position. |
| |
| Args: |
| buffer: memoryview of the serialized bytes. |
| pos: int, position in the memory view to start at. |
| end: int, end position of serialized data |
| message: Message object to store unknown fields in |
| field_dict: Map[Descriptor, Any] to store decoded values in. |
| |
| Returns: |
| int, new position in serialized data. |
| """ |
| value_start_pos = pos |
| (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| if clear_if_default and not enum_value: |
| field_dict.pop(key, None) |
| return pos |
| # pylint: disable=protected-access |
| if enum_value in enum_type.values_by_number: |
| field_dict[key] = enum_value |
| else: |
| if not message._unknown_fields: |
| message._unknown_fields = [] |
| tag_bytes = encoder.TagBytes(field_number, |
| wire_format.WIRETYPE_VARINT) |
| message._unknown_fields.append( |
| (tag_bytes, buffer[value_start_pos:pos].tobytes())) |
| # pylint: enable=protected-access |
| return pos |
| return DecodeField |
| |
| |
| # -------------------------------------------------------------------- |
| |
| |
| Int32Decoder = _SimpleDecoder( |
| wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) |
| |
| Int64Decoder = _SimpleDecoder( |
| wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) |
| |
| UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) |
| UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) |
| |
| SInt32Decoder = _ModifiedDecoder( |
| wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) |
| SInt64Decoder = _ModifiedDecoder( |
| wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) |
| |
| # Note that Python conveniently guarantees that when using the '<' prefix on |
| # formats, they will also have the same size across all platforms (as opposed |
| # to without the prefix, where their sizes depend on the C compiler's basic |
| # type sizes). |
| Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') |
| Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') |
| SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') |
| SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') |
| FloatDecoder = _FloatDecoder() |
| DoubleDecoder = _DoubleDecoder() |
| |
| BoolDecoder = _ModifiedDecoder( |
| wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) |
| |
| |
| def StringDecoder(field_number, is_repeated, is_packed, key, new_default, |
| clear_if_default=False): |
| """Returns a decoder for a string field.""" |
| |
| local_DecodeVarint = _DecodeVarint |
| |
| def _ConvertToUnicode(memview): |
| """Convert byte to unicode.""" |
| byte_str = memview.tobytes() |
| try: |
| value = str(byte_str, 'utf-8') |
| except UnicodeDecodeError as e: |
| # add more information to the error message and re-raise it. |
| e.reason = '%s in field: %s' % (e, key.full_name) |
| raise |
| |
| return value |
| |
| assert not is_packed |
| if is_repeated: |
| tag_bytes = encoder.TagBytes(field_number, |
| wire_format.WIRETYPE_LENGTH_DELIMITED) |
| tag_len = len(tag_bytes) |
| def DecodeRepeatedField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| while 1: |
| (size, pos) = local_DecodeVarint(buffer, pos) |
| new_pos = pos + size |
| if new_pos > end: |
| raise _DecodeError('Truncated string.') |
| value.append(_ConvertToUnicode(buffer[pos:new_pos])) |
| # Predict that the next tag is another copy of the same repeated field. |
| pos = new_pos + tag_len |
| if buffer[new_pos:pos] != tag_bytes or new_pos == end: |
| # Prediction failed. Return. |
| return new_pos |
| return DecodeRepeatedField |
| else: |
| def DecodeField(buffer, pos, end, message, field_dict): |
| (size, pos) = local_DecodeVarint(buffer, pos) |
| new_pos = pos + size |
| if new_pos > end: |
| raise _DecodeError('Truncated string.') |
| if clear_if_default and not size: |
| field_dict.pop(key, None) |
| else: |
| field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) |
| return new_pos |
| return DecodeField |
| |
| |
| def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, |
| clear_if_default=False): |
| """Returns a decoder for a bytes field.""" |
| |
| local_DecodeVarint = _DecodeVarint |
| |
| assert not is_packed |
| if is_repeated: |
| tag_bytes = encoder.TagBytes(field_number, |
| wire_format.WIRETYPE_LENGTH_DELIMITED) |
| tag_len = len(tag_bytes) |
| def DecodeRepeatedField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| while 1: |
| (size, pos) = local_DecodeVarint(buffer, pos) |
| new_pos = pos + size |
| if new_pos > end: |
| raise _DecodeError('Truncated string.') |
| value.append(buffer[pos:new_pos].tobytes()) |
| # Predict that the next tag is another copy of the same repeated field. |
| pos = new_pos + tag_len |
| if buffer[new_pos:pos] != tag_bytes or new_pos == end: |
| # Prediction failed. Return. |
| return new_pos |
| return DecodeRepeatedField |
| else: |
| def DecodeField(buffer, pos, end, message, field_dict): |
| (size, pos) = local_DecodeVarint(buffer, pos) |
| new_pos = pos + size |
| if new_pos > end: |
| raise _DecodeError('Truncated string.') |
| if clear_if_default and not size: |
| field_dict.pop(key, None) |
| else: |
| field_dict[key] = buffer[pos:new_pos].tobytes() |
| return new_pos |
| return DecodeField |
| |
| |
| def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): |
| """Returns a decoder for a group field.""" |
| |
| end_tag_bytes = encoder.TagBytes(field_number, |
| wire_format.WIRETYPE_END_GROUP) |
| end_tag_len = len(end_tag_bytes) |
| |
| assert not is_packed |
| if is_repeated: |
| tag_bytes = encoder.TagBytes(field_number, |
| wire_format.WIRETYPE_START_GROUP) |
| tag_len = len(tag_bytes) |
| def DecodeRepeatedField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| while 1: |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| # Read sub-message. |
| pos = value.add()._InternalParse(buffer, pos, end) |
| # Read end tag. |
| new_pos = pos+end_tag_len |
| if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: |
| raise _DecodeError('Missing group end tag.') |
| # Predict that the next tag is another copy of the same repeated field. |
| pos = new_pos + tag_len |
| if buffer[new_pos:pos] != tag_bytes or new_pos == end: |
| # Prediction failed. Return. |
| return new_pos |
| return DecodeRepeatedField |
| else: |
| def DecodeField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| # Read sub-message. |
| pos = value._InternalParse(buffer, pos, end) |
| # Read end tag. |
| new_pos = pos+end_tag_len |
| if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: |
| raise _DecodeError('Missing group end tag.') |
| return new_pos |
| return DecodeField |
| |
| |
| def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): |
| """Returns a decoder for a message field.""" |
| |
| local_DecodeVarint = _DecodeVarint |
| |
| assert not is_packed |
| if is_repeated: |
| tag_bytes = encoder.TagBytes(field_number, |
| wire_format.WIRETYPE_LENGTH_DELIMITED) |
| tag_len = len(tag_bytes) |
| def DecodeRepeatedField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| while 1: |
| # Read length. |
| (size, pos) = local_DecodeVarint(buffer, pos) |
| new_pos = pos + size |
| if new_pos > end: |
| raise _DecodeError('Truncated message.') |
| # Read sub-message. |
| if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: |
| # The only reason _InternalParse would return early is if it |
| # encountered an end-group tag. |
| raise _DecodeError('Unexpected end-group tag.') |
| # Predict that the next tag is another copy of the same repeated field. |
| pos = new_pos + tag_len |
| if buffer[new_pos:pos] != tag_bytes or new_pos == end: |
| # Prediction failed. Return. |
| return new_pos |
| return DecodeRepeatedField |
| else: |
| def DecodeField(buffer, pos, end, message, field_dict): |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| # Read length. |
| (size, pos) = local_DecodeVarint(buffer, pos) |
| new_pos = pos + size |
| if new_pos > end: |
| raise _DecodeError('Truncated message.') |
| # Read sub-message. |
| if value._InternalParse(buffer, pos, new_pos) != new_pos: |
| # The only reason _InternalParse would return early is if it encountered |
| # an end-group tag. |
| raise _DecodeError('Unexpected end-group tag.') |
| return new_pos |
| return DecodeField |
| |
| |
| # -------------------------------------------------------------------- |
| |
| MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) |
| |
| def MessageSetItemDecoder(descriptor): |
| """Returns a decoder for a MessageSet item. |
| |
| The parameter is the message Descriptor. |
| |
| The message set message looks like this: |
| message MessageSet { |
| repeated group Item = 1 { |
| required int32 type_id = 2; |
| required string message = 3; |
| } |
| } |
| """ |
| |
| type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) |
| message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) |
| item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) |
| |
| local_ReadTag = ReadTag |
| local_DecodeVarint = _DecodeVarint |
| local_SkipField = SkipField |
| |
| def DecodeItem(buffer, pos, end, message, field_dict): |
| """Decode serialized message set to its value and new position. |
| |
| Args: |
| buffer: memoryview of the serialized bytes. |
| pos: int, position in the memory view to start at. |
| end: int, end position of serialized data |
| message: Message object to store unknown fields in |
| field_dict: Map[Descriptor, Any] to store decoded values in. |
| |
| Returns: |
| int, new position in serialized data. |
| """ |
| message_set_item_start = pos |
| type_id = -1 |
| message_start = -1 |
| message_end = -1 |
| |
| # Technically, type_id and message can appear in any order, so we need |
| # a little loop here. |
| while 1: |
| (tag_bytes, pos) = local_ReadTag(buffer, pos) |
| if tag_bytes == type_id_tag_bytes: |
| (type_id, pos) = local_DecodeVarint(buffer, pos) |
| elif tag_bytes == message_tag_bytes: |
| (size, message_start) = local_DecodeVarint(buffer, pos) |
| pos = message_end = message_start + size |
| elif tag_bytes == item_end_tag_bytes: |
| break |
| else: |
| pos = SkipField(buffer, pos, end, tag_bytes) |
| if pos == -1: |
| raise _DecodeError('Missing group end tag.') |
| |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| |
| if type_id == -1: |
| raise _DecodeError('MessageSet item missing type_id.') |
| if message_start == -1: |
| raise _DecodeError('MessageSet item missing message.') |
| |
| extension = message.Extensions._FindExtensionByNumber(type_id) |
| # pylint: disable=protected-access |
| if extension is not None: |
| value = field_dict.get(extension) |
| if value is None: |
| message_type = extension.message_type |
| if not hasattr(message_type, '_concrete_class'): |
| message_factory.GetMessageClass(message_type) |
| value = field_dict.setdefault( |
| extension, message_type._concrete_class()) |
| if value._InternalParse(buffer, message_start,message_end) != message_end: |
| # The only reason _InternalParse would return early is if it encountered |
| # an end-group tag. |
| raise _DecodeError('Unexpected end-group tag.') |
| else: |
| if not message._unknown_fields: |
| message._unknown_fields = [] |
| message._unknown_fields.append( |
| (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) |
| # pylint: enable=protected-access |
| |
| return pos |
| |
| return DecodeItem |
| |
| |
| def UnknownMessageSetItemDecoder(): |
| """Returns a decoder for a Unknown MessageSet item.""" |
| |
| type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) |
| message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) |
| item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) |
| |
| def DecodeUnknownItem(buffer): |
| pos = 0 |
| end = len(buffer) |
| message_start = -1 |
| message_end = -1 |
| while 1: |
| (tag_bytes, pos) = ReadTag(buffer, pos) |
| if tag_bytes == type_id_tag_bytes: |
| (type_id, pos) = _DecodeVarint(buffer, pos) |
| elif tag_bytes == message_tag_bytes: |
| (size, message_start) = _DecodeVarint(buffer, pos) |
| pos = message_end = message_start + size |
| elif tag_bytes == item_end_tag_bytes: |
| break |
| else: |
| pos = SkipField(buffer, pos, end, tag_bytes) |
| if pos == -1: |
| raise _DecodeError('Missing group end tag.') |
| |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| |
| if type_id == -1: |
| raise _DecodeError('MessageSet item missing type_id.') |
| if message_start == -1: |
| raise _DecodeError('MessageSet item missing message.') |
| |
| return (type_id, buffer[message_start:message_end].tobytes()) |
| |
| return DecodeUnknownItem |
| |
| # -------------------------------------------------------------------- |
| |
| def MapDecoder(field_descriptor, new_default, is_message_map): |
| """Returns a decoder for a map field.""" |
| |
| key = field_descriptor |
| tag_bytes = encoder.TagBytes(field_descriptor.number, |
| wire_format.WIRETYPE_LENGTH_DELIMITED) |
| tag_len = len(tag_bytes) |
| local_DecodeVarint = _DecodeVarint |
| # Can't read _concrete_class yet; might not be initialized. |
| message_type = field_descriptor.message_type |
| |
| def DecodeMap(buffer, pos, end, message, field_dict): |
| submsg = message_type._concrete_class() |
| value = field_dict.get(key) |
| if value is None: |
| value = field_dict.setdefault(key, new_default(message)) |
| while 1: |
| # Read length. |
| (size, pos) = local_DecodeVarint(buffer, pos) |
| new_pos = pos + size |
| if new_pos > end: |
| raise _DecodeError('Truncated message.') |
| # Read sub-message. |
| submsg.Clear() |
| if submsg._InternalParse(buffer, pos, new_pos) != new_pos: |
| # The only reason _InternalParse would return early is if it |
| # encountered an end-group tag. |
| raise _DecodeError('Unexpected end-group tag.') |
| |
| if is_message_map: |
| value[submsg.key].CopyFrom(submsg.value) |
| else: |
| value[submsg.key] = submsg.value |
| |
| # Predict that the next tag is another copy of the same repeated field. |
| pos = new_pos + tag_len |
| if buffer[new_pos:pos] != tag_bytes or new_pos == end: |
| # Prediction failed. Return. |
| return new_pos |
| |
| return DecodeMap |
| |
| # -------------------------------------------------------------------- |
| # Optimization is not as heavy here because calls to SkipField() are rare, |
| # except for handling end-group tags. |
| |
| def _SkipVarint(buffer, pos, end): |
| """Skip a varint value. Returns the new position.""" |
| # Previously ord(buffer[pos]) raised IndexError when pos is out of range. |
| # With this code, ord(b'') raises TypeError. Both are handled in |
| # python_message.py to generate a 'Truncated message' error. |
| while ord(buffer[pos:pos+1].tobytes()) & 0x80: |
| pos += 1 |
| pos += 1 |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| return pos |
| |
| def _SkipFixed64(buffer, pos, end): |
| """Skip a fixed64 value. Returns the new position.""" |
| |
| pos += 8 |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| return pos |
| |
| |
| def _DecodeFixed64(buffer, pos): |
| """Decode a fixed64.""" |
| new_pos = pos + 8 |
| return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) |
| |
| |
| def _SkipLengthDelimited(buffer, pos, end): |
| """Skip a length-delimited value. Returns the new position.""" |
| |
| (size, pos) = _DecodeVarint(buffer, pos) |
| pos += size |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| return pos |
| |
| |
| def _SkipGroup(buffer, pos, end): |
| """Skip sub-group. Returns the new position.""" |
| |
| while 1: |
| (tag_bytes, pos) = ReadTag(buffer, pos) |
| new_pos = SkipField(buffer, pos, end, tag_bytes) |
| if new_pos == -1: |
| return pos |
| pos = new_pos |
| |
| |
| def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): |
| """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" |
| |
| unknown_field_set = containers.UnknownFieldSet() |
| while end_pos is None or pos < end_pos: |
| (tag_bytes, pos) = ReadTag(buffer, pos) |
| (tag, _) = _DecodeVarint(tag_bytes, 0) |
| field_number, wire_type = wire_format.UnpackTag(tag) |
| if wire_type == wire_format.WIRETYPE_END_GROUP: |
| break |
| (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) |
| # pylint: disable=protected-access |
| unknown_field_set._add(field_number, wire_type, data) |
| |
| return (unknown_field_set, pos) |
| |
| |
| def _DecodeUnknownField(buffer, pos, wire_type): |
| """Decode a unknown field. Returns the UnknownField and new position.""" |
| |
| if wire_type == wire_format.WIRETYPE_VARINT: |
| (data, pos) = _DecodeVarint(buffer, pos) |
| elif wire_type == wire_format.WIRETYPE_FIXED64: |
| (data, pos) = _DecodeFixed64(buffer, pos) |
| elif wire_type == wire_format.WIRETYPE_FIXED32: |
| (data, pos) = _DecodeFixed32(buffer, pos) |
| elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: |
| (size, pos) = _DecodeVarint(buffer, pos) |
| data = buffer[pos:pos+size].tobytes() |
| pos += size |
| elif wire_type == wire_format.WIRETYPE_START_GROUP: |
| (data, pos) = _DecodeUnknownFieldSet(buffer, pos) |
| elif wire_type == wire_format.WIRETYPE_END_GROUP: |
| return (0, -1) |
| else: |
| raise _DecodeError('Wrong wire type in tag.') |
| |
| return (data, pos) |
| |
| |
| def _EndGroup(buffer, pos, end): |
| """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" |
| |
| return -1 |
| |
| |
| def _SkipFixed32(buffer, pos, end): |
| """Skip a fixed32 value. Returns the new position.""" |
| |
| pos += 4 |
| if pos > end: |
| raise _DecodeError('Truncated message.') |
| return pos |
| |
| |
| def _DecodeFixed32(buffer, pos): |
| """Decode a fixed32.""" |
| |
| new_pos = pos + 4 |
| return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) |
| |
| |
| def _RaiseInvalidWireType(buffer, pos, end): |
| """Skip function for unknown wire types. Raises an exception.""" |
| |
| raise _DecodeError('Tag had invalid wire type.') |
| |
| def _FieldSkipper(): |
| """Constructs the SkipField function.""" |
| |
| WIRETYPE_TO_SKIPPER = [ |
| _SkipVarint, |
| _SkipFixed64, |
| _SkipLengthDelimited, |
| _SkipGroup, |
| _EndGroup, |
| _SkipFixed32, |
| _RaiseInvalidWireType, |
| _RaiseInvalidWireType, |
| ] |
| |
| wiretype_mask = wire_format.TAG_TYPE_MASK |
| |
| def SkipField(buffer, pos, end, tag_bytes): |
| """Skips a field with the specified tag. |
| |
| |pos| should point to the byte immediately after the tag. |
| |
| Returns: |
| The new position (after the tag value), or -1 if the tag is an end-group |
| tag (in which case the calling loop should break). |
| """ |
| |
| # The wire type is always in the first byte since varints are little-endian. |
| wire_type = ord(tag_bytes[0:1]) & wiretype_mask |
| return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) |
| |
| return SkipField |
| |
| SkipField = _FieldSkipper() |