blob: 0c1174ba6b7a5cea66e1089a2a33229f9ec5dc8e [file] [log] [blame]
Adam Cozzette501ecec2023-09-26 14:36:20 -07001// Protocol Buffers - Google's data interchange format
2// Copyright 2023 Google LLC. All rights reserved.
3// https://developers.google.com/protocol-buffers/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are
7// met:
8//
9// * Redistributions of source code must retain the above copyright
10// notice, this list of conditions and the following disclaimer.
11// * Redistributions in binary form must reproduce the above
12// copyright notice, this list of conditions and the following disclaimer
13// in the documentation and/or other materials provided with the
14// distribution.
15// * Neither the name of Google LLC nor the names of its
16// contributors may be used to endorse or promote products derived from
17// this software without specific prior written permission.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31#include "python/convert.h"
32
33#include "python/message.h"
34#include "python/protobuf.h"
Eric Salo07fba1d2023-09-29 14:50:56 -070035#include "upb/message/map.h"
Adam Cozzette501ecec2023-09-26 14:36:20 -070036#include "upb/reflection/message.h"
37#include "upb/util/compare.h"
Joshua Haberman1711ebd2023-11-01 10:33:11 -070038#include "utf8_range.h"
Adam Cozzette501ecec2023-09-26 14:36:20 -070039
40// Must be last.
41#include "upb/port/def.inc"
42
43PyObject* PyUpb_UpbToPy(upb_MessageValue val, const upb_FieldDef* f,
44 PyObject* arena) {
45 switch (upb_FieldDef_CType(f)) {
46 case kUpb_CType_Enum:
47 case kUpb_CType_Int32:
48 return PyLong_FromLong(val.int32_val);
49 case kUpb_CType_Int64:
50 return PyLong_FromLongLong(val.int64_val);
51 case kUpb_CType_UInt32:
52 return PyLong_FromSize_t(val.uint32_val);
53 case kUpb_CType_UInt64:
54 return PyLong_FromUnsignedLongLong(val.uint64_val);
55 case kUpb_CType_Float:
56 return PyFloat_FromDouble(val.float_val);
57 case kUpb_CType_Double:
58 return PyFloat_FromDouble(val.double_val);
59 case kUpb_CType_Bool:
60 return PyBool_FromLong(val.bool_val);
61 case kUpb_CType_Bytes:
62 return PyBytes_FromStringAndSize(val.str_val.data, val.str_val.size);
63 case kUpb_CType_String: {
64 PyObject* ret =
65 PyUnicode_DecodeUTF8(val.str_val.data, val.str_val.size, NULL);
66 // If the string can't be decoded in UTF-8, just return a bytes object
67 // that contains the raw bytes. This can't happen if the value was
68 // assigned using the members of the Python message object, but can happen
69 // if the values were parsed from the wire (binary).
70 if (ret == NULL) {
71 PyErr_Clear();
72 ret = PyBytes_FromStringAndSize(val.str_val.data, val.str_val.size);
73 }
74 return ret;
75 }
76 case kUpb_CType_Message:
77 return PyUpb_Message_Get((upb_Message*)val.msg_val,
78 upb_FieldDef_MessageSubDef(f), arena);
79 default:
80 PyErr_Format(PyExc_SystemError,
81 "Getting a value from a field of unknown type %d",
82 upb_FieldDef_CType(f));
83 return NULL;
84 }
85}
86
87static bool PyUpb_GetInt64(PyObject* obj, int64_t* val) {
88 // We require that the value is either an integer or has an __index__
89 // conversion.
90 obj = PyNumber_Index(obj);
91 if (!obj) return false;
92 // If the value is already a Python long, PyLong_AsLongLong() retrieves it.
93 // Otherwise is converts to integer using __int__.
94 *val = PyLong_AsLongLong(obj);
95 bool ok = true;
96 if (PyErr_Occurred()) {
97 assert(PyErr_ExceptionMatches(PyExc_OverflowError));
98 PyErr_Clear();
99 PyErr_Format(PyExc_ValueError, "Value out of range: %S", obj);
100 ok = false;
101 }
102 Py_DECREF(obj);
103 return ok;
104}
105
106static bool PyUpb_GetUint64(PyObject* obj, uint64_t* val) {
107 // We require that the value is either an integer or has an __index__
108 // conversion.
109 obj = PyNumber_Index(obj);
110 if (!obj) return false;
111 *val = PyLong_AsUnsignedLongLong(obj);
112 bool ok = true;
113 if (PyErr_Occurred()) {
114 assert(PyErr_ExceptionMatches(PyExc_OverflowError));
115 PyErr_Clear();
116 PyErr_Format(PyExc_ValueError, "Value out of range: %S", obj);
117 ok = false;
118 }
119 Py_DECREF(obj);
120 return ok;
121}
122
123static bool PyUpb_GetInt32(PyObject* obj, int32_t* val) {
124 int64_t i64;
125 if (!PyUpb_GetInt64(obj, &i64)) return false;
126 if (i64 < INT32_MIN || i64 > INT32_MAX) {
127 PyErr_Format(PyExc_ValueError, "Value out of range: %S", obj);
128 return false;
129 }
130 *val = i64;
131 return true;
132}
133
134static bool PyUpb_GetUint32(PyObject* obj, uint32_t* val) {
135 uint64_t u64;
136 if (!PyUpb_GetUint64(obj, &u64)) return false;
137 if (u64 > UINT32_MAX) {
138 PyErr_Format(PyExc_ValueError, "Value out of range: %S", obj);
139 return false;
140 }
141 *val = u64;
142 return true;
143}
144
145// If `arena` is specified, copies the string data into the given arena.
146// Otherwise aliases the given data.
147static upb_MessageValue PyUpb_MaybeCopyString(const char* ptr, size_t size,
148 upb_Arena* arena) {
149 upb_MessageValue ret;
150 ret.str_val.size = size;
151 if (arena) {
152 char* buf = upb_Arena_Malloc(arena, size);
153 memcpy(buf, ptr, size);
154 ret.str_val.data = buf;
155 } else {
156 ret.str_val.data = ptr;
157 }
158 return ret;
159}
160
161const char* upb_FieldDef_TypeString(const upb_FieldDef* f) {
162 switch (upb_FieldDef_CType(f)) {
163 case kUpb_CType_Double:
164 return "double";
165 case kUpb_CType_Float:
166 return "float";
167 case kUpb_CType_Int64:
168 return "int64";
169 case kUpb_CType_Int32:
170 return "int32";
171 case kUpb_CType_UInt64:
172 return "uint64";
173 case kUpb_CType_UInt32:
174 return "uint32";
175 case kUpb_CType_Enum:
176 return "enum";
177 case kUpb_CType_Bool:
178 return "bool";
179 case kUpb_CType_String:
180 return "string";
181 case kUpb_CType_Bytes:
182 return "bytes";
183 case kUpb_CType_Message:
184 return "message";
185 }
186 UPB_UNREACHABLE();
187}
188
189static bool PyUpb_PyToUpbEnum(PyObject* obj, const upb_EnumDef* e,
190 upb_MessageValue* val) {
191 if (PyUnicode_Check(obj)) {
192 Py_ssize_t size;
193 const char* name = PyUnicode_AsUTF8AndSize(obj, &size);
194 const upb_EnumValueDef* ev =
195 upb_EnumDef_FindValueByNameWithSize(e, name, size);
196 if (!ev) {
197 PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", name);
198 return false;
199 }
200 val->int32_val = upb_EnumValueDef_Number(ev);
201 return true;
202 } else {
203 int32_t i32;
204 if (!PyUpb_GetInt32(obj, &i32)) return false;
205 if (upb_FileDef_Syntax(upb_EnumDef_File(e)) == kUpb_Syntax_Proto2 &&
206 !upb_EnumDef_CheckNumber(e, i32)) {
207 PyErr_Format(PyExc_ValueError, "invalid enumerator %d", (int)i32);
208 return false;
209 }
210 val->int32_val = i32;
211 return true;
212 }
213}
214
215bool PyUpb_IsNumpyNdarray(PyObject* obj, const upb_FieldDef* f) {
216 PyObject* type_name_obj =
217 PyObject_GetAttrString((PyObject*)Py_TYPE(obj), "__name__");
218 bool is_ndarray = false;
219 if (!strcmp(PyUpb_GetStrData(type_name_obj), "ndarray")) {
220 PyErr_Format(PyExc_TypeError,
221 "%S has type ndarray, but expected one of: %s", obj,
222 upb_FieldDef_TypeString(f));
223 is_ndarray = true;
224 }
225 Py_DECREF(type_name_obj);
226 return is_ndarray;
227}
228
229bool PyUpb_PyToUpb(PyObject* obj, const upb_FieldDef* f, upb_MessageValue* val,
230 upb_Arena* arena) {
231 switch (upb_FieldDef_CType(f)) {
232 case kUpb_CType_Enum:
233 return PyUpb_PyToUpbEnum(obj, upb_FieldDef_EnumSubDef(f), val);
234 case kUpb_CType_Int32:
235 return PyUpb_GetInt32(obj, &val->int32_val);
236 case kUpb_CType_Int64:
237 return PyUpb_GetInt64(obj, &val->int64_val);
238 case kUpb_CType_UInt32:
239 return PyUpb_GetUint32(obj, &val->uint32_val);
240 case kUpb_CType_UInt64:
241 return PyUpb_GetUint64(obj, &val->uint64_val);
242 case kUpb_CType_Float:
243 if (PyUpb_IsNumpyNdarray(obj, f)) return false;
244 val->float_val = PyFloat_AsDouble(obj);
245 return !PyErr_Occurred();
246 case kUpb_CType_Double:
247 if (PyUpb_IsNumpyNdarray(obj, f)) return false;
248 val->double_val = PyFloat_AsDouble(obj);
249 return !PyErr_Occurred();
250 case kUpb_CType_Bool:
251 if (PyUpb_IsNumpyNdarray(obj, f)) return false;
252 val->bool_val = PyLong_AsLong(obj);
253 return !PyErr_Occurred();
254 case kUpb_CType_Bytes: {
255 char* ptr;
256 Py_ssize_t size;
257 if (PyBytes_AsStringAndSize(obj, &ptr, &size) < 0) return false;
258 *val = PyUpb_MaybeCopyString(ptr, size, arena);
259 return true;
260 }
261 case kUpb_CType_String: {
262 Py_ssize_t size;
Adam Cozzette501ecec2023-09-26 14:36:20 -0700263 if (PyBytes_Check(obj)) {
Joshua Haberman1711ebd2023-11-01 10:33:11 -0700264 // Use the object's bytes if they are valid UTF-8.
265 char* ptr;
266 if (PyBytes_AsStringAndSize(obj, &ptr, &size) < 0) return false;
267 if (utf8_range2((const unsigned char*)ptr, size) != 0) {
268 // Invalid UTF-8. Try to convert the message to a Python Unicode
269 // object, even though we know this will fail, just to get the
270 // idiomatic Python error message.
271 obj = PyUnicode_FromEncodedObject(obj, "utf-8", NULL);
272 assert(!obj);
273 return false;
274 }
275 *val = PyUpb_MaybeCopyString(ptr, size, arena);
276 return true;
277 } else {
278 const char* ptr;
279 ptr = PyUnicode_AsUTF8AndSize(obj, &size);
280 if (PyErr_Occurred()) return false;
281 *val = PyUpb_MaybeCopyString(ptr, size, arena);
282 return true;
Adam Cozzette501ecec2023-09-26 14:36:20 -0700283 }
Adam Cozzette501ecec2023-09-26 14:36:20 -0700284 }
285 case kUpb_CType_Message:
286 PyErr_Format(PyExc_ValueError, "Message objects may not be assigned");
287 return false;
288 default:
289 PyErr_Format(PyExc_SystemError,
290 "Getting a value from a field of unknown type %d",
291 upb_FieldDef_CType(f));
292 return false;
293 }
294}
295
296bool upb_Message_IsEqual(const upb_Message* msg1, const upb_Message* msg2,
297 const upb_MessageDef* m);
298
299// -----------------------------------------------------------------------------
300// Equal
301// -----------------------------------------------------------------------------
302
303bool PyUpb_ValueEq(upb_MessageValue val1, upb_MessageValue val2,
304 const upb_FieldDef* f) {
305 switch (upb_FieldDef_CType(f)) {
306 case kUpb_CType_Bool:
307 return val1.bool_val == val2.bool_val;
308 case kUpb_CType_Int32:
309 case kUpb_CType_UInt32:
310 case kUpb_CType_Enum:
311 return val1.int32_val == val2.int32_val;
312 case kUpb_CType_Int64:
313 case kUpb_CType_UInt64:
314 return val1.int64_val == val2.int64_val;
315 case kUpb_CType_Float:
316 return val1.float_val == val2.float_val;
317 case kUpb_CType_Double:
318 return val1.double_val == val2.double_val;
319 case kUpb_CType_String:
320 case kUpb_CType_Bytes:
321 return val1.str_val.size == val2.str_val.size &&
322 memcmp(val1.str_val.data, val2.str_val.data, val1.str_val.size) ==
323 0;
324 case kUpb_CType_Message:
325 return upb_Message_IsEqual(val1.msg_val, val2.msg_val,
326 upb_FieldDef_MessageSubDef(f));
327 default:
328 return false;
329 }
330}
331
332bool PyUpb_Map_IsEqual(const upb_Map* map1, const upb_Map* map2,
333 const upb_FieldDef* f) {
334 assert(upb_FieldDef_IsMap(f));
335 if (map1 == map2) return true;
336
337 size_t size1 = map1 ? upb_Map_Size(map1) : 0;
338 size_t size2 = map2 ? upb_Map_Size(map2) : 0;
339 if (size1 != size2) return false;
340 if (size1 == 0) return true;
341
342 const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
343 const upb_FieldDef* val_f = upb_MessageDef_Field(entry_m, 1);
344 size_t iter = kUpb_Map_Begin;
345
346 upb_MessageValue key, val1;
347 while (upb_Map_Next(map1, &key, &val1, &iter)) {
348 upb_MessageValue val2;
349 if (!upb_Map_Get(map2, key, &val2)) return false;
350 if (!PyUpb_ValueEq(val1, val2, val_f)) return false;
351 }
352
353 return true;
354}
355
356static bool PyUpb_ArrayElem_IsEqual(const upb_Array* arr1,
357 const upb_Array* arr2, size_t i,
358 const upb_FieldDef* f) {
359 assert(i < upb_Array_Size(arr1));
360 assert(i < upb_Array_Size(arr2));
361 upb_MessageValue val1 = upb_Array_Get(arr1, i);
362 upb_MessageValue val2 = upb_Array_Get(arr2, i);
363 return PyUpb_ValueEq(val1, val2, f);
364}
365
366bool PyUpb_Array_IsEqual(const upb_Array* arr1, const upb_Array* arr2,
367 const upb_FieldDef* f) {
368 assert(upb_FieldDef_IsRepeated(f) && !upb_FieldDef_IsMap(f));
369 if (arr1 == arr2) return true;
370
371 size_t n1 = arr1 ? upb_Array_Size(arr1) : 0;
372 size_t n2 = arr2 ? upb_Array_Size(arr2) : 0;
373 if (n1 != n2) return false;
374
375 // Half the length rounded down. Important: the empty list rounds to 0.
376 size_t half = n1 / 2;
377
378 // Search from the ends-in. We expect differences to more quickly manifest
379 // at the ends than in the middle. If the length is odd we will miss the
380 // middle element.
381 for (size_t i = 0; i < half; i++) {
382 if (!PyUpb_ArrayElem_IsEqual(arr1, arr2, i, f)) return false;
383 if (!PyUpb_ArrayElem_IsEqual(arr1, arr2, n1 - 1 - i, f)) return false;
384 }
385
386 // For an odd-lengthed list, pick up the middle element.
387 if (n1 & 1) {
388 if (!PyUpb_ArrayElem_IsEqual(arr1, arr2, half, f)) return false;
389 }
390
391 return true;
392}
393
394bool upb_Message_IsEqual(const upb_Message* msg1, const upb_Message* msg2,
395 const upb_MessageDef* m) {
396 if (msg1 == msg2) return true;
397 if (upb_Message_ExtensionCount(msg1) != upb_Message_ExtensionCount(msg2))
398 return false;
399
400 // Compare messages field-by-field. This is slightly tricky, because while
401 // we can iterate over normal fields in a predictable order, the extension
402 // order is unpredictable and may be different between msg1 and msg2.
403 // So we use the following strategy:
404 // 1. Iterate over all msg1 fields (including extensions).
405 // 2. For non-extension fields, we find the corresponding field by simply
406 // using upb_Message_Next(msg2). If the two messages have the same set
407 // of fields, this will yield the same field.
408 // 3. For extension fields, we have to actually search for the corresponding
409 // field, which we do with upb_Message_GetFieldByDef(msg2, ext_f1).
410 // 4. Once iteration over msg1 is complete, we call upb_Message_Next(msg2)
411 // one
412 // final time to verify that we have visited all of msg2's regular fields
413 // (we pass NULL for ext_dict so that iteration will *not* return
414 // extensions).
415 //
416 // We don't need to visit all of msg2's extensions, because we verified up
417 // front that both messages have the same number of extensions.
418 const upb_DefPool* symtab = upb_FileDef_Pool(upb_MessageDef_File(m));
419 const upb_FieldDef *f1, *f2;
420 upb_MessageValue val1, val2;
421 size_t iter1 = kUpb_Message_Begin;
422 size_t iter2 = kUpb_Message_Begin;
423 while (upb_Message_Next(msg1, m, symtab, &f1, &val1, &iter1)) {
424 if (upb_FieldDef_IsExtension(f1)) {
425 val2 = upb_Message_GetFieldByDef(msg2, f1);
426 } else {
427 if (!upb_Message_Next(msg2, m, NULL, &f2, &val2, &iter2) || f1 != f2) {
428 return false;
429 }
430 }
431
432 if (upb_FieldDef_IsMap(f1)) {
433 if (!PyUpb_Map_IsEqual(val1.map_val, val2.map_val, f1)) return false;
434 } else if (upb_FieldDef_IsRepeated(f1)) {
435 if (!PyUpb_Array_IsEqual(val1.array_val, val2.array_val, f1)) {
436 return false;
437 }
438 } else {
439 if (!PyUpb_ValueEq(val1, val2, f1)) return false;
440 }
441 }
442
443 if (upb_Message_Next(msg2, m, NULL, &f2, &val2, &iter2)) return false;
444
445 size_t usize1, usize2;
446 const char* uf1 = upb_Message_GetUnknown(msg1, &usize1);
447 const char* uf2 = upb_Message_GetUnknown(msg2, &usize2);
448 // 100 is arbitrary, we're trying to prevent stack overflow but it's not
449 // obvious how deep we should allow here.
450 return upb_Message_UnknownFieldsAreEqual(uf1, usize1, uf2, usize2, 100) ==
451 kUpb_UnknownCompareResult_Equal;
452}
453
454#include "upb/port/undef.inc"