Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Hugr lowering for array indexing and add integration tests #422

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
#:
#: All places are equipped with a unique id, a type and an optional definition AST
#: location. During linearity checking, they are tracked separately.
Place: TypeAlias = "Variable | FieldAccess"
Place: TypeAlias = "Variable | FieldAccess | SubscriptAccess"

#: Unique identifier for a `Place`.
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id"
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id | SubscriptAccess.Id"


@dataclass(frozen=True)
Expand Down Expand Up @@ -154,6 +154,45 @@ def replace_defined_at(self, node: AstNode | None) -> "FieldAccess":
return replace(self, exact_defined_at=node)


@dataclass(frozen=True)
class SubscriptAccess:
"""A place identifying a subscript `place[item]` access."""

parent: Place
item: Variable
ty: Type
item_expr: ast.expr
getitem_call: ast.expr
#: Only populated if this place occurs in an inout position
setitem_call: ast.expr | None = None

@dataclass(frozen=True)
class Id:
"""Identifier for subscript places."""

parent: PlaceId
item: Variable.Id

@cached_property
def id(self) -> "SubscriptAccess.Id":
"""The unique `PlaceId` identifier for this place."""
return SubscriptAccess.Id(self.parent.id, self.item.id)

@cached_property
def defined_at(self) -> AstNode | None:
"""Optional location where this place was last assigned to."""
return self.parent.defined_at

