Fix FieldSet to not load lazy fields when the fields map is cloned. In particular, if MessageSet.Builder has build() called on it, but is then later further modified, we don't need to pre-load all lazy fields during this later modification.

PiperOrigin-RevId: 581941605
diff --git a/java/core/BUILD.bazel b/java/core/BUILD.bazel
index b896c42..30c0e90 100644
--- a/java/core/BUILD.bazel
+++ b/java/core/BUILD.bazel
@@ -2,10 +2,10 @@
 load("@rules_java//java:defs.bzl", "java_lite_proto_library", "java_proto_library")
 load("@rules_pkg//:mappings.bzl", "pkg_files", "strip_prefix")
 load("@rules_proto//proto:defs.bzl", "proto_lang_toolchain", "proto_library")
-load("//build_defs:java_opts.bzl", "protobuf_java_export", "protobuf_java_library", "protobuf_versioned_java_library")
-load("//conformance:defs.bzl", "conformance_test")
 load("//:protobuf.bzl", "internal_gen_well_known_protos_java")
 load("//:protobuf_version.bzl", "PROTOBUF_JAVA_VERSION")
+load("//build_defs:java_opts.bzl", "protobuf_java_export", "protobuf_java_library", "protobuf_versioned_java_library")
+load("//conformance:defs.bzl", "conformance_test")
 load("//java/internal:testing.bzl", "junit_tests")
 
 LITE_SRCS = [
@@ -472,6 +472,7 @@
     "src/test/java/com/google/protobuf/FieldPresenceTest.java",
     "src/test/java/com/google/protobuf/ForceFieldBuildersPreRun.java",
     "src/test/java/com/google/protobuf/GeneratedMessageTest.java",
+    "src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java",
     "src/test/java/com/google/protobuf/LazyFieldTest.java",
     "src/test/java/com/google/protobuf/LazyStringEndToEndTest.java",
     "src/test/java/com/google/protobuf/MapForProto2Test.java",
diff --git a/java/core/src/main/java/com/google/protobuf/FieldSet.java b/java/core/src/main/java/com/google/protobuf/FieldSet.java
index bb3eea9..a8ba1bd 100644
--- a/java/core/src/main/java/com/google/protobuf/FieldSet.java
+++ b/java/core/src/main/java/com/google/protobuf/FieldSet.java
@@ -173,7 +173,8 @@
   /** Get a simple map containing all the fields. */
   public Map<T, Object> getAllFields() {
     if (hasLazyField) {
-      SmallSortedMap<T, Object> result = cloneAllFieldsMap(fields, /* copyList */ false);
+      SmallSortedMap<T, Object> result =
+          cloneAllFieldsMap(fields, /* copyList= */ false, /* resolveLazyFields= */ true);
       if (fields.isImmutable()) {
         result.makeImmutable();
       }
@@ -183,22 +184,22 @@
   }
 
   private static <T extends FieldDescriptorLite<T>> SmallSortedMap<T, Object> cloneAllFieldsMap(
-      SmallSortedMap<T, Object> fields, boolean copyList) {
+      SmallSortedMap<T, Object> fields, boolean copyList, boolean resolveLazyFields) {
     SmallSortedMap<T, Object> result = SmallSortedMap.newFieldMap(DEFAULT_FIELD_MAP_ARRAY_SIZE);
     for (int i = 0; i < fields.getNumArrayEntries(); i++) {
-      cloneFieldEntry(result, fields.getArrayEntryAt(i), copyList);
+      cloneFieldEntry(result, fields.getArrayEntryAt(i), copyList, resolveLazyFields);
     }
     for (Map.Entry<T, Object> entry : fields.getOverflowEntries()) {
-      cloneFieldEntry(result, entry, copyList);
+      cloneFieldEntry(result, entry, copyList, resolveLazyFields);
     }
     return result;
   }
 
   private static <T extends FieldDescriptorLite<T>> void cloneFieldEntry(
-      Map<T, Object> map, Map.Entry<T, Object> entry, boolean copyList) {
+      Map<T, Object> map, Map.Entry<T, Object> entry, boolean copyList, boolean resolveLazyFields) {
     T key = entry.getKey();
     Object value = entry.getValue();
-    if (value instanceof LazyField) {
+    if (resolveLazyFields && value instanceof LazyField) {
       map.put(key, ((LazyField) value).getValue());
     } else if (copyList && value instanceof List) {
       map.put(key, new ArrayList<>((List<?>) value));
@@ -958,7 +959,8 @@
       SmallSortedMap<T, Object> fieldsForBuild = fields;
       if (hasNestedBuilders) {
         // Make a copy of the fields map with all Builders replaced by Message.
-        fieldsForBuild = cloneAllFieldsMap(fields, /* copyList */ false);
+        fieldsForBuild =
+            cloneAllFieldsMap(fields, /* copyList= */ false, /* resolveLazyFields= */ false);
         replaceBuilders(fieldsForBuild, partial);
       }
       FieldSet<T> fieldSet = new FieldSet<>(fieldsForBuild);
@@ -1030,7 +1032,10 @@
 
     /** Returns a new Builder using the fields from {@code fieldSet}. */
     public static <T extends FieldDescriptorLite<T>> Builder<T> fromFieldSet(FieldSet<T> fieldSet) {
-      Builder<T> builder = new Builder<T>(cloneAllFieldsMap(fieldSet.fields, /* copyList */ true));
+      Builder<T> builder =
+          new Builder<T>(
+              cloneAllFieldsMap(
+                  fieldSet.fields, /* copyList= */ true, /* resolveLazyFields= */ false));
       builder.hasLazyField = fieldSet.hasLazyField;
       return builder;
     }
@@ -1040,7 +1045,8 @@
     /** Get a simple map containing all the fields. */
     public Map<T, Object> getAllFields() {
       if (hasLazyField) {
-        SmallSortedMap<T, Object> result = cloneAllFieldsMap(fields, /* copyList */ false);
+        SmallSortedMap<T, Object> result =
+            cloneAllFieldsMap(fields, /* copyList= */ false, /* resolveLazyFields= */ true);
         if (fields.isImmutable()) {
           result.makeImmutable();
         } else {
@@ -1081,7 +1087,7 @@
 
     private void ensureIsMutable() {
       if (!isMutable) {
-        fields = cloneAllFieldsMap(fields, /* copyList */ true);
+        fields = cloneAllFieldsMap(fields, /* copyList= */ true, /* resolveLazyFields= */ false);
         isMutable = true;
       }
     }
diff --git a/java/core/src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java b/java/core/src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java
new file mode 100644
index 0000000..c41a381
--- /dev/null
+++ b/java/core/src/test/java/com/google/protobuf/LazilyParsedMessageSetTest.java
@@ -0,0 +1,162 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc.  All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+package com.google.protobuf;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import protobuf_unittest.UnittestMset.RawMessageSet;
+import protobuf_unittest.UnittestMset.TestMessageSetExtension1;
+import protobuf_unittest.UnittestMset.TestMessageSetExtension2;
+import protobuf_unittest.UnittestMset.TestMessageSetExtension3;
+import proto2_wireformat_unittest.UnittestMsetWireFormat.TestMessageSet;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests related to handling of MessageSets with lazily parsed extensions. */
+@RunWith(JUnit4.class)
+public class LazilyParsedMessageSetTest {
+  private static final int TYPE_ID_1 =
+      TestMessageSetExtension1.getDescriptor().getExtensions().get(0).getNumber();
+  private static final int TYPE_ID_2 =
+      TestMessageSetExtension2.getDescriptor().getExtensions().get(0).getNumber();
+  private static final int TYPE_ID_3 =
+      TestMessageSetExtension3.getDescriptor().getExtensions().get(0).getNumber();
+  private static final ByteString CORRUPTED_MESSAGE_PAYLOAD =
+      ByteString.copyFrom(new byte[] {(byte) 0xff});
+
+  @Before
+  public void setUp() {
+    ExtensionRegistryLite.setEagerlyParseMessageSets(false);
+  }
+
+  @Test
+  public void testParseAndUpdateMessageSet_unaccessedLazyFieldsAreNotLoaded() throws Exception {
+    ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
+    extensionRegistry.add(TestMessageSetExtension1.messageSetExtension);
+    extensionRegistry.add(TestMessageSetExtension2.messageSetExtension);
+    extensionRegistry.add(TestMessageSetExtension3.messageSetExtension);
+
+    // Set up a TestMessageSet with 2 extensions. The first extension has corrupted payload
+    // data. The test below makes sure that we never load this extension. If we ever do, then we
+    // will handle the exception and replace the value with the default empty message (this behavior
+    // is tested below in testLoadCorruptedLazyField_getsReplacedWithEmptyMessage). Later on we
+    // check that when we serialize the message set, we still have corrupted payload for the first
+    // extension.
+    RawMessageSet inputRaw =
+        RawMessageSet.newBuilder()
+            .addItem(
+                RawMessageSet.Item.newBuilder()
+                    .setTypeId(TYPE_ID_1)
+                    .setMessage(CORRUPTED_MESSAGE_PAYLOAD))
+            .addItem(
+                RawMessageSet.Item.newBuilder()
+                    .setTypeId(TYPE_ID_2)
+                    .setMessage(
+                        TestMessageSetExtension2.newBuilder().setStr("foo").build().toByteString()))
+            .build();
+
+    ByteString inputData = inputRaw.toByteString();
+
+    // Re-parse as a TestMessageSet, so that all extensions are lazy
+    TestMessageSet messageSet = TestMessageSet.parseFrom(inputData, extensionRegistry);
+
+    // Update one extension and add a new one.
+    TestMessageSet.Builder builder = messageSet.toBuilder();
+    builder.setExtension(
+        TestMessageSetExtension2.messageSetExtension,
+        TestMessageSetExtension2.newBuilder().setStr("bar").build());
+
+    // Call .build() in the middle of updating the builder. This triggers a codepath that we want to
+    // make sure preserves lazy fields.
+    TestMessageSet unusedIntermediateMessageSet = builder.build();
+
+    builder.setExtension(
+        TestMessageSetExtension3.messageSetExtension,
+        TestMessageSetExtension3.newBuilder().setRequiredInt(666).build());
+
+    TestMessageSet updatedMessageSet = builder.build();
+
+    // Check that hasExtension call does not load lazy fields.
+    assertThat(updatedMessageSet.hasExtension(TestMessageSetExtension1.messageSetExtension))
+        .isTrue();
+
+    // Serialize. The first extension should still be unloaded and will get serialized using the
+    // same corrupted byte array.
+    ByteString outputData = updatedMessageSet.toByteString();
+
+    // Re-parse as RawMessageSet
+    RawMessageSet actualRaw =
+        RawMessageSet.parseFrom(outputData, ExtensionRegistry.getEmptyRegistry());
+
+    RawMessageSet expectedRaw =
+        RawMessageSet.newBuilder()
+            .addItem(
+                RawMessageSet.Item.newBuilder()
+                    .setTypeId(TYPE_ID_1)
+                    // This is the important part -- we want to make sure that the payload of the
+                    // 1st extensions is the same corrupted byte array. If we ever load the
+                    // extension during our manipulations above, then we would have replaced it with
+                    // the default empty message.
+                    .setMessage(CORRUPTED_MESSAGE_PAYLOAD))
+            .addItem(
+                RawMessageSet.Item.newBuilder()
+                    .setTypeId(TYPE_ID_2)
+                    .setMessage(
+                        TestMessageSetExtension2.newBuilder().setStr("bar").build().toByteString()))
+            .addItem(
+                RawMessageSet.Item.newBuilder()
+                    .setTypeId(TYPE_ID_3)
+                    .setMessage(
+                        TestMessageSetExtension3.newBuilder()
+                            .setRequiredInt(666)
+                            .build()
+                            .toByteString()))
+            .build();
+
+    assertThat(actualRaw).isEqualTo(expectedRaw);
+  }
+
+  @Test
+  public void testLoadCorruptedLazyField_getsReplacedWithEmptyMessage() throws Exception {
+    ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
+    extensionRegistry.add(TestMessageSetExtension1.messageSetExtension);
+
+    RawMessageSet inputRaw =
+        RawMessageSet.newBuilder()
+            .addItem(
+                RawMessageSet.Item.newBuilder()
+                    .setTypeId(TYPE_ID_1)
+                    .setMessage(CORRUPTED_MESSAGE_PAYLOAD))
+            .build();
+
+    ByteString inputData = inputRaw.toByteString();
+
+    // Re-parse as a TestMessageSet, so that all extensions are lazy
+    TestMessageSet messageSet = TestMessageSet.parseFrom(inputData, extensionRegistry);
+
+    assertThat(messageSet.getExtension(TestMessageSetExtension1.messageSetExtension))
+        .isEqualTo(TestMessageSetExtension1.getDefaultInstance());
+
+    // Serialize. The first extension should be serialized as an empty message.
+    ByteString outputData = messageSet.toByteString();
+
+    // Re-parse as RawMessageSet
+    RawMessageSet actualRaw =
+        RawMessageSet.parseFrom(outputData, ExtensionRegistry.getEmptyRegistry());
+
+    RawMessageSet expectedRaw =
+        RawMessageSet.newBuilder()
+            .addItem(
+                RawMessageSet.Item.newBuilder().setTypeId(TYPE_ID_1).setMessage(ByteString.empty()))
+            .build();
+
+    assertThat(actualRaw).isEqualTo(expectedRaw);
+  }
+}
diff --git a/java/core/src/test/java/com/google/protobuf/WireFormatTest.java b/java/core/src/test/java/com/google/protobuf/WireFormatTest.java
index 4afeff8..bbf8d0c 100644
--- a/java/core/src/test/java/com/google/protobuf/WireFormatTest.java
+++ b/java/core/src/test/java/com/google/protobuf/WireFormatTest.java
@@ -12,7 +12,6 @@
 import protobuf_unittest.UnittestMset.RawMessageSet;
 import protobuf_unittest.UnittestMset.TestMessageSetExtension1;
 import protobuf_unittest.UnittestMset.TestMessageSetExtension2;
-import protobuf_unittest.UnittestMset.TestMessageSetExtension3;
 import protobuf_unittest.UnittestProto;
 import protobuf_unittest.UnittestProto.TestAllExtensions;
 import protobuf_unittest.UnittestProto.TestAllTypes;
@@ -506,73 +505,6 @@
         .isEqualTo(123);
   }
 
-  @Test
-  public void testParseAndUpdateMessageSetExtensionEagerly() throws Exception {
-    testParseAndUpdateMessageSetExtensionEagerlyWithFlag(true);
-  }
-
-  @Test
-  public void testParseAndUpdateMessageSetExtensionNotEagerly() throws Exception {
-    testParseAndUpdateMessageSetExtensionEagerlyWithFlag(false);
-  }
-
-  private void testParseAndUpdateMessageSetExtensionEagerlyWithFlag(boolean eagerParsing)
-      throws Exception {
-    ExtensionRegistryLite.setEagerlyParseMessageSets(eagerParsing);
-    ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
-    extensionRegistry.add(TestMessageSetExtension1.messageSetExtension);
-    extensionRegistry.add(TestMessageSetExtension2.messageSetExtension);
-    extensionRegistry.add(TestMessageSetExtension3.messageSetExtension);
-
-    // Set up a RawMessageSet with 2 extensions
-    RawMessageSet raw =
-        RawMessageSet.newBuilder()
-            .addItem(
-                RawMessageSet.Item.newBuilder()
-                    .setTypeId(TYPE_ID_1)
-                    .setMessage(
-                        TestMessageSetExtension1.newBuilder().setI(123).build().toByteString())
-                    .build())
-            .addItem(
-                RawMessageSet.Item.newBuilder()
-                    .setTypeId(TYPE_ID_2)
-                    .setMessage(
-                        TestMessageSetExtension2.newBuilder().setStr("foo").build().toByteString())
-                    .build())
-            .build();
-
-    ByteString data = raw.toByteString();
-
-    // Parse as a TestMessageSet.
-    TestMessageSet messageSet = TestMessageSet.parseFrom(data, extensionRegistry);
-
-    // Update one extension and add a new one.
-    TestMessageSet.Builder builder = messageSet.toBuilder();
-    builder.setExtension(
-        TestMessageSetExtension2.messageSetExtension,
-        TestMessageSetExtension2.newBuilder().setStr("bar").build());
-    builder.setExtension(
-        TestMessageSetExtension3.messageSetExtension,
-        TestMessageSetExtension3.newBuilder().setRequiredInt(666).build());
-
-    TestMessageSet updatedMessageSet = builder.build();
-    // Check all 3 extensions
-    assertThat(updatedMessageSet.getExtension(TestMessageSetExtension1.messageSetExtension).getI())
-        .isEqualTo(123);
-    assertThat(
-            updatedMessageSet.getExtension(TestMessageSetExtension2.messageSetExtension).getStr())
-        .isEqualTo("bar");
-    assertThat(
-            updatedMessageSet
-                .getExtension(TestMessageSetExtension3.messageSetExtension)
-                .getRequiredInt())
-        .isEqualTo(666);
-
-    // Serialize and re-parse, and make sure we get the same message back
-    assertThat(TestMessageSet.parseFrom(updatedMessageSet.toByteString(), extensionRegistry))
-        .isEqualTo(updatedMessageSet);
-  }
-
   // ================================================================
   // oneof
   @Test