Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6: async checkpoints #16

Merged
merged 9 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Changelog
*[CalVer, YY.month.patch](https://calver.org/)*

## Future
## 22.7.5
- Add TRIO103: `except BaseException` or `except trio.Cancelled` with a code path that doesn't re-raise
- Add TRIO104: "Cancelled and BaseException must be re-raised" if user tries to return or raise a different exception.
- Added TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised
- Added TRIO108: Early return from async function must have at least one checkpoint on every code path before it.

## 22.7.4
- Added TRIO105 check for not immediately `await`ing async trio functions.
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ pip install flake8-trio
- **TRIO104**: `Cancelled` and `BaseException` must be re-raised - when a user tries to `return` or `raise` a different exception.
- **TRIO105**: Calling a trio async function without immediately `await`ing it.
- **TRIO106**: trio must be imported with `import trio` for the linter to work
-
- **TRIO107**: Async functions must have at least one checkpoint on every code path, unless an exception is raised
- **TRIO108**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised.
Checkpoints are `await`, `async with` `async for`.
126 changes: 121 additions & 5 deletions flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import ast
import tokenize
from typing import Any, Collection, Generator, List, Optional, Tuple, Type, Union
from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union

# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.7.4"
__version__ = "22.7.5"


Error = Tuple[int, int, str, Type[Any]]
Expand Down Expand Up @@ -47,6 +47,16 @@ def run(cls, tree: ast.AST) -> Generator[Error, None, None]:
visitor.visit(tree)
yield from visitor.problems

def visit_nodes(self, nodes: Union[ast.AST, Iterable[ast.AST]]) -> None:
if isinstance(nodes, ast.AST):
self.visit(nodes)
else:
for node in nodes:
self.visit(node)

def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any):
self.problems.append(make_error(error, lineno, col, *args, **kwargs))


class TrioScope:
def __init__(self, node: ast.Call, funcname: str, packagename: str):
Expand Down Expand Up @@ -88,7 +98,7 @@ def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]:
return None


def has_decorator(decorator_list: List[ast.expr], names: Collection[str]):
def has_decorator(decorator_list: List[ast.expr], *names: str):
for dec in decorator_list:
if (isinstance(dec, ast.Name) and dec.id in names) or (
isinstance(dec, ast.Attribute) and dec.attr in names
Expand Down Expand Up @@ -135,7 +145,7 @@ def visit_FunctionDef(
self._yield_is_error = False

# check for @<context_manager_name> and @<library>.<context_manager_name>
if has_decorator(node.decorator_list, context_manager_names):
if has_decorator(node.decorator_list, *context_manager_names):
self._context_manager = True

self.generic_visit(node)
Expand Down Expand Up @@ -238,7 +248,7 @@ def visit_FunctionDef(
outer_cm = self._context_manager

# check for @<context_manager_name> and @<library>.<context_manager_name>
if has_decorator(node.decorator_list, context_manager_names):
if has_decorator(node.decorator_list, *context_manager_names):
self._context_manager = True

self.generic_visit(node)
Expand Down Expand Up @@ -462,6 +472,110 @@ def visit_Call(self, node: ast.Call):
self.generic_visit(node)


class Visitor107_108(Flake8TrioVisitor):
def __init__(self) -> None:
super().__init__()
self.all_await = True

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
outer = self.all_await

# do not require checkpointing if overloading
self.all_await = has_decorator(node.decorator_list, "overload")
self.generic_visit(node)

if not self.all_await:
self.error(TRIO107, node.lineno, node.col_offset)

self.all_await = outer

def visit_Return(self, node: ast.Return):
self.generic_visit(node)
if not self.all_await:
self.error(TRIO108, node.lineno, node.col_offset)
# avoid duplicate error messages
self.all_await = True

# disregard raise's in nested functions
def visit_FunctionDef(self, node: ast.FunctionDef):
outer = self.all_await
self.generic_visit(node)
self.all_await = outer

# checkpoint functions
def visit_Await(
self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith, ast.Raise]
):
self.generic_visit(node)
self.all_await = True

visit_AsyncFor = visit_Await
visit_AsyncWith = visit_Await

# raising exception means we don't need to checkpoint so we can treat it as one
visit_Raise = visit_Await

# valid checkpoint if there's valid checkpoints (or raise) in at least one of:
# (try or else) and all excepts
# finally
def visit_Try(self, node: ast.Try):
if self.all_await:
self.generic_visit(node)
return

# check try body
self.visit_nodes(node.body)
body_await = self.all_await
self.all_await = False

# check that all except handlers checkpoint (await or most likely raise)
all_except_await = True
for handler in node.handlers:
self.visit_nodes(handler)
all_except_await &= self.all_await
self.all_await = False

# check else
self.visit_nodes(node.orelse)

# (try or else) and all excepts
self.all_await = (body_await or self.all_await) and all_except_await

# finally can check on it's own
self.visit_nodes(node.finalbody)

# valid checkpoint if both body and orelse have checkpoints
def visit_If(self, node: Union[ast.If, ast.IfExp]):
if self.all_await:
self.generic_visit(node)
return

# ignore checkpoints in condition
self.visit_nodes(node.test)
self.all_await = False

# check body
self.visit_nodes(node.body)
body_await = self.all_await
self.all_await = False

self.visit_nodes(node.orelse)

# checkpoint if both body and else
self.all_await = body_await and self.all_await

# inline if
visit_IfExp = visit_If

# ignore checkpoints in loops due to continue/break shenanigans
def visit_While(self, node: Union[ast.While, ast.For]):
outer = self.all_await
self.generic_visit(node)
self.all_await = outer

visit_For = visit_While


class Plugin:
name = __name__
version = __version__
Expand All @@ -487,3 +601,5 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised"
TRIO105 = "TRIO105: Trio async function {} must be immediately awaited"
TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work"
TRIO107 = "TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised"
TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it."
32 changes: 31 additions & 1 deletion tests/test_flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
TRIO104,
TRIO105,
TRIO106,
TRIO107,
TRIO108,
Error,
Plugin,
make_error,
Expand Down Expand Up @@ -94,7 +96,7 @@ def test_trio102(self):
make_error(TRIO102, 92, 8),
make_error(TRIO102, 94, 8),
make_error(TRIO102, 101, 12),
make_error(TRIO102, 123, 12),
make_error(TRIO102, 124, 12),
)

def test_trio103_104(self):
Expand Down Expand Up @@ -173,6 +175,34 @@ def test_trio106(self):
make_error(TRIO106, 6, 0),
)

