From 59e82c8b86bf3b28c5779a829d94e5f380a73e66 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 27 Aug 2024 17:30:10 +0100 Subject: [PATCH 1/6] Init From 61997cbbd4dcaa432379b8d237e3cafd90b365e3 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Fri, 30 Aug 2024 15:47:58 +0100 Subject: [PATCH 2/6] feat: Type check array subscripts (#420) Closes #416. See #415 for context. * Adds a new `SubscriptAccess` place that will be used to track array subscripts during linearity checking * This place is emitted when checking a subscript AST node * Ensures that subscripts in inout positions also implement a `__setitem__` method --- guppylang/checker/core.py | 43 +++++++++- guppylang/checker/expr_checker.py | 82 +++++++++++++++++-- guppylang/checker/linearity_checker.py | 2 + guppylang/nodes.py | 16 +++- guppylang/prelude/builtins.py | 10 ++- tests/error/array_errors/linear_index.err | 7 -- tests/error/array_errors/linear_index.py | 17 ---- .../inout_errors/subscript_not_setable.err | 7 ++ .../inout_errors/subscript_not_setable.py | 28 +++++++ tests/error/type_errors/not_subscriptable.err | 7 ++ tests/error/type_errors/not_subscriptable.py | 12 +++ .../error/type_errors/subscript_bad_item.err | 7 ++ tests/error/type_errors/subscript_bad_item.py | 13 +++ tests/integration/test_array.py | 2 + tests/integration/test_list.py | 3 + 15 files changed, 220 insertions(+), 36 deletions(-) delete mode 100644 tests/error/array_errors/linear_index.err delete mode 100644 tests/error/array_errors/linear_index.py create mode 100644 tests/error/inout_errors/subscript_not_setable.err create mode 100644 tests/error/inout_errors/subscript_not_setable.py create mode 100644 tests/error/type_errors/not_subscriptable.err create mode 100644 tests/error/type_errors/not_subscriptable.py create mode 100644 tests/error/type_errors/subscript_bad_item.err create mode 100644 tests/error/type_errors/subscript_bad_item.py diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 3a03dd31..60ec1c4d 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -60,10 +60,10 @@ #: #: All places are equipped with a unique id, a type and an optional definition AST #: location. During linearity checking, they are tracked separately. -Place: TypeAlias = "Variable | FieldAccess" +Place: TypeAlias = "Variable | FieldAccess | SubscriptAccess" #: Unique identifier for a `Place`. -PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id" +PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id | SubscriptAccess.Id" @dataclass(frozen=True) @@ -154,6 +154,45 @@ def replace_defined_at(self, node: AstNode | None) -> "FieldAccess": return replace(self, exact_defined_at=node) +@dataclass(frozen=True) +class SubscriptAccess: + """A place identifying a subscript `place[item]` access.""" + + parent: Place + item: Variable + ty: Type + item_expr: ast.expr + getitem_call: ast.expr + #: Only populated if this place occurs in an inout position + setitem_call: ast.expr | None = None + + @dataclass(frozen=True) + class Id: + """Identifier for subscript places.""" + + parent: PlaceId + item: Variable.Id + + @cached_property + def id(self) -> "SubscriptAccess.Id": + """The unique `PlaceId` identifier for this place.""" + return SubscriptAccess.Id(self.parent.id, self.item.id) + + @cached_property + def defined_at(self) -> AstNode | None: + """Optional location where this place was last assigned to.""" + return self.parent.defined_at + + @property + def describe(self) -> str: + """A human-readable description of this place for error messages.""" + return f"Subscript `{self}`" + + def __str__(self) -> str: + """String representation of this place.""" + return f"{self.parent}[...]" + + PyScope = dict[str, Any] diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index cef50fb4..1242baa5 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -24,6 +24,7 @@ import sys import traceback from contextlib import suppress +from dataclasses import replace from typing import Any, NoReturn, cast from guppylang.ast_util import ( @@ -35,12 +36,15 @@ with_loc, with_type, ) +from guppylang.cfg.builder import tmp_vars from guppylang.checker.core import ( Context, DummyEvalDict, FieldAccess, Globals, Locals, + Place, + SubscriptAccess, Variable, ) from guppylang.definition.ty import TypeDef @@ -56,6 +60,7 @@ DesugaredListComp, FieldAccessAndDrop, GlobalName, + InoutReturnSentinel, IterEnd, IterHasNext, IterNext, @@ -63,6 +68,7 @@ MakeIter, PlaceNode, PyExpr, + SubscriptAccessAndDrop, TensorCall, TypeApply, ) @@ -447,7 +453,7 @@ def _synthesize_binary( node, ) - def _synthesize_instance_func( + def synthesize_instance_func( self, node: ast.expr, args: list[ast.expr], @@ -495,16 +501,37 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]: def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) + item_expr, item_ty = self.synthesize(node.slice) + # Give the item a unique name so we can refer to it later in case we also want + # to compile a call to `__setitem__` + item = Variable(next(tmp_vars), item_ty, item_expr) + item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item))) + # Check a call to the `__getitem__` instance function exp_sig = FunctionType( [ - FuncInput(ty, InputFlags.NoFlags), + FuncInput(ty, InputFlags.Inout), FuncInput(ExistentialTypeVar.fresh("Key", False), InputFlags.NoFlags), ], ExistentialTypeVar.fresh("Val", False), ) - return self._synthesize_instance_func( - node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig + getitem_expr, result_ty = self.synthesize_instance_func( + node.value, [item_node], "__getitem__", "not subscriptable", exp_sig ) + # Subscripting a place is itself a place + expr: ast.expr + if isinstance(node.value, PlaceNode): + place = SubscriptAccess( + node.value.place, item, result_ty, item_expr, getitem_expr + ) + expr = PlaceNode(place=place) + else: + # If the subscript is not on a place, then there is no way to address the + # other indices after this one has been projected out (e.g. `f()[0]` makes + # you loose access to all elements besides 0). + expr = SubscriptAccessAndDrop( + item=item, item_expr=item_expr, getitem_expr=getitem_expr + ) + return with_loc(node, expr), result_ty def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: if len(node.keywords) > 0: @@ -550,7 +577,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]: exp_sig = FunctionType( [FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False) ) - expr, ty = self._synthesize_instance_func( + expr, ty = self.synthesize_instance_func( node.value, [], "__iter__", "not iterable", exp_sig ) @@ -574,7 +601,7 @@ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]: exp_sig = FunctionType( [FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty]) ) - return self._synthesize_instance_func( + return self.synthesize_instance_func( node.value, [], "__hasnext__", "not an iterator", exp_sig, True ) @@ -584,14 +611,14 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]: [FuncInput(ty, InputFlags.NoFlags)], TupleType([ExistentialTypeVar.fresh("T", False), ty]), ) - return self._synthesize_instance_func( + return self.synthesize_instance_func( node.value, [], "__next__", "not an iterator", exp_sig, True ) def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType()) - return self._synthesize_instance_func( + return self.synthesize_instance_func( node.value, [], "__end__", "not an iterator", exp_sig, True ) @@ -714,6 +741,8 @@ def type_check_args( new_args: list[ast.expr] = [] for inp, func_inp in zip(inputs, func_ty.inputs, strict=True): a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument") + if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode): + a.place = check_inout_arg_place(a.place, ctx, a) new_args.append(a) subst |= s @@ -734,6 +763,43 @@ def type_check_args( return new_args, subst +def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place: + """Performs additional checks for place arguments in @inout position. + + In particular, we need to check that places involving `place[item]` subscripts + implement the corresponding `__setitem__` method. + """ + match place: + case Variable(): + return place + case FieldAccess(parent=parent): + return replace(place, parent=check_inout_arg_place(parent, ctx, node)) + case SubscriptAccess(parent=parent, item=item, ty=ty): + # Check a call to the `__setitem__` instance function + exp_sig = FunctionType( + [ + FuncInput(parent.ty, InputFlags.Inout), + FuncInput(item.ty, InputFlags.NoFlags), + FuncInput(ty, InputFlags.NoFlags), + ], + NoneType(), + ) + setitem_args = [ + with_type(parent.ty, with_loc(node, PlaceNode(parent))), + with_type(item.ty, with_loc(node, PlaceNode(item))), + with_type(ty, with_loc(node, InoutReturnSentinel(var=place))), + ] + setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func( + setitem_args[0], + setitem_args[1:], + "__setitem__", + "not allowed in a subscripted `@inout` position", + exp_sig, + True, + ) + return replace(place, setitem_call=setitem_call) + + def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[list[ast.expr], Type, Inst]: diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index 534246c3..7a518646 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -18,6 +18,7 @@ Locals, Place, PlaceId, + SubscriptAccess, Variable, ) from guppylang.definition.value import CallableDef @@ -169,6 +170,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N match arg: case PlaceNode(place=place): for leaf in leaf_places(place): + assert not isinstance(leaf, SubscriptAccess) leaf = leaf.replace_defined_at(arg) self.scope.assign(leaf) case arg if inp.ty.linear: diff --git a/guppylang/nodes.py b/guppylang/nodes.py index a79b41d4..6b3394d2 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -94,6 +94,20 @@ class FieldAccessAndDrop(ast.expr): ) +class SubscriptAccessAndDrop(ast.expr): + """A subscript element access on an object, dropping all the remaining items.""" + + item: "Variable" + item_expr: ast.expr + getitem_expr: ast.expr + + _fields = ( + "item", + "item_expr", + "getitem_expr", + ) + + class MakeIter(ast.expr): """Creates an iterator using the `__iter__` magic method. @@ -205,7 +219,7 @@ class InoutReturnSentinel(ast.expr): """An invisible expression corresponding to an implicit use of @inout vars whenever a function returns.""" - var: "Variable | str" + var: "Place | str" _fields = ("var",) diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index d3303323..56c2f818 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -655,7 +655,15 @@ class Array: "ArrayGet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0} ), ) - def __getitem__(self: array[T, n], idx: int) -> T: ... + def __getitem__(self: array[L, n] @ inout, idx: int) -> L: ... + + @guppy.hugr_op( + builtins, + custom_op( + "ArraySet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0} + ), + ) + def __setitem__(self: array[L, n] @ inout, idx: int, value: L) -> None: ... @guppy.custom(builtins, checker=ArrayLenChecker()) def __len__(self: array[T, n]) -> int: ... diff --git a/tests/error/array_errors/linear_index.err b/tests/error/array_errors/linear_index.err deleted file mode 100644 index dfa0a247..00000000 --- a/tests/error/array_errors/linear_index.err +++ /dev/null @@ -1,7 +0,0 @@ -Guppy compilation failed. Error in file $FILE:14 - -12: @guppy(module) -13: def main(qs: array[qubit, 42]) -> int: -14: return qs[0] - ^^ -GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall n, T: nat. (array[T, n], int) -> T` with linear type `qubit` diff --git a/tests/error/array_errors/linear_index.py b/tests/error/array_errors/linear_index.py deleted file mode 100644 index 79a6bc02..00000000 --- a/tests/error/array_errors/linear_index.py +++ /dev/null @@ -1,17 +0,0 @@ -import guppylang.prelude.quantum as quantum -from guppylang.decorator import guppy -from guppylang.module import GuppyModule -from guppylang.prelude.builtins import array -from guppylang.prelude.quantum import qubit - - -module = GuppyModule("test") -module.load(quantum) - - -@guppy(module) -def main(qs: array[qubit, 42]) -> int: - return qs[0] - - -module.compile() \ No newline at end of file diff --git a/tests/error/inout_errors/subscript_not_setable.err b/tests/error/inout_errors/subscript_not_setable.err new file mode 100644 index 00000000..510cfca5 --- /dev/null +++ b/tests/error/inout_errors/subscript_not_setable.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:24 + +22: @guppy(module) +23: def test(c: MyImmutableContainer) -> MyImmutableContainer: +24: foo(c[0]) + ^^^^ +GuppyTypeError: Expression of type `MyImmutableContainer` is not allowed in a subscripted `@inout` position since it does not implement the `__setitem__` method diff --git a/tests/error/inout_errors/subscript_not_setable.py b/tests/error/inout_errors/subscript_not_setable.py new file mode 100644 index 00000000..7a1323fc --- /dev/null +++ b/tests/error/inout_errors/subscript_not_setable.py @@ -0,0 +1,28 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout +from guppylang.prelude.quantum import qubit, quantum + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo(q: qubit @inout) -> None: ... + + +@guppy.struct(module) +class MyImmutableContainer: + q: qubit + + @guppy.declare(module) + def __getitem__(self: "MyImmutableContainer" @inout, idx: int) -> qubit: ... + + +@guppy(module) +def test(c: MyImmutableContainer) -> MyImmutableContainer: + foo(c[0]) + return c + + +module.compile() diff --git a/tests/error/type_errors/not_subscriptable.err b/tests/error/type_errors/not_subscriptable.err new file mode 100644 index 00000000..0611a3bf --- /dev/null +++ b/tests/error/type_errors/not_subscriptable.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy(module) +8: def foo(x: int) -> None: +9: x[0] + ^ +GuppyTypeError: Expression of type `int` is not subscriptable diff --git a/tests/error/type_errors/not_subscriptable.py b/tests/error/type_errors/not_subscriptable.py new file mode 100644 index 00000000..a93e2961 --- /dev/null +++ b/tests/error/type_errors/not_subscriptable.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy(module) +def foo(x: int) -> None: + x[0] + + +module.compile() diff --git a/tests/error/type_errors/subscript_bad_item.err b/tests/error/type_errors/subscript_bad_item.err new file mode 100644 index 00000000..7337c3b0 --- /dev/null +++ b/tests/error/type_errors/subscript_bad_item.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy(module) +9: def foo(xs: array[int, 42]) -> int: +10: return xs[1.0] + ^^^ +GuppyTypeError: Expected argument of type `int`, got `float` diff --git a/tests/error/type_errors/subscript_bad_item.py b/tests/error/type_errors/subscript_bad_item.py new file mode 100644 index 00000000..ee4e0a6e --- /dev/null +++ b/tests/error/type_errors/subscript_bad_item.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + +module = GuppyModule("test") + + +@guppy(module) +def foo(xs: array[int, 42]) -> int: + return xs[1.0] + + +module.compile() diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 49ac3fc7..1ce01107 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -1,3 +1,4 @@ +import pytest from hugr import ops from hugr.std.int import IntVal @@ -23,6 +24,7 @@ def main(xs: array[float, 42]) -> int: assert val.val.v == 42 +@pytest.mark.skip("Skipped until Hugr lowering is updated") def test_index(validate): @compile_guppy def main(xs: array[int, 5], i: int) -> int: diff --git a/tests/integration/test_list.py b/tests/integration/test_list.py index e90f1bfe..30a1a6d5 100644 --- a/tests/integration/test_list.py +++ b/tests/integration/test_list.py @@ -1,3 +1,5 @@ +import pytest + from tests.util import compile_guppy @@ -37,6 +39,7 @@ def test(xs: list[int]) -> list[int]: validate(test) +@pytest.mark.skip("Requires updating lists to use inout") def test_subscript(validate): @compile_guppy def test(xs: list[float], i: int) -> float: From ded9e1c9332e8d6ae08c9535f16ac2c5c3cac66c Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:59:16 +0100 Subject: [PATCH 3/6] feat: Update linearity checker to handle subscripts (#421) Closes #417 and closes #252. See #415 for context. --- guppylang/checker/linearity_checker.py | 76 ++++++++++++++++--- .../array_errors/subscript_after_use.err | 7 ++ .../error/array_errors/subscript_after_use.py | 21 +++++ tests/error/array_errors/subscript_drop.err | 7 ++ tests/error/array_errors/subscript_drop.py | 21 +++++ .../array_errors/subscript_non_inout.err | 7 ++ .../error/array_errors/subscript_non_inout.py | 18 +++++ .../array_errors/use_after_subscript.err | 7 ++ .../error/array_errors/use_after_subscript.py | 21 +++++ 9 files changed, 173 insertions(+), 12 deletions(-) create mode 100644 tests/error/array_errors/subscript_after_use.err create mode 100644 tests/error/array_errors/subscript_after_use.py create mode 100644 tests/error/array_errors/subscript_drop.err create mode 100644 tests/error/array_errors/subscript_drop.py create mode 100644 tests/error/array_errors/subscript_non_inout.err create mode 100644 tests/error/array_errors/subscript_non_inout.py create mode 100644 tests/error/array_errors/use_after_subscript.err create mode 100644 tests/error/array_errors/use_after_subscript.py diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index 7a518646..21cb1b6a 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -32,9 +32,14 @@ InoutReturnSentinel, LocalCall, PlaceNode, + SubscriptAccessAndDrop, TensorCall, ) -from guppylang.tys.ty import FunctionType, InputFlags, StructType +from guppylang.tys.ty import ( + FunctionType, + InputFlags, + StructType, +) class Scope(Locals[PlaceId, Place]): @@ -136,16 +141,31 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non "ownership of the value.", node, ) - for place in leaf_places(node.place): - x = place.id - if (use := self.scope.used(x)) and place.ty.linear: + # Places involving subscripts are handled differently since we ignore everything + # after the subscript for the purposes of linearity checking + if subscript := contains_subscript(node.place): + if not is_inout_arg and subscript.parent.ty.linear: raise GuppyError( - f"{place.describe} with linear type `{place.ty}` was already " - "used (at {0})", + "Subscripting on expression with linear type " + f"`{subscript.parent.ty}` is only allowed in `@inout` position", node, - [use], ) - self.scope.use(x, node) + self.scope.assign(subscript.item) + # Visiting the `__getitem__(place.parent, place.item)` call ensures that we + # linearity-check the parent and element. + self.visit(subscript.getitem_call) + # For all other places, we record uses of all leafs + else: + for place in leaf_places(node.place): + x = place.id + if (use := self.scope.used(x)) and place.ty.linear: + raise GuppyError( + f"{place.describe} with linear type `{place.ty}` was already " + "used (at {0})", + node, + [use], + ) + self.scope.use(x, node) def visit_Assign(self, node: ast.Assign) -> None: self.visit(node.value) @@ -169,10 +189,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N if InputFlags.Inout in inp.flags: match arg: case PlaceNode(place=place): - for leaf in leaf_places(place): - assert not isinstance(leaf, SubscriptAccess) - leaf = leaf.replace_defined_at(arg) - self.scope.assign(leaf) + self._reassign_single_inout_arg(place, arg) case arg if inp.ty.linear: raise GuppyError( f"Inout argument with linear type `{inp.ty}` would be " @@ -182,6 +199,19 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N arg, ) + def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None: + """Helper function to reassign a single inout argument after a function call.""" + # Places involving subscripts are given back by visiting the `__setitem__` call + if subscript := contains_subscript(place): + assert subscript.setitem_call is not None + self.visit(subscript.setitem_call) + self._reassign_single_inout_arg(subscript.parent, node) + else: + for leaf in leaf_places(place): + assert not isinstance(leaf, SubscriptAccess) + leaf = leaf.replace_defined_at(node) + self.scope.assign(leaf) + def visit_GlobalCall(self, node: GlobalCall) -> None: func = self.globals[node.def_id] assert isinstance(func, CallableDef) @@ -214,6 +244,19 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None: node.value, ) + def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> None: + # A subscript access on a value that is not a place. This means the value can no + # longer be accessed after the item has been projected out. Thus, this is only + # legal if the items in the container are not linear + elem_ty = get_type(node.getitem_expr) + if elem_ty.linear: + raise GuppyTypeError( + f"Remaining linear items with type `{elem_ty}` are not used", node + ) + self.visit(node.item_expr) + self.scope.assign(node.item) + self.visit(node.getitem_expr) + def visit_Expr(self, node: ast.Expr) -> None: # An expression statement where the return value is discarded self.visit(node.value) @@ -357,6 +400,15 @@ def leaf_places(place: Place) -> Iterator[Place]: yield place +def contains_subscript(place: Place) -> SubscriptAccess | None: + """Checks if a place contains a subscript access and returns the rightmost one.""" + while not isinstance(place, Variable): + if isinstance(place, SubscriptAccess): + return place + place = place.parent + return None + + def is_inout_var(place: Place) -> TypeGuard[Variable]: """Checks whether a place is an @inout variable.""" return isinstance(place, Variable) and InputFlags.Inout in place.flags diff --git a/tests/error/array_errors/subscript_after_use.err b/tests/error/array_errors/subscript_after_use.err new file mode 100644 index 00000000..098b28ed --- /dev/null +++ b/tests/error/array_errors/subscript_after_use.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main(qs: array[qubit, 42]) -> array[qubit, 42]: +18: return foo(qs, qs[0]) + ^^ +GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:15) diff --git a/tests/error/array_errors/subscript_after_use.py b/tests/error/array_errors/subscript_after_use.py new file mode 100644 index 00000000..37c9e6d2 --- /dev/null +++ b/tests/error/array_errors/subscript_after_use.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array, inout +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo(qs: array[qubit, 42], q: qubit @inout) -> array[qubit, 42]: ... + + +@guppy(module) +def main(qs: array[qubit, 42]) -> array[qubit, 42]: + return foo(qs, qs[0]) + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/subscript_drop.err b/tests/error/array_errors/subscript_drop.err new file mode 100644 index 00000000..448abbbc --- /dev/null +++ b/tests/error/array_errors/subscript_drop.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main() -> qubit: +18: return foo()[0] + ^^^^^^^^ +GuppyTypeError: Remaining linear items with type `qubit` are not used diff --git a/tests/error/array_errors/subscript_drop.py b/tests/error/array_errors/subscript_drop.py new file mode 100644 index 00000000..efd91b92 --- /dev/null +++ b/tests/error/array_errors/subscript_drop.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo() -> array[qubit, 10]: ... + + +@guppy(module) +def main() -> qubit: + return foo()[0] + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/subscript_non_inout.err b/tests/error/array_errors/subscript_non_inout.err new file mode 100644 index 00000000..55df1950 --- /dev/null +++ b/tests/error/array_errors/subscript_non_inout.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:14 + +12: @guppy(module) +13: def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]: +14: q = qs[0] + ^^^^^ +GuppyError: Subscripting on expression with linear type `array[qubit, 42]` is only allowed in `@inout` position diff --git a/tests/error/array_errors/subscript_non_inout.py b/tests/error/array_errors/subscript_non_inout.py new file mode 100644 index 00000000..70756ead --- /dev/null +++ b/tests/error/array_errors/subscript_non_inout.py @@ -0,0 +1,18 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]: + q = qs[0] + return q, qs + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/use_after_subscript.err b/tests/error/array_errors/use_after_subscript.err new file mode 100644 index 00000000..566f65de --- /dev/null +++ b/tests/error/array_errors/use_after_subscript.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main(qs: array[qubit, 42]) -> array[qubit, 42]: +18: return foo(qs[0], qs) + ^^^^^ +GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:22) diff --git a/tests/error/array_errors/use_after_subscript.py b/tests/error/array_errors/use_after_subscript.py new file mode 100644 index 00000000..49b5e808 --- /dev/null +++ b/tests/error/array_errors/use_after_subscript.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array, inout +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo(q: qubit @inout, qs: array[qubit, 42]) -> array[qubit, 42]: ... + + +@guppy(module) +def main(qs: array[qubit, 42]) -> array[qubit, 42]: + return foo(qs[0], qs) + + +module.compile() \ No newline at end of file From 851c8d11b2af827f9fd8b9c81cb18d6092091381 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Fri, 30 Aug 2024 17:09:05 +0100 Subject: [PATCH 4/6] feat: Hugr lowering for array indexing and add integration tests (#422) Closes #418. See #415 for context. Note that `array.__getitem__` and `array.__setitem__` are currently lowered to dummy Hugr ops. The "swapping with `None`" logic explained in #415 will follow in a future PR (see #419) --- guppylang/compiler/core.py | 3 + guppylang/compiler/expr_compiler.py | 28 ++++- guppylang/prelude/_internal/compiler.py | 44 ++++++++ guppylang/prelude/builtins.py | 17 +-- tests/integration/test_array.py | 143 +++++++++++++++++++++++- 5 files changed, 219 insertions(+), 16 deletions(-) diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index 0abc642a..067e0600 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -75,6 +75,9 @@ def __setitem__(self, place: Place, port: Wire) -> None: else: self.locals[place.id] = port + def __contains__(self, place: Place) -> bool: + return place.id in self.locals + def __copy__(self) -> "DFContainer": # Make a copy of the var map so that mutating the copy doesn't # mutate our variable mapping diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 207c61d8..beebc886 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -18,6 +18,7 @@ from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type from guppylang.cfg.builder import tmp_vars from guppylang.checker.core import Variable +from guppylang.checker.linearity_checker import contains_subscript from guppylang.compiler.core import CompilerBase, DFContainer from guppylang.definition.value import CompiledCallableDef, CompiledValueDef from guppylang.error import GuppyError, InternalGuppyError @@ -27,9 +28,11 @@ FieldAccessAndDrop, GlobalCall, GlobalName, + InoutReturnSentinel, LocalCall, PlaceNode, ResultExpr, + SubscriptAccessAndDrop, TensorCall, TypeApply, ) @@ -167,6 +170,10 @@ def visit_Constant(self, node: ast.Constant) -> Wire: raise InternalGuppyError("Unsupported constant expression in compiler") def visit_PlaceNode(self, node: PlaceNode) -> Wire: + if subscript := contains_subscript(node.place): + if subscript.item not in self.dfg: + self.dfg[subscript.item] = self.visit(subscript.item_expr) + self.dfg[subscript] = self.visit(subscript.getitem_call) return self.dfg[node.place] def visit_GlobalName(self, node: GlobalName) -> Wire: @@ -183,6 +190,10 @@ def visit_GlobalName(self, node: GlobalName) -> Wire: def visit_Name(self, node: ast.Name) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_InoutReturnSentinel(self, node: InoutReturnSentinel) -> Wire: + assert not isinstance(node.var, str) + return self.dfg[node.var] + def visit_Tuple(self, node: ast.Tuple) -> Wire: elems = [self.visit(e) for e in node.elts] types = [get_type(e) for e in node.elts] @@ -229,9 +240,18 @@ def _update_inout_ports( """Helper method that updates the ports for @inout arguments after a call.""" for inp, arg in zip(func_ty.inputs, args, strict=True): if InputFlags.Inout in inp.flags: - # Linearity checker ensures that inout arguments are places - assert isinstance(arg, PlaceNode) + # Linearity checker ensures that inout arguments that are not places + # can be safely dropped after the call returns + if not isinstance(arg, PlaceNode): + next(inout_ports) + continue self.dfg[arg.place] = next(inout_ports) + # Places involving subscripts need to generate code for the appropriate + # `__setitem__` call. Nested subscripts are handled automatically since + # `arg.place.parent` occurs as an inout arg of this call, so will also + # be recursively reassigned. + if subscript := contains_subscript(arg.place): + self.visit(subscript.setitem_call) assert next(inout_ports, None) is None, "Too many inout return ports" def visit_LocalCall(self, node: LocalCall) -> Wire: @@ -362,6 +382,10 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> Wire: field_idx ] + def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> Wire: + self.dfg[node.item] = self.visit(node.item_expr) + return self.visit(node.getitem_expr) + def visit_ResultExpr(self, node: ResultExpr) -> Wire: extra_args = [] if isinstance(node.base_ty, NumericType): diff --git a/guppylang/prelude/_internal/compiler.py b/guppylang/prelude/_internal/compiler.py index 798031e0..24451a94 100644 --- a/guppylang/prelude/_internal/compiler.py +++ b/guppylang/prelude/_internal/compiler.py @@ -6,6 +6,8 @@ from guppylang.definition.custom import ( CustomCallCompiler, ) +from guppylang.definition.value import CallReturnWires +from guppylang.error import InternalGuppyError from guppylang.tys.ty import NumericType # Note: Hugr's INT_T is 64bits, but guppy defaults to 32bits @@ -216,3 +218,45 @@ def compile(self, args: list[Wire]) -> list[Wire]: quantum_op("Reset")(ht.FunctionType([ht.Qubit], [ht.Qubit]), []), q ) return [q] + + +class ArrayGetitemCompiler(CustomCallCompiler): + """Compiler for the `array.__getitem__` function.""" + + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: + [array, idx] = args + array_ty = self.ty.input[0] + idx_ty = self.ty.input[1] + elem_ty = self.ty.output[0] + op = ops.Custom( + extension="guppy.unsupported.array", + signature=ht.FunctionType([array_ty, idx_ty], [array_ty, elem_ty]), + name="GetItem", + args=[arg.to_hugr() for arg in self.type_args], + ) + node = self.builder.add_op(op, array, idx) + return CallReturnWires(regular_returns=[node[1]], inout_returns=[node[0]]) + + def compile(self, args: list[Wire]) -> list[Wire]: + raise InternalGuppyError("Call compile_with_inouts instead") + + +class ArraySetitemCompiler(CustomCallCompiler): + """Compiler for the `array.__setitem__` function.""" + + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: + [array, idx, elem] = args + array_ty = self.ty.input[0] + idx_ty = self.ty.input[1] + elem_ty = self.ty.input[2] + op = ops.Custom( + extension="guppy.unsupported.array", + signature=ht.FunctionType([array_ty, idx_ty, elem_ty], [array_ty]), + name="SetItem", + args=[arg.to_hugr() for arg in self.type_args], + ) + node = self.builder.add_op(op, array, idx, elem) + return CallReturnWires(regular_returns=[], inout_returns=[node[0]]) + + def compile(self, args: list[Wire]) -> list[Wire]: + raise InternalGuppyError("Call compile_with_inouts instead") diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 56c2f818..f376603d 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -19,6 +19,8 @@ UnsupportedChecker, ) from guppylang.prelude._internal.compiler import ( + ArrayGetitemCompiler, + ArraySetitemCompiler, FloatBoolCompiler, FloatDivmodCompiler, FloatFloordivCompiler, @@ -34,7 +36,6 @@ linst_op, list_op, logic_op, - type_arg, ) from guppylang.tys.builtin import ( array_type_def, @@ -649,20 +650,10 @@ def sort(self: linst[T]) -> None: ... @guppy.extend_type(builtins, array_type_def) class Array: - @guppy.hugr_op( - builtins, - custom_op( - "ArrayGet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0} - ), - ) + @guppy.custom(builtins, ArrayGetitemCompiler()) def __getitem__(self: array[L, n] @ inout, idx: int) -> L: ... - @guppy.hugr_op( - builtins, - custom_op( - "ArraySet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0} - ), - ) + @guppy.custom(builtins, ArraySetitemCompiler()) def __setitem__(self: array[L, n] @ inout, idx: int, value: L) -> None: ... @guppy.custom(builtins, checker=ArrayLenChecker()) diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 1ce01107..200752d5 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -4,9 +4,12 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.prelude.builtins import array +from guppylang.prelude.builtins import array, inout from tests.util import compile_guppy +from guppylang.prelude.quantum import qubit +import guppylang.prelude.quantum as quantum + def test_len(validate): module = GuppyModule("test") @@ -31,3 +34,141 @@ def main(xs: array[int, 5], i: int) -> int: return xs[0] + xs[i] + xs[xs[2 * i]] validate(main) + + +def test_subscript_drop_rest(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def foo() -> array[int, 10]: ... + + @guppy(module) + def main() -> int: + return foo()[0] + + validate(module.compile()) + + +def test_linear_subscript(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def foo(q: qubit @inout) -> None: ... + + @guppy(module) + def main(qs: array[qubit, 42], i: int) -> array[qubit, 42]: + foo(qs[i]) + return qs + + validate(module.compile()) + + +def test_inout_subscript(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def foo(q: qubit @inout) -> None: ... + + @guppy(module) + def main(qs: array[qubit, 42] @inout, i: int) -> None: + foo(qs[i]) + + validate(module.compile()) + + +def test_multi_subscripts(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def foo(q1: qubit @inout, q2: qubit @inout) -> None: ... + + @guppy(module) + def main(qs: array[qubit, 42]) -> array[qubit, 42]: + foo(qs[0], qs[1]) + foo(qs[0], qs[0]) # Will panic at runtime + return qs + + validate(module.compile()) + + +def test_struct_array(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.struct(module) + class S: + q1: qubit + q2: qubit + + @guppy.declare(module) + def foo(q1: qubit @inout, q2: qubit @inout) -> None: ... + + @guppy(module) + def main(ss: array[S, 10]) -> array[S, 10]: + # This will panic at runtime :( + # To make this work, we would need to replace the qubits in the struct + # with `qubit | None` and write back `None` after `q1` has been extracted... + foo(ss[0].q1, ss[0].q2) + return ss + + validate(module.compile()) + + +def test_nested_subscripts(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def foo(q: qubit @inout) -> None: ... + + @guppy.declare(module) + def bar( + q1: qubit @inout, q2: qubit @inout, q3: qubit @inout, q4: qubit @inout + ) -> None: ... + + @guppy(module) + def main(qs: array[array[qubit, 13], 42]) -> array[array[qubit, 13], 42]: + foo(qs[0][0]) + # The following should work *without* panicking at runtime! Accessing `qs[0][0]` + # replaces one qubit with `None` but puts everything back into `qs` before + # going to the next argument. + bar(qs[0][0], qs[0][1], qs[1][0], qs[1][1]) + return qs + + validate(module.compile()) + + +def test_struct_nested_subscript(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.struct(module) + class C: + c: qubit + blah: int + + @guppy.struct(module) + class B: + ys: array[array[C, 10], 20] + foo: C + + @guppy.struct(module) + class A: + xs: array[B, 42] + bar: qubit + baz: tuple[B, B] + + @guppy.declare(module) + def foo(q1: qubit @inout) -> None: ... + + @guppy(module) + def main(a: A, i: int, j: int, k: int) -> A: + foo(a.xs[i].ys[j][k].c) + return a + + validate(module.compile()) + From f3c3d2251f4528bedad37e98ece8215a2e83f3a9 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:26:53 +0100 Subject: [PATCH 5/6] feat: Lower linear array elements to optionals (#447) Closes #419 --- guppylang/prelude/_internal/compiler.py | 187 +++++++++++++++++++++--- guppylang/tys/builtin.py | 9 +- 2 files changed, 177 insertions(+), 19 deletions(-) diff --git a/guppylang/prelude/_internal/compiler.py b/guppylang/prelude/_internal/compiler.py index 5bcc6645..27fb9be8 100644 --- a/guppylang/prelude/_internal/compiler.py +++ b/guppylang/prelude/_internal/compiler.py @@ -1,6 +1,8 @@ import hugr from hugr import Wire, ops from hugr import tys as ht +from hugr import val as hv +from hugr.dfg import _DfBase from hugr.std.float import FLOAT_T from guppylang.definition.custom import ( @@ -250,20 +252,68 @@ def compile(self, args: list[Wire]) -> list[Wire]: class ArrayGetitemCompiler(CustomCallCompiler): """Compiler for the `array.__getitem__` function.""" - def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: - [array, idx] = args - array_ty = self.ty.input[0] - idx_ty = self.ty.input[1] - elem_ty = self.ty.output[0] + def build_classical_getitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem_ty: ht.Type, + ) -> CallReturnWires: + """Lowers a call to `array.__getitem__` for classical arrays.""" + [ty_arg, len_arg] = self.type_args op = ops.Custom( extension="guppy.unsupported.array", signature=ht.FunctionType([array_ty, idx_ty], [array_ty, elem_ty]), - name="GetItem", - args=[arg.to_hugr() for arg in self.type_args], + name="get", + args=[len_arg.to_hugr(), ty_arg.to_hugr()], ) node = self.builder.add_op(op, array, idx) return CallReturnWires(regular_returns=[node[1]], inout_returns=[node[0]]) + def build_linear_getitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem_ty: ht.Type, + ) -> CallReturnWires: + """Lowers a call to `array.__getitem__` for linear arrays.""" + # Swap out the element at the given index with `None`. The `to_hugr` + # implementation of the array type ensures that linear element types are turned + # into optionals. + elem_opt_ty = ht.Sum([[elem_ty], []]) + none = self.builder.add_op(ops.Tag(1, elem_opt_ty)) + length = self.type_args[1].to_hugr() + array, elem_opt = build_array_set( + self.builder, + array, + array_ty, + idx, + idx_ty, + none, + elem_opt_ty, + length, + ) + # Make sure that the element we got out is not None + conditional = self.builder.add_conditional(elem_opt) + with conditional.add_case(0) as case: + case.set_outputs(*case.inputs()) + with conditional.add_case(1) as case: + error = build_error(case, 1, "Linear array element has already been used") + case.set_outputs(build_panic(case, [], [elem_ty], error)) + return CallReturnWires(regular_returns=[conditional], inout_returns=[array]) + + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: + [array, idx] = args + [array_ty, idx_ty] = self.ty.input + [elem_ty, *_] = self.ty.output + if elem_ty.type_bound() == ht.TypeBound.Any: + return self.build_linear_getitem(array, array_ty, idx, idx_ty, elem_ty) + else: + return self.build_classical_getitem(array, array_ty, idx, idx_ty, elem_ty) + def compile(self, args: list[Wire]) -> list[Wire]: raise InternalGuppyError("Call compile_with_inouts instead") @@ -271,19 +321,120 @@ def compile(self, args: list[Wire]) -> list[Wire]: class ArraySetitemCompiler(CustomCallCompiler): """Compiler for the `array.__setitem__` function.""" + def build_classical_setitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem: Wire, + elem_ty: ht.Type, + length: ht.TypeArg, + ) -> CallReturnWires: + """Lowers a call to `array.__setitem__` for classical arrays.""" + array, _ = build_array_set( + self.builder, array, array_ty, idx, idx_ty, elem, elem_ty, length + ) + return CallReturnWires(regular_returns=[], inout_returns=[array]) + + def build_linear_setitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem: Wire, + elem_ty: ht.Type, + length: ht.TypeArg, + ) -> CallReturnWires: + """Lowers a call to `array.__setitem__` for linear arrays.""" + # Embed the element into an optional + elem_opt_ty = ht.Sum([[elem_ty], []]) + elem = self.builder.add_op(ops.Tag(0, elem_opt_ty), elem) + array, old_elem = build_array_set( + self.builder, array, array_ty, idx, idx_ty, elem, elem_opt_ty, length + ) + # Check that the old element was `None` + conditional = self.builder.add_conditional(old_elem) + with conditional.add_case(0) as case: + error = build_error(case, 1, "Linear array element has not been used") + build_panic(case, [elem_ty], [], error, *case.inputs()) + case.set_outputs() + with conditional.add_case(1) as case: + case.set_outputs() + return CallReturnWires(regular_returns=[], inout_returns=[array]) + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: [array, idx, elem] = args - array_ty = self.ty.input[0] - idx_ty = self.ty.input[1] - elem_ty = self.ty.input[2] - op = ops.Custom( - extension="guppy.unsupported.array", - signature=ht.FunctionType([array_ty, idx_ty, elem_ty], [array_ty]), - name="SetItem", - args=[arg.to_hugr() for arg in self.type_args], - ) - node = self.builder.add_op(op, array, idx, elem) - return CallReturnWires(regular_returns=[], inout_returns=[node[0]]) + [array_ty, idx_ty, elem_ty] = self.ty.input + length = self.type_args[1].to_hugr() + if elem_ty.type_bound() == ht.TypeBound.Any: + return self.build_linear_setitem( + array, array_ty, idx, idx_ty, elem, elem_ty, length + ) + else: + return self.build_classical_setitem( + array, array_ty, idx, idx_ty, elem, elem_ty, length + ) def compile(self, args: list[Wire]) -> list[Wire]: raise InternalGuppyError("Call compile_with_inouts instead") + + +#: The Hugr error type +error_ty = ht.Opaque( + id="error", bound=ht.TypeBound.Copyable, args=[], extension="prelude" +) + + +def build_panic( + # TODO: Change to `_DfBase[ops.DfParentOp]` once `_DfBase` is covariant + builder: _DfBase[ops.Case], + in_tys: ht.TypeRow, + out_tys: ht.TypeRow, + err: Wire, + *args: Wire, +) -> Wire: + """Builds a panic operation.""" + op = ops.Custom( + extension="prelude", + signature=ht.FunctionType([error_ty, *in_tys], out_tys), + name="panic", + args=[ + ht.SequenceArg([ht.TypeTypeArg(ty) for ty in in_tys]), + ht.SequenceArg([ht.TypeTypeArg(ty) for ty in out_tys]), + ], + ) + return builder.add_op(op, err, *args) + + +def build_error(builder: _DfBase[ops.Case], signal: int, msg: str) -> Wire: + """Constructs and loads a static error value.""" + val = hv.Extension( + name="ConstError", + typ=error_ty, + val={"signal": signal, "message": msg}, + extensions=["prelude"], + ) + return builder.load(builder.add_const(val)) + + +def build_array_set( + builder: _DfBase[ops.DfParentOp], + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem: Wire, + elem_ty: ht.Type, + length: ht.TypeArg, +) -> tuple[Wire, Wire]: + """Builds an array set operation, returning the original element.""" + op = ops.Custom( + extension="guppy.unsupported.array", + signature=ht.FunctionType([array_ty, idx_ty, elem_ty], [array_ty, elem_ty]), + name="get", + args=[length, ht.TypeTypeArg(elem_ty)], + ) + array, swapped_elem = iter(builder.add_op(op, array, idx, elem)) + return array, swapped_elem diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 427ea1e3..70159ee6 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -133,10 +133,17 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: [ty_arg, len_arg] = args assert isinstance(ty_arg, TypeArg) assert isinstance(len_arg, ConstArg) + # Linear elements are turned into an optional to enable unsafe indexing. + # See `ArrayGetitemCompiler` for details. + elem_ty: ht.Type + if ty_arg.ty.linear: + elem_ty = ht.Sum([[ty_arg.ty.to_hugr()], []]) + else: + elem_ty = ty_arg.ty.to_hugr() return ht.Opaque( extension="prelude", id="array", - args=[len_arg.to_hugr(), ty_arg.to_hugr()], + args=[len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)], bound=ty_arg.ty.hugr_bound, ) From 85dcceb76c9aa5f4004933fe8927f006cee99ed3 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 4 Sep 2024 12:31:03 +0200 Subject: [PATCH 6/6] Fix tests --- tests/error/array_errors/subscript_after_use.py | 2 +- tests/error/array_errors/subscript_drop.py | 2 +- tests/error/array_errors/use_after_subscript.py | 2 +- tests/error/inout_errors/subscript_not_setable.py | 2 +- tests/integration/test_array.py | 14 +++++++------- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/error/array_errors/subscript_after_use.py b/tests/error/array_errors/subscript_after_use.py index 37c9e6d2..b05d31a9 100644 --- a/tests/error/array_errors/subscript_after_use.py +++ b/tests/error/array_errors/subscript_after_use.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -module.load(quantum) +module.load_all(quantum) @guppy.declare(module) diff --git a/tests/error/array_errors/subscript_drop.py b/tests/error/array_errors/subscript_drop.py index efd91b92..8775a1d4 100644 --- a/tests/error/array_errors/subscript_drop.py +++ b/tests/error/array_errors/subscript_drop.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -module.load(quantum) +module.load_all(quantum) @guppy.declare(module) diff --git a/tests/error/array_errors/use_after_subscript.py b/tests/error/array_errors/use_after_subscript.py index 49b5e808..938686f7 100644 --- a/tests/error/array_errors/use_after_subscript.py +++ b/tests/error/array_errors/use_after_subscript.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -module.load(quantum) +module.load_all(quantum) @guppy.declare(module) diff --git a/tests/error/inout_errors/subscript_not_setable.py b/tests/error/inout_errors/subscript_not_setable.py index 7a1323fc..ca592bc5 100644 --- a/tests/error/inout_errors/subscript_not_setable.py +++ b/tests/error/inout_errors/subscript_not_setable.py @@ -4,7 +4,7 @@ from guppylang.prelude.quantum import qubit, quantum module = GuppyModule("test") -module.load(quantum) +module.load_all(quantum) @guppy.declare(module) diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 93dcb14d..ea62087d 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -64,7 +64,7 @@ def main(ys: array[int, 0]) -> array[array[int, 0], 2]: def test_subscript_drop_rest(validate): module = GuppyModule("test") - module.load(quantum) + module.load_all(quantum) @guppy.declare(module) def foo() -> array[int, 10]: ... @@ -78,7 +78,7 @@ def main() -> int: def test_linear_subscript(validate): module = GuppyModule("test") - module.load(quantum) + module.load_all(quantum) @guppy.declare(module) def foo(q: qubit @inout) -> None: ... @@ -93,7 +93,7 @@ def main(qs: array[qubit, 42], i: int) -> array[qubit, 42]: def test_inout_subscript(validate): module = GuppyModule("test") - module.load(quantum) + module.load_all(quantum) @guppy.declare(module) def foo(q: qubit @inout) -> None: ... @@ -107,7 +107,7 @@ def main(qs: array[qubit, 42] @inout, i: int) -> None: def test_multi_subscripts(validate): module = GuppyModule("test") - module.load(quantum) + module.load_all(quantum) @guppy.declare(module) def foo(q1: qubit @inout, q2: qubit @inout) -> None: ... @@ -123,7 +123,7 @@ def main(qs: array[qubit, 42]) -> array[qubit, 42]: def test_struct_array(validate): module = GuppyModule("test") - module.load(quantum) + module.load_all(quantum) @guppy.struct(module) class S: @@ -146,7 +146,7 @@ def main(ss: array[S, 10]) -> array[S, 10]: def test_nested_subscripts(validate): module = GuppyModule("test") - module.load(quantum) + module.load_all(quantum) @guppy.declare(module) def foo(q: qubit @inout) -> None: ... @@ -170,7 +170,7 @@ def main(qs: array[array[qubit, 13], 42]) -> array[array[qubit, 13], 42]: def test_struct_nested_subscript(validate): module = GuppyModule("test") - module.load(quantum) + module.load_all(quantum) @guppy.struct(module) class C: