Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dont do generator rewrite if list comp contains await #493

Merged
merged 1 commit into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyupgrade/_ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ def has_starargs(call: ast.Call) -> bool:
any(k.arg is None for k in call.keywords) or
any(isinstance(a, ast.Starred) for a in call.args)
)


def contains_await(node: ast.AST) -> bool:
for node_ in ast.walk(node):
if isinstance(node_, ast.Await):
return True
else:
return False
11 changes: 2 additions & 9 deletions pyupgrade/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pyupgrade._ast_helpers import ast_parse
from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import contains_await
from pyupgrade._ast_helpers import has_starargs
from pyupgrade._data import FUNCS
from pyupgrade._data import Settings
Expand Down Expand Up @@ -520,14 +521,6 @@ def _format_params(call: ast.Call) -> Set[str]:
return params


def _contains_await(node: ast.AST) -> bool:
for node_ in ast.walk(node):
if isinstance(node_, ast.Await):
return True
else:
return False


class FindPy36Plus(ast.NodeVisitor):
def __init__(self, *, min_version: Version) -> None:
self.fstrings: Dict[Offset, ast.Call] = {}
Expand Down Expand Up @@ -600,7 +593,7 @@ def visit_Call(self, node: ast.Call) -> None:
if not candidate:
i += 1
else:
if self.min_version >= (3, 7) or not _contains_await(node):
if self.min_version >= (3, 7) or not contains_await(node):
self.fstrings[ast_to_offset(node)] = node

self.generic_visit(node)
Expand Down
4 changes: 3 additions & 1 deletion pyupgrade/_plugins/generator_expressions_pep289.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import contains_await
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
Expand Down Expand Up @@ -55,7 +56,8 @@ def visit_Call(
not any(
generator.is_async
for generator in node.args[0].generators
)
) and
not contains_await(node.args[0])
):
if len(node.args) == 1 and not node.keywords:
yield ast_to_offset(node.args[0]), _delete_list_comp_brackets
Expand Down
7 changes: 7 additions & 0 deletions tests/features/generator_expressions_pep289_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
' sum([i async for i in foo()])\n',
id='Contains async',
),
pytest.param(
'tuple([\n'
' await self._configure_component(hass, controller_config)\n'
' for controller_config in configs\n'
'])\n',
id='Contains await',
),
),
)
def test_fix_generator_expressions_noop(s):
Expand Down