def test_trio107_108(self):
self.assert_expected_errors(
"trio107_108.py",
make_error(TRIO107, 13, 0),
# if
make_error(TRIO107, 18, 0),
make_error(TRIO107, 36, 0),
# ifexp
make_error(TRIO107, 46, 0),
# loops
make_error(TRIO107, 51, 0),
make_error(TRIO107, 56, 0),
make_error(TRIO107, 69, 0),
make_error(TRIO107, 74, 0),
# try
make_error(TRIO107, 83, 0),
# early return
make_error(TRIO108, 140, 4),
make_error(TRIO108, 145, 8),
# nested function definition
make_error(TRIO107, 149, 0),
make_error(TRIO107, 159, 4),
make_error(TRIO107, 163, 0),
make_error(TRIO107, 170, 8),
make_error(TRIO107, 168, 0),
make_error(TRIO107, 174, 0),
)


@pytest.mark.fuzz
class TestFuzz(unittest.TestCase):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_trio_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ def runTest(self):
self.assertNotIn(lineno, func_error_lines, msg=test)
func_error_lines.add(lineno)

self.assertSetEqual(file_error_lines, func_error_lines, msg=test)
self.assertSequenceEqual(
sorted(file_error_lines), sorted(func_error_lines), msg=test
)
1 change: 1 addition & 0 deletions tests/trio100_py39.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ async def function_name():
trio.move_on_after(5), # error
):
pass
await function_name() # avoid TRIO107
5 changes: 3 additions & 2 deletions tests/trio102.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

async def foo():
try:
pass
await foo() # avoid TRIO107
finally:
with trio.move_on_after(deadline=30) as s:
s.shield = True
Expand Down Expand Up @@ -107,11 +107,12 @@ async def foo2():
yield 1
finally:
await foo() # safe
await foo() # avoid TRIO107


async def foo3():
try:
pass
await foo() # avoid TRIO107
finally:
with trio.move_on_after(30) as s, trio.fail_after(5):
s.shield = True
Expand Down
Loading