Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trio111, passing variables from context managers to nurseries opened outside them #22

Merged
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Changelog
*[CalVer, YY.month.patch](https://calver.org/)*

## Future
- add TRIO112, nursery body with only a call to `nursery.start[_soon]` and not passing itself as a parameter can be replaced with a regular function call.
## 22.8.5
- Add TRIO111: Variable, from context manager opened inside nursery, passed to `start[_soon]` might be invalidly accesed while in use, due to context manager closing before the nursery. This is usually a bug, and nurseries should generally be the inner-most context manager.
- Add TRIO112: this single-task nursery could be replaced by awaiting the function call directly.

## 22.8.4
- Fix TRIO108 raising errors on yields in some sync code.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ pip install flake8-trio
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
- **TRIO109**: Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead
- **TRIO110**: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.
- **TRIO111**: Variable, from context manager opened inside nursery, passed to `start[_soon]` might be invalidly accesed while in use, due to context manager closing before the nursery. This is usually a bug, and nurseries should generally be the inner-most context manager.
- **TRIO112**: nursery body with only a call to `nursery.start[_soon]` and not passing itself as a parameter can be replaced with a regular function call.
146 changes: 124 additions & 22 deletions flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.8.4"
__version__ = "22.8.5"


Error_codes = {
Expand Down Expand Up @@ -55,7 +55,12 @@
"`trio.[fail/move_on]_[after/at]` instead"
),
"TRIO110": "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.",
"TRIO112": "Redundant nursery {}, consider replacing with a regular function call",
"TRIO111": (
"variable {2} is usable within the context manager on line {0}, but that "
"will close before nursery opened on line {1} - this is usually a bug. "
"Nurseries should generally be the inner-most context manager."
),
"TRIO112": "Redundant nursery {}, consider replacing with directly awaiting the function call",
}


Expand Down Expand Up @@ -162,10 +167,18 @@ def error(self, error: str, node: HasLineCol, *args: object):
if not self.suppress_errors:
self._problems.append(Error(error, node.lineno, node.col_offset, *args))

def get_state(self, *attrs: str) -> Dict[str, Any]:
def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
if not attrs:
attrs = tuple(self.__dict__.keys())
return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
res: Dict[str, Any] = {}
for attr in attrs:
if attr == "_problems":
continue
value = getattr(self, attr)
if copy and hasattr(value, "copy"):
value = value.copy()
res[attr] = value
return res

def set_state(self, attrs: Dict[str, Any], copy: bool = False):
for attr, value in attrs.items():
Expand All @@ -187,37 +200,68 @@ def has_decorator(decorator_list: List[ast.expr], *names: str):
return False


# handles 100, 101, 106, 109, 110
# handles 100, 101, 106, 109, 110, 111, 112
class VisitorMiscChecks(Flake8TrioVisitor):
class NurseryCall(NamedTuple):
stack_index: int
name: str

class TrioContextManager(NamedTuple):
lineno: int
name: str
is_nursery: bool

def __init__(self):
super().__init__()

# variables only used for 101
# 101
self._yield_is_error = False
self._safe_decorator = False

# ---- 100, 101 ----
# 111
self._context_managers: List[VisitorMiscChecks.TrioContextManager] = []
self._nursery_call: Optional[VisitorMiscChecks.NurseryCall] = None

self.defaults = self.get_state(copy=True)

# ---- 100, 101, 111, 112 ----
def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
# 100
self.check_for_trio100(node)
self.check_for_trio112(node)

# 101 for rest of function
outer = self.get_state("_yield_is_error")
outer = self.get_state("_yield_is_error", "_context_managers", copy=True)

# Check for a `with trio.<scope_creater>`
if not self._safe_decorator:
for item in (i.context_expr for i in node.items):
if (
get_matching_call(item, "open_nursery", *cancel_scope_names)
is not None
):
self._yield_is_error = True
break
for item in node.items:
# 101
# if there's no safe decorator,
# and it's not yet been determined that yield is error
# and this withitem opens a cancelscope:
# then yielding is unsafe
if (
not self._safe_decorator
and not self._yield_is_error
and get_matching_call(
item.context_expr, "open_nursery", *cancel_scope_names
)
is not None
):
self._yield_is_error = True

self.generic_visit(node)
# 111
# if a withitem is saved in a variable,
# push its line, variable, and whether it's a trio nursery
# to the _context_managers stack,
if isinstance(item.optional_vars, ast.Name):
self._context_managers.append(
self.TrioContextManager(
item.context_expr.lineno,
item.optional_vars.id,
get_matching_call(item.context_expr, "open_nursery")
is not None,
)
)

# reset yield_is_error
self.generic_visit(node)
self.set_state(outer)

visit_AsyncWith = visit_With
Expand All @@ -236,7 +280,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
# ---- 101 ----
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
outer = self.get_state()
self._yield_is_error = False
self.set_state(self.defaults, copy=True)

