blob: b3351a498fa427305703220d6c8bb42a85aeee7d [file] [log] [blame]
# -*- coding: utf-8 -*-
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
"""Tests Nextgen Pythonic protobuf APIs."""
import io
import unittest
from google.protobuf import proto
from google.protobuf.internal import encoder
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import _parameterized
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
@_parameterized.named_parameters(('_proto2', unittest_pb2),
('_proto3', unittest_proto3_arena_pb2))
@testing_refleaks.TestCase
class ProtoTest(unittest.TestCase):
def test_simple_serialize_parse(self, message_module):
msg = message_module.TestAllTypes()
test_util.SetAllFields(msg)
serialized_data = proto.serialize(msg)
parsed_msg = proto.parse(message_module.TestAllTypes, serialized_data)
self.assertEqual(msg, parsed_msg)
def test_serialize_parse_length_prefixed_empty(self, message_module):
empty_alltypes = message_module.TestAllTypes()
out = io.BytesIO()
proto.serialize_length_prefixed(empty_alltypes, out)
input_bytes = io.BytesIO(out.getvalue())
msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)
self.assertEqual(msg, empty_alltypes)
def test_parse_length_prefixed_truncated(self, message_module):
out = io.BytesIO()
encoder._VarintEncoder()(out.write, 9999)
msg = message_module.TestAllTypes(optional_int32=1)
out.write(proto.serialize(msg))
input_bytes = io.BytesIO(out.getvalue())
with self.assertRaises(ValueError) as context:
proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)
self.assertEqual(
str(context.exception),
'Truncated message or non-buffered input_bytes: '
'Expected 9999 bytes but only 2 bytes parsed for '
'TestAllTypes.',
)
def test_serialize_length_prefixed_fake_io(self, message_module):
class FakeBytesIO(io.BytesIO):
def write(self, b: bytes) -> int:
return 0
msg = message_module.TestAllTypes(optional_int32=123)
out = FakeBytesIO()
with self.assertRaises(TypeError) as context:
proto.serialize_length_prefixed(msg, out)
self.assertIn(
'Failed to write complete message (wrote: 0, expected: 2)',
str(context.exception),
)
_EXPECTED_PROTO3 = b'\x04r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'
_EXPECTED_PROTO2 = b'\x06\x08\x00r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'
@_parameterized.named_parameters(
('_proto2', unittest_pb2, _EXPECTED_PROTO2),
('_proto3', unittest_proto3_arena_pb2, _EXPECTED_PROTO3),
)
@testing_refleaks.TestCase
class LengthPrefixedWithGolden(unittest.TestCase):
def test_serialize_length_prefixed(self, message_module, expected):
number_of_messages = 3
out = io.BytesIO()
for index in range(0, number_of_messages):
msg = message_module.TestAllTypes(
optional_int32=index, optional_string='hi'
)
proto.serialize_length_prefixed(msg, out)
self.assertEqual(out.getvalue(), expected)
def test_parse_length_prefixed(self, message_module, input_bytes):
expected_number_of_messages = 3
input_io = io.BytesIO(input_bytes)
index = 0
while True:
msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_io)
if msg is None:
break
self.assertEqual(msg.optional_int32, index)
self.assertEqual(msg.optional_string, 'hi')
index += 1
self.assertEqual(index, expected_number_of_messages)
if __name__ == '__main__':
unittest.main()