Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Type check array subscripts #420

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
#:
#: All places are equipped with a unique id, a type and an optional definition AST
#: location. During linearity checking, they are tracked separately.
Place: TypeAlias = "Variable | FieldAccess"
Place: TypeAlias = "Variable | FieldAccess | SubscriptAccess"

#: Unique identifier for a `Place`.
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id"
PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id | SubscriptAccess.Id"


@dataclass(frozen=True)
Expand Down Expand Up @@ -154,6 +154,45 @@ def replace_defined_at(self, node: AstNode | None) -> "FieldAccess":
return replace(self, exact_defined_at=node)


@dataclass(frozen=True)
class SubscriptAccess:
"""A place identifying a subscript `place[item]` access."""

parent: Place
item: Variable
ty: Type
item_expr: ast.expr
getitem_call: ast.expr
#: Only populated if this place occurs in an inout position
setitem_call: ast.expr | None = None

@dataclass(frozen=True)
class Id:
"""Identifier for subscript places."""

parent: PlaceId
item: Variable.Id

@cached_property
def id(self) -> "SubscriptAccess.Id":
"""The unique `PlaceId` identifier for this place."""
return SubscriptAccess.Id(self.parent.id, self.item.id)

@cached_property
def defined_at(self) -> AstNode | None:
"""Optional location where this place was last assigned to."""
return self.parent.defined_at

@property
def describe(self) -> str:
"""A human-readable description of this place for error messages."""
return f"Subscript `{self}`"

def __str__(self) -> str:
"""String representation of this place."""
return f"{self.parent}[...]"


PyScope = dict[str, Any]


Expand Down
82 changes: 74 additions & 8 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sys
import traceback
from contextlib import suppress
from dataclasses import replace
from typing import Any, NoReturn, cast

from guppylang.ast_util import (
Expand All @@ -35,12 +36,15 @@
with_loc,
with_type,
)
from guppylang.cfg.builder import tmp_vars
from guppylang.checker.core import (
Context,
DummyEvalDict,
FieldAccess,
Globals,
Locals,
Place,
SubscriptAccess,
Variable,
)
from guppylang.definition.ty import TypeDef
Expand All @@ -56,13 +60,15 @@
DesugaredListComp,
FieldAccessAndDrop,
GlobalName,
InoutReturnSentinel,
IterEnd,
IterHasNext,
IterNext,
LocalCall,
MakeIter,
PlaceNode,
PyExpr,
SubscriptAccessAndDrop,
TensorCall,
TypeApply,
)
Expand Down Expand Up @@ -447,7 +453,7 @@ def _synthesize_binary(
node,
)

def _synthesize_instance_func(
def synthesize_instance_func(
self,
node: ast.expr,
args: list[ast.expr],
Expand Down Expand Up @@ -495,16 +501,37 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:

def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
item_expr, item_ty = self.synthesize(node.slice)
# Give the item a unique name so we can refer to it later in case we also want
# to compile a call to `__setitem__`
item = Variable(next(tmp_vars), item_ty, item_expr)
item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item)))
# Check a call to the `__getitem__` instance function
exp_sig = FunctionType(
[
FuncInput(ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.Inout),
FuncInput(ExistentialTypeVar.fresh("Key", False), InputFlags.NoFlags),
],
ExistentialTypeVar.fresh("Val", False),
)
return self._synthesize_instance_func(
node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig
getitem_expr, result_ty = self.synthesize_instance_func(
node.value, [item_node], "__getitem__", "not subscriptable", exp_sig
)
# Subscripting a place is itself a place
expr: ast.expr
if isinstance(node.value, PlaceNode):
place = SubscriptAccess(
node.value.place, item, result_ty, item_expr, getitem_expr
)
expr = PlaceNode(place=place)
else:
# If the subscript is not on a place, then there is no way to address the
# other indices after this one has been projected out (e.g. `f()[0]` makes
# you loose access to all elements besides 0).
expr = SubscriptAccessAndDrop(
item=item, item_expr=item_expr, getitem_expr=getitem_expr
)
return with_loc(node, expr), result_ty

def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
if len(node.keywords) > 0:
Expand Down Expand Up @@ -550,7 +577,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False)
)
expr, ty = self._synthesize_instance_func(
expr, ty = self.synthesize_instance_func(
node.value, [], "__iter__", "not iterable", exp_sig
)

Expand All @@ -574,7 +601,7 @@ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
exp_sig = FunctionType(
[FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty])
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__hasnext__", "not an iterator", exp_sig, True
)

Expand All @@ -584,14 +611,14 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
[FuncInput(ty, InputFlags.NoFlags)],
TupleType([ExistentialTypeVar.fresh("T", False), ty]),
)
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__next__", "not an iterator", exp_sig, True
)

