Add an indirection to sub-messages pointers to allow for static tree shaking.

PiperOrigin-RevId: 640369522
diff --git a/upb/mini_descriptor/decode.c b/upb/mini_descriptor/decode.c
index d355112..1094fc0 100644
--- a/upb/mini_descriptor/decode.c
+++ b/upb/mini_descriptor/decode.c
@@ -27,6 +27,7 @@
 #include "upb/mini_table/field.h"
 #include "upb/mini_table/internal/field.h"
 #include "upb/mini_table/internal/message.h"
+#include "upb/mini_table/internal/sub.h"
 #include "upb/mini_table/message.h"
 #include "upb/mini_table/sub.h"
 
@@ -407,11 +408,15 @@
                                        upb_SubCounts sub_counts) {
   uint32_t total_count = sub_counts.submsg_count + sub_counts.subenum_count;
   size_t subs_bytes = sizeof(*d->table->UPB_PRIVATE(subs)) * total_count;
-  upb_MiniTableSub* subs = upb_Arena_Malloc(d->arena, subs_bytes);
+  size_t ptrs_bytes = sizeof(upb_MiniTable*) * sub_counts.submsg_count;
+  upb_MiniTableSubInternal* subs = upb_Arena_Malloc(d->arena, subs_bytes);
+  const upb_MiniTable** subs_ptrs = upb_Arena_Malloc(d->arena, ptrs_bytes);
   upb_MdDecoder_CheckOutOfMemory(&d->base, subs);
+  upb_MdDecoder_CheckOutOfMemory(&d->base, subs_ptrs);
   uint32_t i = 0;
   for (; i < sub_counts.submsg_count; i++) {
-    subs[i].UPB_PRIVATE(submsg) = UPB_PRIVATE(_upb_MiniTable_Empty)();
+    subs_ptrs[i] = UPB_PRIVATE(_upb_MiniTable_Empty)();
+    subs[i].UPB_PRIVATE(submsg) = &subs_ptrs[i];
   }
   if (sub_counts.subenum_count) {
     upb_MiniTableField* f = d->fields;
diff --git a/upb/mini_descriptor/link.c b/upb/mini_descriptor/link.c
index 093150b..5dec59e 100644
--- a/upb/mini_descriptor/link.c
+++ b/upb/mini_descriptor/link.c
@@ -9,10 +9,14 @@
 
 #include <stddef.h>
 #include <stdint.h>
+#include <string.h>
 
 #include "upb/base/descriptor_constants.h"
 #include "upb/mini_table/enum.h"
 #include "upb/mini_table/field.h"
+#include "upb/mini_table/internal/field.h"
+#include "upb/mini_table/internal/message.h"
+#include "upb/mini_table/internal/sub.h"
 #include "upb/mini_table/message.h"
 #include "upb/mini_table/sub.h"
 
@@ -51,11 +55,11 @@
   }
 
   int idx = field->UPB_PRIVATE(submsg_index);
-  upb_MiniTableSub* table_subs = (void*)table->UPB_PRIVATE(subs);
+  upb_MiniTableSubInternal* table_subs = (void*)table->UPB_PRIVATE(subs);
   // TODO: Add this assert back once YouTube is updated to not call
   // this function repeatedly.
   // UPB_ASSERT(UPB_PRIVATE(_upb_MiniTable_IsEmpty)(table_sub->submsg));
-  table_subs[idx] = upb_MiniTableSub_FromMessage(sub);
+  memcpy((void*)table_subs[idx].UPB_PRIVATE(submsg), &sub, sizeof(void*));
   return true;
 }
 
diff --git a/upb/mini_table/internal/message.h b/upb/mini_table/internal/message.h
index d5b1ae4..2c618ce 100644
--- a/upb/mini_table/internal/message.h
+++ b/upb/mini_table/internal/message.h
@@ -46,7 +46,7 @@
 
 // LINT.IfChange(minitable_struct_definition)
 struct upb_MiniTable {
-  const union upb_MiniTableSub* UPB_PRIVATE(subs);
+  const upb_MiniTableSubInternal* UPB_PRIVATE(subs);
   const struct upb_MiniTableField* UPB_ONLYBITS(fields);
 
   // Must be aligned to sizeof(void*). Doesn't include internal members like
@@ -99,9 +99,10 @@
   return &m->UPB_ONLYBITS(fields)[i];
 }
 
