From 7eeebc41964d4c52a7924d70ee31f2e03cf3c973 Mon Sep 17 00:00:00 2001 From: Shantanu Jain Date: Sat, 7 Sep 2024 21:13:46 -0700 Subject: [PATCH 1/5] pure refactoring --- src/black/linegen.py | 75 +++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/src/black/linegen.py b/src/black/linegen.py index 46945ca2a14..401188fda6a 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -1079,6 +1079,40 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None ) +def _ensure_trailing_comma(leaves: List[Leaf], original: Line, opening_bracket: Leaf) -> bool: + if not leaves: + return False + # Ensure a trailing comma for imports + if original.is_import: + return True + # ...and standalone function arguments + if not original.is_def: + return False + if opening_bracket.value != "(": + return False + # Don't add commas if we already have any commas + if any( + leaf.type == token.COMMA + and ( + Preview.typed_params_trailing_comma not in original.mode + or not is_part_of_annotation(leaf) + ) + for leaf in leaves + ): + return False + # Don't add commas inside parenthesized return annotations + if get_annotation_type(leaves[0]) == "return": + return False + # Don't add commas inside PEP 604 unions + if ( + leaves[0].parent + and leaves[0].parent.next_sibling + and leaves[0].parent.next_sibling.type == token.VBAR + ): + return False + return True + + def bracket_split_build_line( leaves: List[Leaf], original: Line, @@ -1099,40 +1133,15 @@ def bracket_split_build_line( if component is _BracketSplitComponent.body: result.inside_brackets = True result.depth += 1 - if leaves: - no_commas = ( - # Ensure a trailing comma for imports and standalone function arguments - original.is_def - # Don't add one after any comments or within type annotations - and opening_bracket.value == "(" - # Don't add one if there's already one there - and not any( - leaf.type == token.COMMA - and ( - Preview.typed_params_trailing_comma not in original.mode - or not is_part_of_annotation(leaf) - ) - for leaf in leaves - ) - # Don't add one inside parenthesized return annotations - and get_annotation_type(leaves[0]) != "return" - # Don't add one inside PEP 604 unions - and not ( - leaves[0].parent - and leaves[0].parent.next_sibling - and leaves[0].parent.next_sibling.type == token.VBAR - ) - ) - - if original.is_import or no_commas: - for i in range(len(leaves) - 1, -1, -1): - if leaves[i].type == STANDALONE_COMMENT: - continue + if _ensure_trailing_comma(leaves, original, opening_bracket): + for i in range(len(leaves) - 1, -1, -1): + if leaves[i].type == STANDALONE_COMMENT: + continue - if leaves[i].type != token.COMMA: - new_comma = Leaf(token.COMMA, ",") - leaves.insert(i + 1, new_comma) - break + if leaves[i].type != token.COMMA: + new_comma = Leaf(token.COMMA, ",") + leaves.insert(i + 1, new_comma) + break leaves_to_track: Set[LeafID] = set() if component is _BracketSplitComponent.head: From 97a90ae3dcfc83923c268808283ca23c37eb10db Mon Sep 17 00:00:00 2001 From: Shantanu Jain Date: Sat, 7 Sep 2024 22:46:25 -0700 Subject: [PATCH 2/5] fix crashes --- CHANGES.md | 3 ++ src/black/linegen.py | 17 ++++-- src/black/nodes.py | 1 + src/black/trans.py | 2 +- .../funcdef_return_type_trailing_comma.py | 53 +++++++++++++++++++ 5 files changed, 70 insertions(+), 6 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b9909b80149..c3e39fb0abd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -18,6 +18,9 @@ +- Fix crashes involving comments in parenthesised return types or `X | Y` style unions. + (#4453) + ### Preview style diff --git a/src/black/linegen.py b/src/black/linegen.py index 401188fda6a..ba6e906a388 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -1079,7 +1079,9 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None ) -def _ensure_trailing_comma(leaves: List[Leaf], original: Line, opening_bracket: Leaf) -> bool: +def _ensure_trailing_comma( + leaves: List[Leaf], original: Line, opening_bracket: Leaf +) -> bool: if not leaves: return False # Ensure a trailing comma for imports @@ -1100,14 +1102,19 @@ def _ensure_trailing_comma(leaves: List[Leaf], original: Line, opening_bracket: for leaf in leaves ): return False + + # Find a leaf with a parent (comments don't have parents) + leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None) + if leaf_with_parent is None: + return True # Don't add commas inside parenthesized return annotations - if get_annotation_type(leaves[0]) == "return": + if get_annotation_type(leaf_with_parent) == "return": return False # Don't add commas inside PEP 604 unions if ( - leaves[0].parent - and leaves[0].parent.next_sibling - and leaves[0].parent.next_sibling.type == token.VBAR + leaf_with_parent.parent + and leaf_with_parent.parent.next_sibling + and leaf_with_parent.parent.next_sibling.type == token.VBAR ): return False return True diff --git a/src/black/nodes.py b/src/black/nodes.py index dae787939ea..bf8e9e1a36a 100644 --- a/src/black/nodes.py +++ b/src/black/nodes.py @@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]: def is_part_of_annotation(leaf: Leaf) -> bool: """Returns whether this leaf is part of a type annotation.""" + assert leaf.parent is not None return get_annotation_type(leaf) is not None diff --git a/src/black/trans.py b/src/black/trans.py index 29a978c6b71..1853584108d 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult: break i += 1 - if not is_part_of_annotation(leaf) and not contains_comment: + if not contains_comment and not is_part_of_annotation(leaf): string_indices.append(idx) # Advance to the next non-STRING leaf. diff --git a/tests/data/cases/funcdef_return_type_trailing_comma.py b/tests/data/cases/funcdef_return_type_trailing_comma.py index 9b9b9c673de..2db4a85920a 100644 --- a/tests/data/cases/funcdef_return_type_trailing_comma.py +++ b/tests/data/cases/funcdef_return_type_trailing_comma.py @@ -142,6 +142,31 @@ def SimplePyFn( Buffer[UInt8, 2], Buffer[UInt8, 2], ]: ... + +def foo() -> ( + # comment inside parenthesised return type + int +): + ... + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): + ... + +def foo() -> ( + # comment inside parenthesised new union return type + int | str | bytes +): + ... + +def foo() -> ( + # comment inside plain tuple +): + pass # output # normal, short, function definition def foo(a, b) -> tuple[int, float]: ... @@ -299,3 +324,31 @@ def SimplePyFn( Buffer[UInt8, 2], Buffer[UInt8, 2], ]: ... + + +def foo() -> ( + # comment inside parenthesised return type + int +): ... + + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): ... + + +def foo() -> ( + # comment inside parenthesised new union return type + int + | str + | bytes +): ... + + +def foo() -> ( + # comment inside plain tuple +): + pass From 4d07330677f126076af8e70ab0a9ac5d07647584 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 15 Sep 2024 18:31:39 -0700 Subject: [PATCH 3/5] Move tests to non-preview --- .../funcdef_return_type_trailing_comma.py | 52 ------------------ tests/data/cases/function_trailing_comma.py | 53 +++++++++++++++++++ 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/tests/data/cases/funcdef_return_type_trailing_comma.py b/tests/data/cases/funcdef_return_type_trailing_comma.py index 2db4a85920a..14fd763d9d1 100644 --- a/tests/data/cases/funcdef_return_type_trailing_comma.py +++ b/tests/data/cases/funcdef_return_type_trailing_comma.py @@ -143,30 +143,6 @@ def SimplePyFn( Buffer[UInt8, 2], ]: ... -def foo() -> ( - # comment inside parenthesised return type - int -): - ... - -def foo() -> ( - # comment inside parenthesised return type - # more - int - # another -): - ... - -def foo() -> ( - # comment inside parenthesised new union return type - int | str | bytes -): - ... - -def foo() -> ( - # comment inside plain tuple -): - pass # output # normal, short, function definition def foo(a, b) -> tuple[int, float]: ... @@ -324,31 +300,3 @@ def SimplePyFn( Buffer[UInt8, 2], Buffer[UInt8, 2], ]: ... - - -def foo() -> ( - # comment inside parenthesised return type - int -): ... - - -def foo() -> ( - # comment inside parenthesised return type - # more - int - # another -): ... - - -def foo() -> ( - # comment inside parenthesised new union return type - int - | str - | bytes -): ... - - -def foo() -> ( - # comment inside plain tuple -): - pass diff --git a/tests/data/cases/function_trailing_comma.py b/tests/data/cases/function_trailing_comma.py index 92f46e27516..ce2d4ebaa2c 100644 --- a/tests/data/cases/function_trailing_comma.py +++ b/tests/data/cases/function_trailing_comma.py @@ -60,6 +60,31 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr argument1, (one, two,), argument4, argument5, argument6 ) +def foo() -> ( + # comment inside parenthesised return type + int +): + ... + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): + ... + +def foo() -> ( + # comment inside parenthesised new union return type + int | str | bytes +): + ... + +def foo() -> ( + # comment inside plain tuple +): + pass + # output def f( @@ -176,3 +201,31 @@ def func() -> ( argument5, argument6, ) + + +def foo() -> ( + # comment inside parenthesised return type + int +): ... + + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): ... + + +def foo() -> ( + # comment inside parenthesised new union return type + int + | str + | bytes +): ... + + +def foo() -> ( + # comment inside plain tuple +): + pass From f54a22405f6eb8b7be005fb2e55ab3c4fc01c743 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 15 Sep 2024 18:35:48 -0700 Subject: [PATCH 4/5] a few more test cases --- tests/data/cases/function_trailing_comma.py | 36 +++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/data/cases/function_trailing_comma.py b/tests/data/cases/function_trailing_comma.py index ce2d4ebaa2c..0d58d9cdca0 100644 --- a/tests/data/cases/function_trailing_comma.py +++ b/tests/data/cases/function_trailing_comma.py @@ -85,6 +85,22 @@ def foo() -> ( ): pass +def foo(arg: (# comment with non-return annotation + int + # comment with non-return annotation +)): + pass + +variable: ( # annotation + because + # why not +) + +variable: ( + because + # why not +) + # output def f( @@ -229,3 +245,23 @@ def foo() -> ( # comment inside plain tuple ): pass + + +def foo( + arg: ( # comment with non-return annotation + int + # comment with non-return annotation + ), +): + pass + + +variable: ( # annotation + because + # why not +) + +variable: ( + because + # why not +) From 0781f6a9dc1aabd32ac007cc483e0f119c0f0c5b Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 15 Sep 2024 18:38:40 -0700 Subject: [PATCH 5/5] more --- tests/data/cases/function_trailing_comma.py | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/data/cases/function_trailing_comma.py b/tests/data/cases/function_trailing_comma.py index 0d58d9cdca0..63cf3999c2e 100644 --- a/tests/data/cases/function_trailing_comma.py +++ b/tests/data/cases/function_trailing_comma.py @@ -91,6 +91,23 @@ def foo(arg: (# comment with non-return annotation )): pass +def foo(arg: (# comment with non-return annotation + int | range | memoryview + # comment with non-return annotation +)): + pass + +def foo(arg: (# only before + int +)): + pass + +def foo(arg: ( + int + # only after +)): + pass + variable: ( # annotation because # why not @@ -256,6 +273,30 @@ def foo( pass +def foo( + arg: ( # comment with non-return annotation + int + | range + | memoryview + # comment with non-return annotation + ), +): + pass + + +def foo(arg: int): # only before + pass + + +def foo( + arg: ( + int + # only after + ), +): + pass + + variable: ( # annotation because # why not