Fix SPLIT_ALL_COMMA_SEPARATED_VALUES and SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES being too agressive for lambdas and unpacking. (#1173)

Co-authored-by: Bill Wendling <morbo@google.com>
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d8127cf..ca880fb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,8 @@
 ### Fixed
 - Fix SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED for one-item named argument lists
   by taking precedence over SPLIT_BEFORE_NAMED_ASSIGNS.
+- Fix SPLIT_ALL_COMMA_SEPARATED_VALUES and SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES
+  being too agressive for lambdas and unpacking.
 
 ## [0.40.2] 2023-09-22
 ### Changes
diff --git a/yapf/pytree/subtype_assigner.py b/yapf/pytree/subtype_assigner.py
index 05d88b0..e3b3277 100644
--- a/yapf/pytree/subtype_assigner.py
+++ b/yapf/pytree/subtype_assigner.py
@@ -222,6 +222,11 @@
       if isinstance(child, pytree.Leaf) and child.value == '**':
         _AppendTokenSubtype(child, subtypes.BINARY_OPERATOR)
 
+  def Visit_lambdef(self, node):  # pylint: disable=invalid-name
+    # trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
+    _AppendSubtypeRec(node, subtypes.LAMBDEF)
+    self.DefaultNodeVisit(node)
+
   def Visit_trailer(self, node):  # pylint: disable=invalid-name
     for child in node.children:
       self.Visit(child)
diff --git a/yapf/yapflib/format_decision_state.py b/yapf/yapflib/format_decision_state.py
index ce74313..06f3455 100644
--- a/yapf/yapflib/format_decision_state.py
+++ b/yapf/yapflib/format_decision_state.py
@@ -180,6 +180,10 @@
       return False
 
     if style.Get('SPLIT_ALL_COMMA_SEPARATED_VALUES') and previous.value == ',':
+      if (subtypes.COMP_FOR in current.subtypes or
+          subtypes.LAMBDEF in current.subtypes):
+        return False
+
       return True
 
     if (style.Get('FORCE_MULTILINE_DICT') and
@@ -188,6 +192,11 @@
 
     if (style.Get('SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES') and
         previous.value == ','):
+
+      if (subtypes.COMP_FOR in current.subtypes or
+          subtypes.LAMBDEF in current.subtypes):
+        return False
+
       # Avoid breaking in a container that fits in the current line if possible
       opening = _GetOpeningBracket(current)
 
diff --git a/yapf/yapflib/subtypes.py b/yapf/yapflib/subtypes.py
index b4b7efe..3c234fb 100644
--- a/yapf/yapflib/subtypes.py
+++ b/yapf/yapflib/subtypes.py
@@ -38,3 +38,4 @@
 SIMPLE_EXPRESSION = 22
 PARAMETER_START = 23
 PARAMETER_STOP = 24
+LAMBDEF = 25
diff --git a/yapftests/reformatter_basic_test.py b/yapftests/reformatter_basic_test.py
index 24f34a6..d58343e 100644
--- a/yapftests/reformatter_basic_test.py
+++ b/yapftests/reformatter_basic_test.py
@@ -91,6 +91,30 @@
     """)
     llines = yapf_test_helper.ParseAndUnwrap(unformatted_code)
     self.assertCodeEqual(expected_formatted_code, reformatter.Reformat(llines))
+    unformatted_code = textwrap.dedent("""\
+        values = [ lambda arg1, arg2: arg1 + arg2 ]
+    """)  # noqa
+    expected_formatted_code = textwrap.dedent("""\
+        values = [
+            lambda arg1, arg2: arg1 + arg2
+        ]
+    """)
+    llines = yapf_test_helper.ParseAndUnwrap(unformatted_code)
+    self.assertCodeEqual(expected_formatted_code, reformatter.Reformat(llines))
+    unformatted_code = textwrap.dedent("""\
+        values = [
+            (some_arg1, some_arg2) for some_arg1, some_arg2 in values
+        ]
+    """)  # noqa
+    expected_formatted_code = textwrap.dedent("""\
+        values = [
+            (some_arg1,
+             some_arg2)
+            for some_arg1, some_arg2 in values
+        ]
+    """)
+    llines = yapf_test_helper.ParseAndUnwrap(unformatted_code)
+    self.assertCodeEqual(expected_formatted_code, reformatter.Reformat(llines))
     # There is a test for split_all_top_level_comma_separated_values, with
     # different expected value
     unformatted_code = textwrap.dedent("""\
@@ -161,6 +185,32 @@
     """)
     llines = yapf_test_helper.ParseAndUnwrap(unformatted_code)
     self.assertCodeEqual(expected_formatted_code, reformatter.Reformat(llines))
+    # Works the same way as split_all_comma_separated_values
+    unformatted_code = textwrap.dedent("""\
+        values = [ lambda arg1, arg2: arg1 + arg2 ]
+    """)  # noqa
+    expected_formatted_code = textwrap.dedent("""\
+        values = [
+            lambda arg1, arg2: arg1 + arg2
+        ]
+    """)
+    llines = yapf_test_helper.ParseAndUnwrap(unformatted_code)
+    self.assertCodeEqual(expected_formatted_code, reformatter.Reformat(llines))
+    # There is a test for split_all_comma_separated_values, with different
+    # expected value
+    unformatted_code = textwrap.dedent("""\
+        values = [
+            (some_arg1, some_arg2) for some_arg1, some_arg2 in values
+        ]
+    """)  # noqa
+    expected_formatted_code = textwrap.dedent("""\
+        values = [
+            (some_arg1, some_arg2)
+            for some_arg1, some_arg2 in values
+        ]
+    """)
+    llines = yapf_test_helper.ParseAndUnwrap(unformatted_code)
+    self.assertCodeEqual(expected_formatted_code, reformatter.Reformat(llines))
     # There is a test for split_all_comma_separated_values, with different
     # expected value
     unformatted_code = textwrap.dedent("""\