def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType())
return self._synthesize_instance_func(
return self.synthesize_instance_func(
node.value, [], "__end__", "not an iterator", exp_sig, True
)

Expand Down Expand Up @@ -714,6 +741,8 @@ def type_check_args(
new_args: list[ast.expr] = []
for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
a.place = check_inout_arg_place(a.place, ctx, a)
new_args.append(a)
subst |= s

Expand All @@ -734,6 +763,43 @@ def type_check_args(
return new_args, subst


def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
"""Performs additional checks for place arguments in @inout position.

In particular, we need to check that places involving `place[item]` subscripts
implement the corresponding `__setitem__` method.
"""
match place:
case Variable():
return place
case FieldAccess(parent=parent):
return replace(place, parent=check_inout_arg_place(parent, ctx, node))
case SubscriptAccess(parent=parent, item=item, ty=ty):
# Check a call to the `__setitem__` instance function
exp_sig = FunctionType(
[
FuncInput(parent.ty, InputFlags.Inout),
FuncInput(item.ty, InputFlags.NoFlags),
FuncInput(ty, InputFlags.NoFlags),
],
NoneType(),
)
setitem_args = [
with_type(parent.ty, with_loc(node, PlaceNode(parent))),
with_type(item.ty, with_loc(node, PlaceNode(item))),
with_type(ty, with_loc(node, InoutReturnSentinel(var=place))),
]
setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func(
setitem_args[0],
setitem_args[1:],
"__setitem__",
"not allowed in a subscripted `@inout` position",
exp_sig,
True,
)
return replace(place, setitem_call=setitem_call)


def synthesize_call(
func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context
) -> tuple[list[ast.expr], Type, Inst]:
Expand Down
2 changes: 2 additions & 0 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Locals,
Place,
PlaceId,
SubscriptAccess,
Variable,
)
from guppylang.definition.value import CallableDef
Expand Down Expand Up @@ -169,6 +170,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N
match arg:
case PlaceNode(place=place):
for leaf in leaf_places(place):
assert not isinstance(leaf, SubscriptAccess)
leaf = leaf.replace_defined_at(arg)
self.scope.assign(leaf)
case arg if inp.ty.linear:
Expand Down
16 changes: 15 additions & 1 deletion guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ class FieldAccessAndDrop(ast.expr):
)


class SubscriptAccessAndDrop(ast.expr):
"""A subscript element access on an object, dropping all the remaining items."""

item: "Variable"
item_expr: ast.expr
getitem_expr: ast.expr

_fields = (
"item",
"item_expr",
"getitem_expr",
)


class MakeIter(ast.expr):
"""Creates an iterator using the `__iter__` magic method.

Expand Down Expand Up @@ -205,7 +219,7 @@ class InoutReturnSentinel(ast.expr):
"""An invisible expression corresponding to an implicit use of @inout vars whenever
a function returns."""

var: "Variable | str"
var: "Place | str"

_fields = ("var",)

Expand Down
10 changes: 9 additions & 1 deletion guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,15 @@ class Array:
"ArrayGet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0}
),
)
def __getitem__(self: array[T, n], idx: int) -> T: ...
def __getitem__(self: array[L, n] @ inout, idx: int) -> L: ...

@guppy.hugr_op(
builtins,
custom_op(
"ArraySet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0}
),
)
def __setitem__(self: array[L, n] @ inout, idx: int, value: L) -> None: ...

@guppy.custom(builtins, checker=ArrayLenChecker())
def __len__(self: array[T, n]) -> int: ...
Expand Down
7 changes: 0 additions & 7 deletions tests/error/array_errors/linear_index.err

This file was deleted.

17 changes: 0 additions & 17 deletions tests/error/array_errors/linear_index.py

This file was deleted.

7 changes: 7 additions & 0 deletions tests/error/inout_errors/subscript_not_setable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:24

22: @guppy(module)
23: def test(c: MyImmutableContainer) -> MyImmutableContainer:
24: foo(c[0])
^^^^
GuppyTypeError: Expression of type `MyImmutableContainer` is not allowed in a subscripted `@inout` position since it does not implement the `__setitem__` method
28 changes: 28 additions & 0 deletions tests/error/inout_errors/subscript_not_setable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout
from guppylang.prelude.quantum import qubit, quantum

module = GuppyModule("test")
module.load(quantum)


@guppy.declare(module)
def foo(q: qubit @inout) -> None: ...


@guppy.struct(module)
class MyImmutableContainer:
q: qubit

@guppy.declare(module)
def __getitem__(self: "MyImmutableContainer" @inout, idx: int) -> qubit: ...


@guppy(module)
def test(c: MyImmutableContainer) -> MyImmutableContainer:
foo(c[0])
return c


module.compile()
7 changes: 7 additions & 0 deletions tests/error/type_errors/not_subscriptable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def foo(x: int) -> None:
9: x[0]
^
GuppyTypeError: Expression of type `int` is not subscriptable
12 changes: 12 additions & 0 deletions tests/error/type_errors/not_subscriptable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def foo(x: int) -> None:
x[0]


module.compile()
7 changes: 7 additions & 0 deletions tests/error/type_errors/subscript_bad_item.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:10

8: @guppy(module)
9: def foo(xs: array[int, 42]) -> int:
10: return xs[1.0]
^^^
GuppyTypeError: Expected argument of type `int`, got `float`
Loading
Loading