# check for @<context_manager_name> and @<library>.<context_manager_name>
if has_decorator(node.decorator_list, *context_manager_names):
Expand All @@ -251,6 +295,12 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.check_for_trio109(node)
self.visit_FunctionDef(node)

def visit_Lambda(self, node: ast.Lambda):
outer = self.get_state()
self.set_state(self.defaults, copy=True)
self.generic_visit(node)
self.set_state(outer)

# ---- 101 ----
def visit_Yield(self, node: ast.Yield):
if self._yield_is_error:
Expand All @@ -260,8 +310,11 @@ def visit_Yield(self, node: ast.Yield):

# ---- 109 ----
def check_for_trio109(self, node: ast.AsyncFunctionDef):
# pending configuration or a more sophisticated check, ignore
# all functions with a decorator
if node.decorator_list:
return

args = node.args
for arg in (*args.posonlyargs, *args.args, *args.kwonlyargs):
if arg.arg == "timeout":
Expand All @@ -277,6 +330,7 @@ def visit_Import(self, node: ast.Import):
for name in node.names:
if name.name == "trio" and name.asname is not None:
self.error("TRIO106", node)
self.generic_visit(node)

# ---- 110 ----
def visit_While(self, node: ast.While):
Expand All @@ -292,6 +346,53 @@ def check_for_trio110(self, node: ast.While):
):
self.error("TRIO110", node)

# ---- 111 ----
# if it's a <X>.start[_soon] call
# and <X> is a nursery listed in self._context_managers:
# Save <X>'s index in self._context_managers to guard against cm's higher in the
# stack being passed as parameters to it. (and save <X> for the error message)
def visit_Call(self, node: ast.Call):
outer = self.get_state("_nursery_call")

if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.attr in ("start", "start_soon")
):
self._nursery_call = None
for i, cm in enumerate(self._context_managers):
if node.func.value.id == cm.name:
# don't break upon finding a nursery in case there's multiple cm's
# on the stack with the same name
if cm.is_nursery:
self._nursery_call = self.NurseryCall(i, node.func.attr)
else:
self._nursery_call = None

self.generic_visit(node)
self.set_state(outer)

# If we're inside a <X>.start[_soon] call (where <X> is a nursery),
# and we're accessing a variable cm that's on the self._context_managers stack,
# with a higher index than <X>:
# Raise error since the scope of cm may close before the function passed to the
# nursery finishes.
def visit_Name(self, node: ast.Name):
self.generic_visit(node)
if self._nursery_call is None:
return

for i, cm in enumerate(self._context_managers):
if cm.name == node.id and i > self._nursery_call.stack_index:
self.error(
"TRIO111",
node,
cm.lineno,
self._context_managers[self._nursery_call.stack_index].lineno,
node.id,
self._nursery_call.name,
)

# if with has a withitem `trio.open_nursery() as <X>`,
# and the body is only a single expression <X>.start[_soon](),
# and does not pass <X> as a parameter to the expression
Expand Down Expand Up @@ -323,6 +424,7 @@ def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]):
self.error("TRIO112", item.context_expr, var_name)


# used in 102, 103 and 104
def critical_except(node: ast.ExceptHandler) -> Optional[Statement]:
def has_exception(node: Optional[ast.expr]) -> str:
if isinstance(node, ast.Name) and node.id == "BaseException":
Expand Down
41 changes: 25 additions & 16 deletions tests/test_flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,27 @@ def test_eval(test: str, path: str):
try:
# Append a bunch of empty strings so string formatting gives garbage
# instead of throwing an exception
args = eval(
f"[{reg_match}]",
{
"lineno": lineno,
"line": lineno,
"Statement": Statement,
"Stmt": Statement,
},
)
try:
args = eval(
f"[{reg_match}]",
{
"lineno": lineno,
"line": lineno,
"Statement": Statement,
"Stmt": Statement,
},
)
except NameError:
print(f"failed to eval on line {lineno}", file=sys.stderr)
raise

except Exception as e:
print(f"lineno: {lineno}, line: {line}", file=sys.stderr)
raise e
col, *args = args
if args:
col, *args = args
else:
col = 0
assert isinstance(
col, int
), f'invalid column "{col}" @L{lineno}, in "{line}"'
Expand Down Expand Up @@ -163,13 +170,15 @@ def assert_expected_errors(plugin: Plugin, include: Iterable[str], *expected: Er

def print_first_diff(errors: Sequence[Error], expected: Sequence[Error]):
first_error_line: List[Error] = []
for e in errors:
if e.line == errors[0].line:
first_error_line.append(e)
first_expected_line: List[Error] = []
for e in expected:
if e.line == expected[0].line:
first_expected_line.append(e)
for err, exp in zip(errors, expected):
if err == exp:
continue
if not first_error_line or err.line == first_error_line[0]:
first_error_line.append(err)
if not first_expected_line or exp.line == first_expected_line[0]:
first_expected_line.append(exp)

if first_expected_line != first_error_line:
print(
"First lines with different errors",
Expand Down
Loading