Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix async iterator body stripping #15491

Merged
merged 5 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,40 +521,48 @@ def translate_stmt_list(
return [block]

stack = self.class_and_function_stack
if self.strip_function_bodies and len(stack) == 1 and stack[0] == "F":
# Fast case for stripping function bodies
if (
can_strip
and self.strip_function_bodies
and len(stack) == 1
and stack[0] == "F"
and not is_coroutine
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can strip if there is no yield right? Also IIUC this is the actual fix, while the change below is just a refactoring, right?

Copy link
Collaborator Author

@hauntsaninja hauntsaninja Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change below isn't just refactoring, since the stripping below used to only happen in methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, wait, but then now the stripping in last part will occur even if stack[0] != "F", is it safe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that check, and fixed a bug I introduced in the logic for empty bodies

):
return []

res: list[Statement] = []
for stmt in stmts:
node = self.visit(stmt)
res.append(node)

if (
self.strip_function_bodies
and can_strip
and stack[-2:] == ["C", "F"]
and not is_possible_trivial_body(res)
):
# We only strip method bodies if they don't assign to an attribute, as
# this may define an attribute which has an externally visible effect.
visitor = FindAttributeAssign()
for s in res:
s.accept(visitor)
if visitor.found:
break
else:
if is_coroutine:
# Yields inside an async function affect the return type and should not
# be stripped.
yield_visitor = FindYield()
# Slow case for stripping function bodies
if can_strip and self.strip_function_bodies:
if stack[-2:] == ["C", "F"]:
if is_possible_trivial_body(res):
can_strip = False
else:
# We only strip method bodies if they don't assign to an attribute, as
# this may define an attribute which has an externally visible effect.
visitor = FindAttributeAssign()
for s in res:
s.accept(yield_visitor)
if yield_visitor.found:
s.accept(visitor)
if visitor.found:
can_strip = False
break
else:
return []
else:
return []

if can_strip and stack[-1] == "F" and is_coroutine:
# Yields inside an async function affect the return type and should not
# be stripped.
yield_visitor = FindYield()
for s in res:
s.accept(yield_visitor)
if yield_visitor.found:
can_strip = False
break

if can_strip:
return []
return res

def translate_type_comment(
Expand Down
13 changes: 10 additions & 3 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -945,17 +945,21 @@ async def bar(x: Union[A, B]) -> None:
[typing fixtures/typing-async.pyi]

[case testAsyncIteratorWithIgnoredErrors]
from m import L
import m

async def func(l: L) -> None:
async def func(l: m.L) -> None:
reveal_type(l.get_iterator) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]"
reveal_type(l.get_iterator2) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]"
async for i in l.get_iterator():
reveal_type(i) # N: Revealed type is "builtins.str"

reveal_type(m.get_generator) # N: Revealed type is "def () -> typing.AsyncGenerator[builtins.int, None]"
async for i2 in m.get_generator():
reveal_type(i2) # N: Revealed type is "builtins.int"

[file m.py]
# mypy: ignore-errors=True
from typing import AsyncIterator
from typing import AsyncIterator, AsyncGenerator

class L:
async def some_func(self, i: int) -> str:
Expand All @@ -968,6 +972,9 @@ class L:
if self:
a = (yield 'x')

async def get_generator() -> AsyncGenerator[int, None]:
yield 1

[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]

Expand Down