diff --git a/pyupgrade/_ast_helpers.py b/pyupgrade/_ast_helpers.py index fdd34abf..6a556c5a 100644 --- a/pyupgrade/_ast_helpers.py +++ b/pyupgrade/_ast_helpers.py @@ -42,3 +42,11 @@ def has_starargs(call: ast.Call) -> bool: any(k.arg is None for k in call.keywords) or any(isinstance(a, ast.Starred) for a in call.args) ) + + +def contains_await(node: ast.AST) -> bool: + for node_ in ast.walk(node): + if isinstance(node_, ast.Await): + return True + else: + return False diff --git a/pyupgrade/_main.py b/pyupgrade/_main.py index 24bfbb67..f27807a9 100644 --- a/pyupgrade/_main.py +++ b/pyupgrade/_main.py @@ -25,6 +25,7 @@ from pyupgrade._ast_helpers import ast_parse from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import contains_await from pyupgrade._ast_helpers import has_starargs from pyupgrade._data import FUNCS from pyupgrade._data import Settings @@ -520,14 +521,6 @@ def _format_params(call: ast.Call) -> Set[str]: return params -def _contains_await(node: ast.AST) -> bool: - for node_ in ast.walk(node): - if isinstance(node_, ast.Await): - return True - else: - return False - - class FindPy36Plus(ast.NodeVisitor): def __init__(self, *, min_version: Version) -> None: self.fstrings: Dict[Offset, ast.Call] = {} @@ -600,7 +593,7 @@ def visit_Call(self, node: ast.Call) -> None: if not candidate: i += 1 else: - if self.min_version >= (3, 7) or not _contains_await(node): + if self.min_version >= (3, 7) or not contains_await(node): self.fstrings[ast_to_offset(node)] = node self.generic_visit(node) diff --git a/pyupgrade/_plugins/generator_expressions_pep289.py b/pyupgrade/_plugins/generator_expressions_pep289.py index 82e66abe..f90a5913 100644 --- a/pyupgrade/_plugins/generator_expressions_pep289.py +++ b/pyupgrade/_plugins/generator_expressions_pep289.py @@ -8,6 +8,7 @@ from tokenize_rt import Token from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import contains_await from pyupgrade._data import register from pyupgrade._data import State from pyupgrade._data import TokenFunc @@ -55,7 +56,8 @@ def visit_Call( not any( generator.is_async for generator in node.args[0].generators - ) + ) and + not contains_await(node.args[0]) ): if len(node.args) == 1 and not node.keywords: yield ast_to_offset(node.args[0]), _delete_list_comp_brackets diff --git a/tests/features/generator_expressions_pep289_test.py b/tests/features/generator_expressions_pep289_test.py index 5d91d36c..f3877983 100644 --- a/tests/features/generator_expressions_pep289_test.py +++ b/tests/features/generator_expressions_pep289_test.py @@ -28,6 +28,13 @@ ' sum([i async for i in foo()])\n', id='Contains async', ), + pytest.param( + 'tuple([\n' + ' await self._configure_component(hass, controller_config)\n' + ' for controller_config in configs\n' + '])\n', + id='Contains await', + ), ), ) def test_fix_generator_expressions_noop(s):