@property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Subscript `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return f"{self.parent}[...]"


PyScope = dict[str, Any]


Expand Down
82 changes: 74 additions & 8 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import traceback
from contextlib import suppress
from dataclasses import replace
from typing import Any, NoReturn, cast

from guppylang.ast_util import (
Expand All @@ -35,12 +36,15 @@
with_loc,
with_type,
)
from guppylang.cfg.builder import tmp_vars
from guppylang.checker.core import (
Context,
DummyEvalDict,
FieldAccess,
Globals,
Locals,
Place,
SubscriptAccess,
Variable,
)
from guppylang.definition.ty import TypeDef
Expand All @@ -56,13 +60,15 @@
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
InoutReturnSentinel,
IterEnd,
IterHasNext,
IterNext,
LocalCall,
MakeIter,
PlaceNode,
PyExpr,
SubscriptAccessAndDrop,
TensorCall,
TypeApply,
)
Expand Down Expand Up @@ -447,7 +453,7 @@ def _synthesize_binary(
node,
)

def _synthesize_instance_func(
def synthesize_instance_func(
self,
node: ast.expr,
args: list[ast.expr],
Expand Down Expand Up @@ -495,16 +501,37 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:

def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
item_expr, item_ty = self.synthesize(node.slice)
# Give the item a unique name so we can refer to it later in case we also want
# to compile a call to `__setitem__`
item = Variable(next(tmp_vars), item_ty, item_expr)
item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item)))
# Check a call to the `__getitem__` instance function
exp_sig = FunctionType(
[
FuncInput(ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.Inout),
FuncInput(ExistentialTypeVar.fresh("Key", False), InputFlags.NoFlags),
],
ExistentialTypeVar.fresh("Val", False),
)
return self._synthesize_instance_func(
node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig
getitem_expr, result_ty = self.synthesize_instance_func(
node.value, [item_node], "__getitem__", "not subscriptable", exp_sig
)
# Subscripting a place is itself a place
expr: ast.expr
if isinstance(node.value, PlaceNode):
place = SubscriptAccess(
node.value.place, item, result_ty, item_expr, getitem_expr
)
expr = PlaceNode(place=place)
else:
# If the subscript is not on a place, then there is no way to address the
# other indices after this one has been projected out (e.g. `f()[0]` makes
# you loose access to all elements besides 0).
expr = SubscriptAccessAndDrop(
item=item, item_expr=item_expr, getitem_expr=getitem_expr
)
return with_loc(node, expr), result_ty

def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
if len(node.keywords) > 0:
Expand Down Expand Up @@ -550,7 +577,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False)
)
expr, ty = self._synthesize_instance_func(
expr, ty = self.synthesize_instance_func(
node.value, [], "__iter__", "not iterable", exp_sig
)

Expand All @@ -574,7 +601,7 @@ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty])
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__hasnext__", "not an iterator", exp_sig, True
)

Expand All @@ -584,14 +611,14 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
[FuncInput(ty, InputFlags.NoFlags)],
TupleType([ExistentialTypeVar.fresh("T", False), ty]),
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__next__", "not an iterator", exp_sig, True
)

def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType())
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__end__", "not an iterator", exp_sig, True
)

Expand Down Expand Up @@ -714,6 +741,8 @@ def type_check_args(
new_args: list[ast.expr] = []
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
a.place = check_inout_arg_place(a.place, ctx, a)
new_args.append(a)
subst |= s

Expand All @@ -734,6 +763,43 @@ def type_check_args(
return new_args, subst


def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
"""Performs additional checks for place arguments in @inout position.

In particular, we need to check that places involving `place[item]` subscripts
implement the corresponding `__setitem__` method.
"""
match place:
case Variable():
return place
case FieldAccess(parent=parent):
return replace(place, parent=check_inout_arg_place(parent, ctx, node))
case SubscriptAccess(parent=parent, item=item, ty=ty):
# Check a call to the `__setitem__` instance function
exp_sig = FunctionType(
[
FuncInput(parent.ty, InputFlags.Inout),
FuncInput(item.ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.NoFlags),
],
NoneType(),
)
setitem_args = [
with_type(parent.ty, with_loc(node, PlaceNode(parent))),
with_type(item.ty, with_loc(node, PlaceNode(item))),
with_type(ty, with_loc(node, InoutReturnSentinel(var=place))),
]
setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func(
setitem_args[0],
setitem_args[1:],
"__setitem__",
"not assignable",
exp_sig,
True,
)
return replace(place, setitem_call=setitem_call)


def synthesize_call(
func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context
) -> tuple[list[ast.expr], Type, Inst]:
Expand Down
76 changes: 65 additions & 11 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Locals,
Place,
PlaceId,
SubscriptAccess,
Variable,
)
from guppylang.definition.value import CallableDef
Expand All @@ -31,9 +32,14 @@
InoutReturnSentinel,
LocalCall,
PlaceNode,
SubscriptAccessAndDrop,
TensorCall,
)
from guppylang.tys.ty import FunctionType, InputFlags, StructType
from guppylang.tys.ty import (
FunctionType,
InputFlags,
StructType,
)


class Scope(Locals[PlaceId, Place]):
Expand Down Expand Up @@ -135,16 +141,31 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non
"ownership of the value.",
node,
)
for place in leaf_places(node.place):
x = place.id
if (use := self.scope.used(x)) and place.ty.linear:
# Places involving subscripts are handled differently since we ignore everything
# after the subscript for the purposes of linearity checking
if subscript := contains_subscript(node.place):
if not is_inout_arg and subscript.parent.ty.linear:
raise GuppyError(
f"{place.describe} with linear type `{place.ty}` was already "
"used (at {0})",
"Subscripting on expression with linear type "
f"`{subscript.parent.ty}` is only allowed in `@inout` position",
node,
[use],
)
self.scope.use(x, node)
self.scope.assign(subscript.item)
# Visiting the `__getitem__(place.parent, place.item)` call ensures that we
# linearity-check the parent and element.
self.visit(subscript.getitem_call)
# For all other places, we record uses of all leafs
else:
for place in leaf_places(node.place):
x = place.id
if (use := self.scope.used(x)) and place.ty.linear:
raise GuppyError(
f"{place.describe} with linear type `{place.ty}` was already "
"used (at {0})",
node,
[use],
)
self.scope.use(x, node)

def visit_Assign(self, node: ast.Assign) -> None:
self.visit(node.value)
Expand All @@ -168,9 +189,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
if InputFlags.Inout in inp.flags:
match arg:
case PlaceNode(place=place):
for leaf in leaf_places(place):
leaf = leaf.replace_defined_at(arg)
self.scope.assign(leaf)
self._reassign_single_inout_arg(place, arg)
case arg if inp.ty.linear:
raise GuppyError(
f"Inout argument with linear type `{inp.ty}` would be "
Expand All @@ -180,6 +199,19 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
arg,
)

def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None:
"""Helper function to reassign a single inout argument after a function call."""
# Places involving subscripts are given back by visiting the `__setitem__` call
if subscript := contains_subscript(place):
assert subscript.setitem_call is not None
self.visit(subscript.setitem_call)
self._reassign_single_inout_arg(subscript.parent, node)
else:
for leaf in leaf_places(place):
assert not isinstance(leaf, SubscriptAccess)
leaf = leaf.replace_defined_at(node)
self.scope.assign(leaf)

def visit_GlobalCall(self, node: GlobalCall) -> None:
func = self.globals[node.def_id]
assert isinstance(func, CallableDef)
Expand Down Expand Up @@ -212,6 +244,19 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None:
node.value,
)

def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> None:
# A subscript access on a value that is not a place. This means the value can no
# longer be accessed after the item has been projected out. Thus, this is only
# legal if the items in the container are not linear
elem_ty = get_type(node.getitem_expr)
if elem_ty.linear:
raise GuppyTypeError(
f"Remaining linear items with type `{elem_ty}` are not used", node
)
self.visit(node.item_expr)
self.scope.assign(node.item)
self.visit(node.getitem_expr)

def visit_Expr(self, node: ast.Expr) -> None:
# An expression statement where the return value is discarded
self.visit(node.value)
Expand Down Expand Up @@ -355,6 +400,15 @@ def leaf_places(place: Place) -> Iterator[Place]:
yield place


def contains_subscript(place: Place) -> SubscriptAccess | None:
"""Checks if a place contains a subscript access and returns the rightmost one."""
while not isinstance(place, Variable):
if isinstance(place, SubscriptAccess):
return place
place = place.parent
return None


def is_inout_var(place: Place) -> TypeGuard[Variable]:
"""Checks whether a place is an @inout variable."""
return isinstance(place, Variable) and InputFlags.Inout in place.flags
Expand Down
Loading
Loading