upb: copy the wire decode recursion-depth-checking code to the wire encoder
PiperOrigin-RevId: 524873082
diff --git a/python/message.c b/python/message.c
index 2fe5126..7c876a7 100644
--- a/python/message.c
+++ b/python/message.c
@@ -1528,7 +1528,7 @@
const upb_MiniTable* layout = upb_MessageDef_MiniTable(msgdef);
size_t size = 0;
// Python does not currently have any effective limit on serialization depth.
- int options = UPB_ENCODE_MAXDEPTH(UINT32_MAX);
+ int options = upb_EncodeOptions_MaxDepth(UINT16_MAX);
if (check_required) options |= kUpb_EncodeOption_CheckRequired;
if (deterministic) options |= kUpb_EncodeOption_Deterministic;
char* pb;
diff --git a/upb/message/test.cc b/upb/message/test.cc
index 8da2920..a204d9c 100644
--- a/upb/message/test.cc
+++ b/upb/message/test.cc
@@ -507,12 +507,17 @@
// static void DecodeEncodeArbitrarySchemaAndPayload(
// const upb::fuzz::MiniTableFuzzInput& input, std::string_view proto_payload,
// int decode_options, int encode_options) {
+// // The value of 80 used here is empirical and intended to roughly represent
+// // the tiny 64K stack size used by the test framework. We still see the
+// // occasional stack overflow at 90, so far 80 has worked 100% of the time.
+// decode_options = upb_Decode_LimitDepth(decode_options, 80);
+// encode_options = upb_Encode_LimitDepth(encode_options, 80);
+//
// upb::Arena arena;
// upb_ExtensionRegistry* exts;
// const upb_MiniTable* mini_table =
// upb::fuzz::BuildMiniTable(input, &exts, arena.ptr());
// if (!mini_table) return;
-// decode_options = upb_Decode_LimitDepth(decode_options, 80);
// upb_Message* msg = upb_Message_New(mini_table, arena.ptr());
// upb_Decode(proto_payload.data(), proto_payload.size(), msg, mini_table, exts,
// decode_options, arena.ptr());
diff --git a/upb/wire/encode.h b/upb/wire/encode.h
index 8067fcb..f7e456a 100644
--- a/upb/wire/encode.h
+++ b/upb/wire/encode.h
@@ -48,24 +48,37 @@
* memory during encode. */
kUpb_EncodeOption_Deterministic = 1,
- /* When set, unknown fields are not printed. */
+ // When set, unknown fields are not printed.
kUpb_EncodeOption_SkipUnknown = 2,
- /* When set, the encode will fail if any required fields are missing. */
+ // When set, the encode will fail if any required fields are missing.
kUpb_EncodeOption_CheckRequired = 4,
};
-#define UPB_ENCODE_MAXDEPTH(depth) ((depth) << 16)
-
typedef enum {
kUpb_EncodeStatus_Ok = 0,
- kUpb_EncodeStatus_OutOfMemory = 1, // Arena alloc failed
- kUpb_EncodeStatus_MaxDepthExceeded = 2, // Exceeded UPB_ENCODE_MAXDEPTH
+ kUpb_EncodeStatus_OutOfMemory = 1, // Arena alloc failed
+ kUpb_EncodeStatus_MaxDepthExceeded = 2,
// kUpb_EncodeOption_CheckRequired failed but the parse otherwise succeeded.
kUpb_EncodeStatus_MissingRequired = 3,
} upb_EncodeStatus;
+UPB_INLINE uint32_t upb_EncodeOptions_MaxDepth(uint16_t depth) {
+ return (uint32_t)depth << 16;
+}
+
+UPB_INLINE uint16_t upb_EncodeOptions_GetMaxDepth(uint32_t options) {
+ return options >> 16;
+}
+
+// Enforce an upper bound on recursion depth.
+UPB_INLINE int upb_Encode_LimitDepth(uint32_t encode_options, uint32_t limit) {
+ uint32_t max_depth = upb_EncodeOptions_GetMaxDepth(encode_options);
+ if (max_depth > limit) max_depth = limit;
+ return upb_EncodeOptions_MaxDepth(max_depth) | (encode_options & 0xffff);
+}
+
upb_EncodeStatus upb_Encode(const void* msg, const upb_MiniTable* l,
int options, upb_Arena* arena, char** buf,
size_t* size);