-UPB_INLINE const union upb_MiniTableSub UPB_PRIVATE(
-    _upb_MiniTable_GetSubByIndex)(const struct upb_MiniTable* m, uint32_t i) {
-  return m->UPB_PRIVATE(subs)[i];
+UPB_INLINE const struct upb_MiniTable* UPB_PRIVATE(
+    _upb_MiniTable_GetSubTableByIndex)(const struct upb_MiniTable* m,
+                                       uint32_t i) {
+  return *m->UPB_PRIVATE(subs)[i].UPB_PRIVATE(submsg);
 }
 
 UPB_API_INLINE const struct upb_MiniTable* upb_MiniTable_SubMessage(
@@ -109,7 +110,8 @@
   if (upb_MiniTableField_CType(f) != kUpb_CType_Message) {
     return NULL;
   }
-  return m->UPB_PRIVATE(subs)[f->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(submsg);
+  return UPB_PRIVATE(_upb_MiniTable_GetSubTableByIndex)(
+      m, f->UPB_PRIVATE(submsg_index));
 }
 
 UPB_API_INLINE const struct upb_MiniTable* upb_MiniTable_GetSubMessageTable(
diff --git a/upb/mini_table/internal/sub.h b/upb/mini_table/internal/sub.h
index 967b557..4c21569 100644
--- a/upb/mini_table/internal/sub.h
+++ b/upb/mini_table/internal/sub.h
@@ -11,6 +11,11 @@
 // Must be last.
 #include "upb/port/def.inc"
 
+typedef union {
+  const struct upb_MiniTable* const* UPB_PRIVATE(submsg);
+  const struct upb_MiniTableEnum* UPB_PRIVATE(subenum);
+} upb_MiniTableSubInternal;
+
 union upb_MiniTableSub {
   const struct upb_MiniTable* UPB_PRIVATE(submsg);
   const struct upb_MiniTableEnum* UPB_PRIVATE(subenum);
diff --git a/upb/wire/decode.c b/upb/wire/decode.c
index 59602ae..2cb1a44 100644
--- a/upb/wire/decode.c
+++ b/upb/wire/decode.c
@@ -35,7 +35,7 @@
 #include "upb/mini_table/field.h"
 #include "upb/mini_table/internal/field.h"
 #include "upb/mini_table/internal/message.h"
-#include "upb/mini_table/internal/size_log2.h"
+#include "upb/mini_table/internal/sub.h"
 #include "upb/mini_table/message.h"
 #include "upb/mini_table/sub.h"
 #include "upb/port/atomic.h"
@@ -97,15 +97,15 @@
 // Returns the MiniTable corresponding to a given MiniTableField
 // from an array of MiniTableSubs.
 static const upb_MiniTable* _upb_MiniTableSubs_MessageByField(
-    const upb_MiniTableSub* subs, const upb_MiniTableField* field) {
-  return upb_MiniTableSub_Message(subs[field->UPB_PRIVATE(submsg_index)]);
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) {
+  return *subs[field->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(submsg);
 }
 
 // Returns the MiniTableEnum corresponding to a given MiniTableField
 // from an array of MiniTableSub.
 static const upb_MiniTableEnum* _upb_MiniTableSubs_EnumByField(
-    const upb_MiniTableSub* subs, const upb_MiniTableField* field) {
-  return upb_MiniTableSub_Enum(subs[field->UPB_PRIVATE(submsg_index)]);
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) {
+  return subs[field->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(subenum);
 }
 
 static const char* _upb_Decoder_DecodeMessage(upb_Decoder* d, const char* ptr,
@@ -240,11 +240,10 @@
   }
 }
 
-static upb_Message* _upb_Decoder_NewSubMessage(upb_Decoder* d,
-                                               const upb_MiniTableSub* subs,
-                                               const upb_MiniTableField* field,
-                                               upb_TaggedMessagePtr* target) {
-  const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field);
+static upb_Message* _upb_Decoder_NewSubMessage2(upb_Decoder* d,
+                                                const upb_MiniTable* subl,
+                                                const upb_MiniTableField* field,
+                                                upb_TaggedMessagePtr* target) {
   UPB_ASSERT(subl);
   upb_Message* msg = _upb_Message_New(subl, &d->arena);
   if (!msg) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory);
@@ -265,8 +264,15 @@
   return msg;
 }
 
+static upb_Message* _upb_Decoder_NewSubMessage(
+    upb_Decoder* d, const upb_MiniTableSubInternal* subs,
+    const upb_MiniTableField* field, upb_TaggedMessagePtr* target) {
+  const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field);
+  return _upb_Decoder_NewSubMessage2(d, subl, field, target);
+}
+
 static upb_Message* _upb_Decoder_ReuseSubMessage(
-    upb_Decoder* d, const upb_MiniTableSub* subs,
+    upb_Decoder* d, const upb_MiniTableSubInternal* subs,
     const upb_MiniTableField* field, upb_TaggedMessagePtr* target) {
   upb_TaggedMessagePtr tagged = *target;
   const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field);
@@ -319,7 +325,7 @@
 UPB_FORCEINLINE
 const char* _upb_Decoder_DecodeSubMessage(upb_Decoder* d, const char* ptr,
                                           upb_Message* submsg,
-                                          const upb_MiniTableSub* subs,
+                                          const upb_MiniTableSubInternal* subs,
                                           const upb_MiniTableField* field,
                                           int size) {
   int saved_delta = upb_EpsCopyInputStream_PushLimit(&d->input, ptr, size);
@@ -352,7 +358,7 @@
 UPB_FORCEINLINE
 const char* _upb_Decoder_DecodeKnownGroup(upb_Decoder* d, const char* ptr,
                                           upb_Message* submsg,
-                                          const upb_MiniTableSub* subs,
+                                          const upb_MiniTableSubInternal* subs,
                                           const upb_MiniTableField* field) {
   const upb_MiniTable* subl = _upb_MiniTableSubs_MessageByField(subs, field);
   UPB_ASSERT(subl);
@@ -403,12 +409,10 @@
 }
 
 UPB_NOINLINE
-static const char* _upb_Decoder_DecodeEnumArray(upb_Decoder* d, const char* ptr,
-                                                upb_Message* msg,
-                                                upb_Array* arr,
-                                                const upb_MiniTableSub* subs,
-                                                const upb_MiniTableField* field,
-                                                wireval* val) {
+static const char* _upb_Decoder_DecodeEnumArray(
+    upb_Decoder* d, const char* ptr, upb_Message* msg, upb_Array* arr,
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field,
+    wireval* val) {
   const upb_MiniTableEnum* e = _upb_MiniTableSubs_EnumByField(subs, field);
   if (!_upb_Decoder_CheckEnum(d, ptr, msg, e, field, val)) return ptr;
   void* mem = UPB_PTR_AT(upb_Array_MutableDataPtr(arr),
@@ -484,7 +488,7 @@
 UPB_NOINLINE
 static const char* _upb_Decoder_DecodeEnumPacked(
     upb_Decoder* d, const char* ptr, upb_Message* msg, upb_Array* arr,
-    const upb_MiniTableSub* subs, const upb_MiniTableField* field,
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field,
     wireval* val) {
   const upb_MiniTableEnum* e = _upb_MiniTableSubs_EnumByField(subs, field);
   int saved_limit = upb_EpsCopyInputStream_PushLimit(&d->input, ptr, val->size);
@@ -518,11 +522,10 @@
   return ret;
 }
 
-static const char* _upb_Decoder_DecodeToArray(upb_Decoder* d, const char* ptr,
-                                              upb_Message* msg,
-                                              const upb_MiniTableSub* subs,
-                                              const upb_MiniTableField* field,
-                                              wireval* val, int op) {
+static const char* _upb_Decoder_DecodeToArray(
+    upb_Decoder* d, const char* ptr, upb_Message* msg,
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field,
+    wireval* val, int op) {
   upb_Array** arrp = UPB_PTR_AT(msg, field->UPB_PRIVATE(offset), void);
   upb_Array* arr = *arrp;
   void* mem;
@@ -623,11 +626,10 @@
   return ret;
 }
 
-static const char* _upb_Decoder_DecodeToMap(upb_Decoder* d, const char* ptr,
-                                            upb_Message* msg,
-                                            const upb_MiniTableSub* subs,
-                                            const upb_MiniTableField* field,
-                                            wireval* val) {
+static const char* _upb_Decoder_DecodeToMap(
+    upb_Decoder* d, const char* ptr, upb_Message* msg,
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field,
+    wireval* val) {
   upb_Map** map_p = UPB_PTR_AT(msg, field->UPB_PRIVATE(offset), upb_Map*);
   upb_Map* map = *map_p;
   upb_MapEntry ent;
@@ -688,8 +690,8 @@
 
 static const char* _upb_Decoder_DecodeToSubMessage(
     upb_Decoder* d, const char* ptr, upb_Message* msg,
-    const upb_MiniTableSub* subs, const upb_MiniTableField* field, wireval* val,
-    int op) {
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field,
+    wireval* val, int op) {
   void* mem = UPB_PTR_AT(msg, field->UPB_PRIVATE(offset), void);
   int type = field->UPB_PRIVATE(descriptortype);
 
@@ -819,9 +821,9 @@
   if (UPB_UNLIKELY(!ext)) {
     _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory);
   }
-  upb_Message* submsg = _upb_Decoder_NewSubMessage(
-      d, &ext->ext->UPB_PRIVATE(sub), &ext->ext->UPB_PRIVATE(field),
-      (upb_TaggedMessagePtr*)&ext->data);
+  upb_Message* submsg = _upb_Decoder_NewSubMessage2(
+      d, ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(submsg),
+      &ext->ext->UPB_PRIVATE(field), (upb_TaggedMessagePtr*)&ext->data);
   upb_DecodeStatus status = upb_Decode(
       data, size, submsg, upb_MiniTableExtension_GetSubMessage(item_mt),
       d->extreg, d->options, &d->arena);
@@ -1022,8 +1024,9 @@
     // unlinked.
     do {
       UPB_ASSERT(upb_MiniTableField_CType(oneof) == kUpb_CType_Message);
-      const upb_MiniTableSub* oneof_sub =
-          &mt->UPB_PRIVATE(subs)[oneof->UPB_PRIVATE(submsg_index)];
+      const upb_MiniTable* oneof_sub =
+          *mt->UPB_PRIVATE(subs)[oneof->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(
+              submsg);
       UPB_ASSERT(!oneof_sub);
     } while (upb_MiniTable_NextOneofField(mt, &oneof));
   }
@@ -1161,8 +1164,9 @@
                                           const upb_MiniTable* layout,
                                           const upb_MiniTableField* field,
                                           int op, wireval* val) {
-  const upb_MiniTableSub* subs = layout->UPB_PRIVATE(subs);
+  const upb_MiniTableSubInternal* subs = layout->UPB_PRIVATE(subs);
   uint8_t mode = field->UPB_PRIVATE(mode);
+  upb_MiniTableSubInternal ext_sub;
 
   if (UPB_UNLIKELY(mode & kUpb_LabelFlags_IsExtension)) {
     const upb_MiniTableExtension* ext_layout =
@@ -1174,7 +1178,14 @@
     }
     d->unknown_msg = msg;
     msg = (upb_Message*)&ext->data;
-    subs = &ext->ext->UPB_PRIVATE(sub);
+    if (upb_MiniTableField_IsSubMessage(&ext->ext->UPB_PRIVATE(field))) {
+      ext_sub.UPB_PRIVATE(submsg) =
+          &ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(submsg);
+    } else {
+      ext_sub.UPB_PRIVATE(subenum) =
+          ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(subenum);
+    }
+    subs = &ext_sub;
   }
 
   switch (mode & kUpb_FieldMode_Mask) {
diff --git a/upb/wire/encode.c b/upb/wire/encode.c
index 0b6d7c3..5764199 100644
--- a/upb/wire/encode.c
+++ b/upb/wire/encode.c
@@ -35,14 +35,21 @@
 #include "upb/mini_table/field.h"
 #include "upb/mini_table/internal/field.h"
 #include "upb/mini_table/internal/message.h"
+#include "upb/mini_table/internal/sub.h"
 #include "upb/mini_table/message.h"
-#include "upb/mini_table/sub.h"
 #include "upb/wire/internal/constants.h"
 #include "upb/wire/types.h"
 
 // Must be last.
 #include "upb/port/def.inc"
 
+// Returns the MiniTable corresponding to a given MiniTableField
+// from an array of MiniTableSubs.
+static const upb_MiniTable* _upb_Encoder_GetSubMiniTable(
+    const upb_MiniTableSubInternal* subs, const upb_MiniTableField* field) {
+  return *subs[field->UPB_PRIVATE(submsg_index)].UPB_PRIVATE(submsg);
+}
+
 #define UPB_PB_VARINT_MAX_LEN 10
 
 UPB_NOINLINE
@@ -224,7 +231,7 @@
 }
 
 static void encode_scalar(upb_encstate* e, const void* _field_mem,
-                          const upb_MiniTableSub* subs,
+                          const upb_MiniTableSubInternal* subs,
                           const upb_MiniTableField* f) {
   const char* field_mem = _field_mem;
   int wire_type;
@@ -273,8 +280,7 @@
     case kUpb_FieldType_Group: {
       size_t size;
       upb_TaggedMessagePtr submsg = *(upb_TaggedMessagePtr*)field_mem;
-      const upb_MiniTable* subm =
-          upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]);
+      const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f);
       if (submsg == 0) {
         return;
       }
@@ -288,8 +294,7 @@
     case kUpb_FieldType_Message: {
       size_t size;
       upb_TaggedMessagePtr submsg = *(upb_TaggedMessagePtr*)field_mem;
-      const upb_MiniTable* subm =
-          upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]);
+      const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f);
       if (submsg == 0) {
         return;
       }
@@ -309,7 +314,7 @@
 }
 
 static void encode_array(upb_encstate* e, const upb_Message* msg,
-                         const upb_MiniTableSub* subs,
+                         const upb_MiniTableSubInternal* subs,
                          const upb_MiniTableField* f) {
   const upb_Array* arr = *UPB_PTR_AT(msg, f->UPB_PRIVATE(offset), upb_Array*);
   bool packed = upb_MiniTableField_IsPacked(f);
@@ -379,8 +384,7 @@
     case kUpb_FieldType_Group: {
       const upb_TaggedMessagePtr* start = upb_Array_DataPtr(arr);
       const upb_TaggedMessagePtr* ptr = start + upb_Array_Size(arr);
-      const upb_MiniTable* subm =
-          upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]);
+      const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f);
       if (--e->depth == 0) encode_err(e, kUpb_EncodeStatus_MaxDepthExceeded);
       do {
         size_t size;
@@ -395,8 +399,7 @@
     case kUpb_FieldType_Message: {
       const upb_TaggedMessagePtr* start = upb_Array_DataPtr(arr);
       const upb_TaggedMessagePtr* ptr = start + upb_Array_Size(arr);
-      const upb_MiniTable* subm =
-          upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]);
+      const upb_MiniTable* subm = _upb_Encoder_GetSubMiniTable(subs, f);
       if (--e->depth == 0) encode_err(e, kUpb_EncodeStatus_MaxDepthExceeded);
       do {
         size_t size;
@@ -432,11 +435,10 @@
 }
 
 static void encode_map(upb_encstate* e, const upb_Message* msg,
-                       const upb_MiniTableSub* subs,
+                       const upb_MiniTableSubInternal* subs,
                        const upb_MiniTableField* f) {
   const upb_Map* map = *UPB_PTR_AT(msg, f->UPB_PRIVATE(offset), const upb_Map*);
-  const upb_MiniTable* layout =
-      upb_MiniTableSub_Message(subs[f->UPB_PRIVATE(submsg_index)]);
+  const upb_MiniTable* layout = _upb_Encoder_GetSubMiniTable(subs, f);
   UPB_ASSERT(upb_MiniTable_FieldCount(layout) == 2);
 
   if (!map || !upb_Map_Size(map)) return;
@@ -465,7 +467,6 @@
 }
 
 static bool encode_shouldencode(upb_encstate* e, const upb_Message* msg,
-                                const upb_MiniTableSub* subs,
                                 const upb_MiniTableField* f) {
   if (f->presence == 0) {
     // Proto3 presence or map/array.
@@ -504,7 +505,7 @@
 }
 
 static void encode_field(upb_encstate* e, const upb_Message* msg,
-                         const upb_MiniTableSub* subs,
+                         const upb_MiniTableSubInternal* subs,
                          const upb_MiniTableField* field) {
   switch (UPB_PRIVATE(_upb_MiniTableField_Mode)(field)) {
     case kUpb_FieldMode_Array:
@@ -539,7 +540,14 @@
   if (UPB_UNLIKELY(is_message_set)) {
     encode_msgset_item(e, ext);
   } else {
-    encode_field(e, (upb_Message*)&ext->data, &ext->ext->UPB_PRIVATE(sub),
+    upb_MiniTableSubInternal sub;
+    if (upb_MiniTableField_IsSubMessage(&ext->ext->UPB_PRIVATE(field))) {
+      sub.UPB_PRIVATE(submsg) = &ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(submsg);
+    } else {
+      sub.UPB_PRIVATE(subenum) =
+          ext->ext->UPB_PRIVATE(sub).UPB_PRIVATE(subenum);
+    }
+    encode_field(e, (upb_Message*)&ext->data, &sub,
                  &ext->ext->UPB_PRIVATE(field));
   }
 }
@@ -595,7 +603,7 @@
     const upb_MiniTableField* first = &m->UPB_PRIVATE(fields)[0];
     while (f != first) {
       f--;
-      if (encode_shouldencode(e, msg, m->UPB_PRIVATE(subs), f)) {
+      if (encode_shouldencode(e, msg, f)) {
         encode_field(e, msg, m->UPB_PRIVATE(subs), f);
       }
     }
@@ -682,4 +690,4 @@
     default:
       return "Unknown encode status";
   }
-}
\ No newline at end of file
+}
diff --git a/upb_generator/protoc-gen-upb_minitable.cc b/upb_generator/protoc-gen-upb_minitable.cc
index 7c15308..9e80e19 100644
--- a/upb_generator/protoc-gen-upb_minitable.cc
+++ b/upb_generator/protoc-gen-upb_minitable.cc
@@ -70,6 +70,10 @@
   return absl::StrCat(ExtensionIdentBase(ext), "_", ext.name(), "_ext");
 }
 
+std::string MessagePtrName(upb::MessageDefPtr message) {
+  return MessageInitName(message) + "_ptr";
+}
+
 const char* kEnumsInit = "enums_layout";
 const char* kExtensionsInit = "extensions_layout";
 const char* kMessagesInit = "messages_layout";
@@ -312,10 +316,11 @@
   output("  $0,\n", upb::generator::FieldInitializer(field, field64, field32));
 }
 
-std::string GetSub(upb::FieldDefPtr field) {
+std::string GetSub(upb::FieldDefPtr field, bool is_extension) {
   if (auto message_def = field.message_type()) {
     return absl::Substitute("{.UPB_PRIVATE(submsg) = &$0}",
-                            MessageInitName(message_def));
+                            is_extension ? MessageInitName(message_def)
+                                         : MessagePtrName(message_def));
   }
 
   if (auto enum_def = field.enum_subdef()) {
@@ -345,17 +350,18 @@
     uint32_t index = f->UPB_PRIVATE(submsg_index);
     if (index != kUpb_NoSub) {
       const int f_number = upb_MiniTableField_Number(f);
-      auto pair =
-          subs.emplace(index, GetSub(message.FindFieldByNumber(f_number)));
+      upb::FieldDefPtr field = message.FindFieldByNumber(f_number);
+      auto pair = subs.emplace(index, GetSub(field, false));
       ABSL_CHECK(pair.second);
     }
   }
-  // Write upb_MiniTableSub table for sub messages referenced from fields.
+  // Write upb_MiniTableSubInternal table for sub messages referenced from
+  // fields.
   if (!subs.empty()) {
     std::string submsgs_array_name = msg_name + "_submsgs";
     submsgs_array_ref = "&" + submsgs_array_name + "[0]";
-    output("static const upb_MiniTableSub $0[$1] = {\n", submsgs_array_name,
-           subs.size());
+    output("static const upb_MiniTableSubInternal $0[$1] = {\n",
+           submsgs_array_name, subs.size());
 
     int i = 0;
     for (const auto& pair : subs) {
@@ -421,6 +427,8 @@
     output("  })\n");
   }
   output("};\n\n");
+  output("const upb_MiniTable* $0 = &$1;\n", MessagePtrName(message),
+         MessageInitName(message));
 }
 
 void WriteEnum(upb::EnumDefPtr e, Output& output) {
@@ -492,7 +500,7 @@
                     Output& output) {
   output("$0,\n", FieldInitializer(pools, ext));
   output("  &$0,\n", MessageInitName(ext.containing_type()));
-  output("  $0,\n", GetSub(ext));
+  output("  $0,\n", GetSub(ext, true));
 }
 
 int WriteExtensions(const DefPoolPair& pools, upb::FileDefPtr file,
@@ -571,6 +579,7 @@
 
   for (auto message : this_file_messages) {
     output("extern const upb_MiniTable $0;\n", MessageInitName(message));
+    output("extern const upb_MiniTable* $0;\n", MessagePtrName(message));
   }
   for (auto ext : this_file_exts) {
     output("extern const upb_MiniTableExtension $0;\n", ExtensionLayout(ext));