diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 3a03dd31..60ec1c4d 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -60,10 +60,10 @@ #: #: All places are equipped with a unique id, a type and an optional definition AST #: location. During linearity checking, they are tracked separately. -Place: TypeAlias = "Variable | FieldAccess" +Place: TypeAlias = "Variable | FieldAccess | SubscriptAccess" #: Unique identifier for a `Place`. -PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id" +PlaceId: TypeAlias = "Variable.Id | FieldAccess.Id | SubscriptAccess.Id" @dataclass(frozen=True) @@ -154,6 +154,45 @@ def replace_defined_at(self, node: AstNode | None) -> "FieldAccess": return replace(self, exact_defined_at=node) +@dataclass(frozen=True) +class SubscriptAccess: + """A place identifying a subscript `place[item]` access.""" + + parent: Place + item: Variable + ty: Type + item_expr: ast.expr + getitem_call: ast.expr + #: Only populated if this place occurs in an inout position + setitem_call: ast.expr | None = None + + @dataclass(frozen=True) + class Id: + """Identifier for subscript places.""" + + parent: PlaceId + item: Variable.Id + + @cached_property + def id(self) -> "SubscriptAccess.Id": + """The unique `PlaceId` identifier for this place.""" + return SubscriptAccess.Id(self.parent.id, self.item.id) + + @cached_property + def defined_at(self) -> AstNode | None: + """Optional location where this place was last assigned to.""" + return self.parent.defined_at + + @property + def describe(self) -> str: + """A human-readable description of this place for error messages.""" + return f"Subscript `{self}`" + + def __str__(self) -> str: + """String representation of this place.""" + return f"{self.parent}[...]" + + PyScope = dict[str, Any] diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 556e0257..273ee2b7 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -24,6 +24,7 @@ import sys import traceback from contextlib import suppress +from dataclasses import replace from typing import Any, NoReturn, cast from guppylang.ast_util import ( @@ -35,12 +36,15 @@ with_loc, with_type, ) +from guppylang.cfg.builder import tmp_vars from guppylang.checker.core import ( Context, DummyEvalDict, FieldAccess, Globals, Locals, + Place, + SubscriptAccess, Variable, ) from guppylang.definition.common import Definition @@ -58,6 +62,7 @@ DesugaredListComp, FieldAccessAndDrop, GlobalName, + InoutReturnSentinel, IterEnd, IterHasNext, IterNext, @@ -66,6 +71,7 @@ PartialApply, PlaceNode, PyExpr, + SubscriptAccessAndDrop, TensorCall, TypeApply, ) @@ -491,7 +497,7 @@ def _synthesize_binary( node, ) - def _synthesize_instance_func( + def synthesize_instance_func( self, node: ast.expr, args: list[ast.expr], @@ -539,16 +545,37 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]: def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) + item_expr, item_ty = self.synthesize(node.slice) + # Give the item a unique name so we can refer to it later in case we also want + # to compile a call to `__setitem__` + item = Variable(next(tmp_vars), item_ty, item_expr) + item_node = with_type(item_ty, with_loc(item_expr, PlaceNode(place=item))) + # Check a call to the `__getitem__` instance function exp_sig = FunctionType( [ - FuncInput(ty, InputFlags.NoFlags), + FuncInput(ty, InputFlags.Inout), FuncInput(ExistentialTypeVar.fresh("Key", False), InputFlags.NoFlags), ], ExistentialTypeVar.fresh("Val", False), ) - return self._synthesize_instance_func( - node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig + getitem_expr, result_ty = self.synthesize_instance_func( + node.value, [item_node], "__getitem__", "not subscriptable", exp_sig ) + # Subscripting a place is itself a place + expr: ast.expr + if isinstance(node.value, PlaceNode): + place = SubscriptAccess( + node.value.place, item, result_ty, item_expr, getitem_expr + ) + expr = PlaceNode(place=place) + else: + # If the subscript is not on a place, then there is no way to address the + # other indices after this one has been projected out (e.g. `f()[0]` makes + # you loose access to all elements besides 0). + expr = SubscriptAccessAndDrop( + item=item, item_expr=item_expr, getitem_expr=getitem_expr + ) + return with_loc(node, expr), result_ty def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: if len(node.keywords) > 0: @@ -600,7 +627,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]: exp_sig = FunctionType( [FuncInput(ty, InputFlags.NoFlags)], ExistentialTypeVar.fresh("Iter", False) ) - expr, ty = self._synthesize_instance_func( + expr, ty = self.synthesize_instance_func( node.value, [], "__iter__", "not iterable", exp_sig ) @@ -624,7 +651,7 @@ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]: exp_sig = FunctionType( [FuncInput(ty, InputFlags.NoFlags)], TupleType([bool_type(), ty]) ) - return self._synthesize_instance_func( + return self.synthesize_instance_func( node.value, [], "__hasnext__", "not an iterator", exp_sig, True ) @@ -634,14 +661,14 @@ def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]: [FuncInput(ty, InputFlags.NoFlags)], TupleType([ExistentialTypeVar.fresh("T", False), ty]), ) - return self._synthesize_instance_func( + return self.synthesize_instance_func( node.value, [], "__next__", "not an iterator", exp_sig, True ) def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) exp_sig = FunctionType([FuncInput(ty, InputFlags.NoFlags)], NoneType()) - return self._synthesize_instance_func( + return self.synthesize_instance_func( node.value, [], "__end__", "not an iterator", exp_sig, True ) @@ -764,6 +791,8 @@ def type_check_args( new_args: list[ast.expr] = [] for inp, func_inp in zip(inputs, func_ty.inputs, strict=True): a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument") + if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode): + a.place = check_inout_arg_place(a.place, ctx, a) new_args.append(a) subst |= s @@ -784,6 +813,43 @@ def type_check_args( return new_args, subst +def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place: + """Performs additional checks for place arguments in @inout position. + + In particular, we need to check that places involving `place[item]` subscripts + implement the corresponding `__setitem__` method. + """ + match place: + case Variable(): + return place + case FieldAccess(parent=parent): + return replace(place, parent=check_inout_arg_place(parent, ctx, node)) + case SubscriptAccess(parent=parent, item=item, ty=ty): + # Check a call to the `__setitem__` instance function + exp_sig = FunctionType( + [ + FuncInput(parent.ty, InputFlags.Inout), + FuncInput(item.ty, InputFlags.NoFlags), + FuncInput(ty, InputFlags.NoFlags), + ], + NoneType(), + ) + setitem_args = [ + with_type(parent.ty, with_loc(node, PlaceNode(parent))), + with_type(item.ty, with_loc(node, PlaceNode(item))), + with_type(ty, with_loc(node, InoutReturnSentinel(var=place))), + ] + setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func( + setitem_args[0], + setitem_args[1:], + "__setitem__", + "not allowed in a subscripted `@inout` position", + exp_sig, + True, + ) + return replace(place, setitem_call=setitem_call) + + def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[list[ast.expr], Type, Inst]: diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index b467103a..1bede919 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -18,6 +18,7 @@ Locals, Place, PlaceId, + SubscriptAccess, Variable, ) from guppylang.definition.custom import CustomFunctionDef @@ -33,9 +34,15 @@ LocalCall, PartialApply, PlaceNode, + SubscriptAccessAndDrop, TensorCall, ) -from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, StructType +from guppylang.tys.ty import ( + FuncInput, + FunctionType, + InputFlags, + StructType, +) class Scope(Locals[PlaceId, Place]): @@ -137,16 +144,31 @@ def visit_PlaceNode(self, node: PlaceNode, /, is_inout_arg: bool = False) -> Non "ownership of the value.", node, ) - for place in leaf_places(node.place): - x = place.id - if (use := self.scope.used(x)) and place.ty.linear: + # Places involving subscripts are handled differently since we ignore everything + # after the subscript for the purposes of linearity checking + if subscript := contains_subscript(node.place): + if not is_inout_arg and subscript.parent.ty.linear: raise GuppyError( - f"{place.describe} with linear type `{place.ty}` was already " - "used (at {0})", + "Subscripting on expression with linear type " + f"`{subscript.parent.ty}` is only allowed in `@inout` position", node, - [use], ) - self.scope.use(x, node) + self.scope.assign(subscript.item) + # Visiting the `__getitem__(place.parent, place.item)` call ensures that we + # linearity-check the parent and element. + self.visit(subscript.getitem_call) + # For all other places, we record uses of all leafs + else: + for place in leaf_places(node.place): + x = place.id + if (use := self.scope.used(x)) and place.ty.linear: + raise GuppyError( + f"{place.describe} with linear type `{place.ty}` was already " + "used (at {0})", + node, + [use], + ) + self.scope.use(x, node) def visit_Assign(self, node: ast.Assign) -> None: self.visit(node.value) @@ -170,9 +192,7 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N if InputFlags.Inout in inp.flags: match arg: case PlaceNode(place=place): - for leaf in leaf_places(place): - leaf = leaf.replace_defined_at(arg) - self.scope.assign(leaf) + self._reassign_single_inout_arg(place, arg) case arg if inp.ty.linear: raise GuppyError( f"Inout argument with linear type `{inp.ty}` would be " @@ -182,6 +202,19 @@ def _reassign_inout_args(self, func_ty: FunctionType, args: list[ast.expr]) -> N arg, ) + def _reassign_single_inout_arg(self, place: Place, node: ast.expr) -> None: + """Helper function to reassign a single inout argument after a function call.""" + # Places involving subscripts are given back by visiting the `__setitem__` call + if subscript := contains_subscript(place): + assert subscript.setitem_call is not None + self.visit(subscript.setitem_call) + self._reassign_single_inout_arg(subscript.parent, node) + else: + for leaf in leaf_places(place): + assert not isinstance(leaf, SubscriptAccess) + leaf = leaf.replace_defined_at(node) + self.scope.assign(leaf) + def visit_GlobalCall(self, node: GlobalCall) -> None: func = self.globals[node.def_id] assert isinstance(func, CallableDef) @@ -233,6 +266,19 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None: node.value, ) + def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> None: + # A subscript access on a value that is not a place. This means the value can no + # longer be accessed after the item has been projected out. Thus, this is only + # legal if the items in the container are not linear + elem_ty = get_type(node.getitem_expr) + if elem_ty.linear: + raise GuppyTypeError( + f"Remaining linear items with type `{elem_ty}` are not used", node + ) + self.visit(node.item_expr) + self.scope.assign(node.item) + self.visit(node.getitem_expr) + def visit_Expr(self, node: ast.Expr) -> None: # An expression statement where the return value is discarded self.visit(node.value) @@ -376,6 +422,15 @@ def leaf_places(place: Place) -> Iterator[Place]: yield place +def contains_subscript(place: Place) -> SubscriptAccess | None: + """Checks if a place contains a subscript access and returns the rightmost one.""" + while not isinstance(place, Variable): + if isinstance(place, SubscriptAccess): + return place + place = place.parent + return None + + def is_inout_var(place: Place) -> TypeGuard[Variable]: """Checks whether a place is an @inout variable.""" return isinstance(place, Variable) and InputFlags.Inout in place.flags diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index 0abc642a..067e0600 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -75,6 +75,9 @@ def __setitem__(self, place: Place, port: Wire) -> None: else: self.locals[place.id] = port + def __contains__(self, place: Place) -> bool: + return place.id in self.locals + def __copy__(self) -> "DFContainer": # Make a copy of the var map so that mutating the copy doesn't # mutate our variable mapping diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index dbc35dce..6dbc8edc 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -18,6 +18,7 @@ from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type from guppylang.cfg.builder import tmp_vars from guppylang.checker.core import Variable +from guppylang.checker.linearity_checker import contains_subscript from guppylang.compiler.core import CompilerBase, DFContainer from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef @@ -28,10 +29,12 @@ FieldAccessAndDrop, GlobalCall, GlobalName, + InoutReturnSentinel, LocalCall, PartialApply, PlaceNode, ResultExpr, + SubscriptAccessAndDrop, TensorCall, TypeApply, ) @@ -170,6 +173,10 @@ def visit_Constant(self, node: ast.Constant) -> Wire: raise InternalGuppyError("Unsupported constant expression in compiler") def visit_PlaceNode(self, node: PlaceNode) -> Wire: + if subscript := contains_subscript(node.place): + if subscript.item not in self.dfg: + self.dfg[subscript.item] = self.visit(subscript.item_expr) + self.dfg[subscript] = self.visit(subscript.getitem_call) return self.dfg[node.place] def visit_GlobalName(self, node: GlobalName) -> Wire: @@ -186,6 +193,10 @@ def visit_GlobalName(self, node: GlobalName) -> Wire: def visit_Name(self, node: ast.Name) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_InoutReturnSentinel(self, node: InoutReturnSentinel) -> Wire: + assert not isinstance(node.var, str) + return self.dfg[node.var] + def visit_Tuple(self, node: ast.Tuple) -> Wire: elems = [self.visit(e) for e in node.elts] types = [get_type(e) for e in node.elts] @@ -232,9 +243,18 @@ def _update_inout_ports( """Helper method that updates the ports for @inout arguments after a call.""" for inp, arg in zip(func_ty.inputs, args, strict=True): if InputFlags.Inout in inp.flags: - # Linearity checker ensures that inout arguments are places - assert isinstance(arg, PlaceNode) + # Linearity checker ensures that inout arguments that are not places + # can be safely dropped after the call returns + if not isinstance(arg, PlaceNode): + next(inout_ports) + continue self.dfg[arg.place] = next(inout_ports) + # Places involving subscripts need to generate code for the appropriate + # `__setitem__` call. Nested subscripts are handled automatically since + # `arg.place.parent` occurs as an inout arg of this call, so will also + # be recursively reassigned. + if subscript := contains_subscript(arg.place): + self.visit(subscript.setitem_call) assert next(inout_ports, None) is None, "Too many inout return ports" def visit_LocalCall(self, node: LocalCall) -> Wire: @@ -382,6 +402,10 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> Wire: field_idx ] + def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> Wire: + self.dfg[node.item] = self.visit(node.item_expr) + return self.visit(node.getitem_expr) + def visit_ResultExpr(self, node: ResultExpr) -> Wire: extra_args = [] if isinstance(node.base_ty, NumericType): diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 197738a8..47cd76c7 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -110,6 +110,20 @@ class FieldAccessAndDrop(ast.expr): ) +class SubscriptAccessAndDrop(ast.expr): + """A subscript element access on an object, dropping all the remaining items.""" + + item: "Variable" + item_expr: ast.expr + getitem_expr: ast.expr + + _fields = ( + "item", + "item_expr", + "getitem_expr", + ) + + class MakeIter(ast.expr): """Creates an iterator using the `__iter__` magic method. @@ -221,7 +235,7 @@ class InoutReturnSentinel(ast.expr): """An invisible expression corresponding to an implicit use of @inout vars whenever a function returns.""" - var: "Variable | str" + var: "Place | str" _fields = ("var",) diff --git a/guppylang/prelude/_internal/compiler.py b/guppylang/prelude/_internal/compiler.py index d9bde883..27fb9be8 100644 --- a/guppylang/prelude/_internal/compiler.py +++ b/guppylang/prelude/_internal/compiler.py @@ -1,11 +1,14 @@ import hugr from hugr import Wire, ops from hugr import tys as ht +from hugr import val as hv +from hugr.dfg import _DfBase from hugr.std.float import FLOAT_T from guppylang.definition.custom import ( CustomCallCompiler, ) +from guppylang.definition.value import CallReturnWires from guppylang.error import InternalGuppyError from guppylang.tys.arg import ConstArg, TypeArg from guppylang.tys.builtin import array_type @@ -244,3 +247,194 @@ def compile(self, args: list[Wire]) -> list[Wire]: quantum_op("Reset")(ht.FunctionType([ht.Qubit], [ht.Qubit]), []), q ) return [q] + + +class ArrayGetitemCompiler(CustomCallCompiler): + """Compiler for the `array.__getitem__` function.""" + + def build_classical_getitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem_ty: ht.Type, + ) -> CallReturnWires: + """Lowers a call to `array.__getitem__` for classical arrays.""" + [ty_arg, len_arg] = self.type_args + op = ops.Custom( + extension="guppy.unsupported.array", + signature=ht.FunctionType([array_ty, idx_ty], [array_ty, elem_ty]), + name="get", + args=[len_arg.to_hugr(), ty_arg.to_hugr()], + ) + node = self.builder.add_op(op, array, idx) + return CallReturnWires(regular_returns=[node[1]], inout_returns=[node[0]]) + + def build_linear_getitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem_ty: ht.Type, + ) -> CallReturnWires: + """Lowers a call to `array.__getitem__` for linear arrays.""" + # Swap out the element at the given index with `None`. The `to_hugr` + # implementation of the array type ensures that linear element types are turned + # into optionals. + elem_opt_ty = ht.Sum([[elem_ty], []]) + none = self.builder.add_op(ops.Tag(1, elem_opt_ty)) + length = self.type_args[1].to_hugr() + array, elem_opt = build_array_set( + self.builder, + array, + array_ty, + idx, + idx_ty, + none, + elem_opt_ty, + length, + ) + # Make sure that the element we got out is not None + conditional = self.builder.add_conditional(elem_opt) + with conditional.add_case(0) as case: + case.set_outputs(*case.inputs()) + with conditional.add_case(1) as case: + error = build_error(case, 1, "Linear array element has already been used") + case.set_outputs(build_panic(case, [], [elem_ty], error)) + return CallReturnWires(regular_returns=[conditional], inout_returns=[array]) + + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: + [array, idx] = args + [array_ty, idx_ty] = self.ty.input + [elem_ty, *_] = self.ty.output + if elem_ty.type_bound() == ht.TypeBound.Any: + return self.build_linear_getitem(array, array_ty, idx, idx_ty, elem_ty) + else: + return self.build_classical_getitem(array, array_ty, idx, idx_ty, elem_ty) + + def compile(self, args: list[Wire]) -> list[Wire]: + raise InternalGuppyError("Call compile_with_inouts instead") + + +class ArraySetitemCompiler(CustomCallCompiler): + """Compiler for the `array.__setitem__` function.""" + + def build_classical_setitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem: Wire, + elem_ty: ht.Type, + length: ht.TypeArg, + ) -> CallReturnWires: + """Lowers a call to `array.__setitem__` for classical arrays.""" + array, _ = build_array_set( + self.builder, array, array_ty, idx, idx_ty, elem, elem_ty, length + ) + return CallReturnWires(regular_returns=[], inout_returns=[array]) + + def build_linear_setitem( + self, + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem: Wire, + elem_ty: ht.Type, + length: ht.TypeArg, + ) -> CallReturnWires: + """Lowers a call to `array.__setitem__` for linear arrays.""" + # Embed the element into an optional + elem_opt_ty = ht.Sum([[elem_ty], []]) + elem = self.builder.add_op(ops.Tag(0, elem_opt_ty), elem) + array, old_elem = build_array_set( + self.builder, array, array_ty, idx, idx_ty, elem, elem_opt_ty, length + ) + # Check that the old element was `None` + conditional = self.builder.add_conditional(old_elem) + with conditional.add_case(0) as case: + error = build_error(case, 1, "Linear array element has not been used") + build_panic(case, [elem_ty], [], error, *case.inputs()) + case.set_outputs() + with conditional.add_case(1) as case: + case.set_outputs() + return CallReturnWires(regular_returns=[], inout_returns=[array]) + + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: + [array, idx, elem] = args + [array_ty, idx_ty, elem_ty] = self.ty.input + length = self.type_args[1].to_hugr() + if elem_ty.type_bound() == ht.TypeBound.Any: + return self.build_linear_setitem( + array, array_ty, idx, idx_ty, elem, elem_ty, length + ) + else: + return self.build_classical_setitem( + array, array_ty, idx, idx_ty, elem, elem_ty, length + ) + + def compile(self, args: list[Wire]) -> list[Wire]: + raise InternalGuppyError("Call compile_with_inouts instead") + + +#: The Hugr error type +error_ty = ht.Opaque( + id="error", bound=ht.TypeBound.Copyable, args=[], extension="prelude" +) + + +def build_panic( + # TODO: Change to `_DfBase[ops.DfParentOp]` once `_DfBase` is covariant + builder: _DfBase[ops.Case], + in_tys: ht.TypeRow, + out_tys: ht.TypeRow, + err: Wire, + *args: Wire, +) -> Wire: + """Builds a panic operation.""" + op = ops.Custom( + extension="prelude", + signature=ht.FunctionType([error_ty, *in_tys], out_tys), + name="panic", + args=[ + ht.SequenceArg([ht.TypeTypeArg(ty) for ty in in_tys]), + ht.SequenceArg([ht.TypeTypeArg(ty) for ty in out_tys]), + ], + ) + return builder.add_op(op, err, *args) + + +def build_error(builder: _DfBase[ops.Case], signal: int, msg: str) -> Wire: + """Constructs and loads a static error value.""" + val = hv.Extension( + name="ConstError", + typ=error_ty, + val={"signal": signal, "message": msg}, + extensions=["prelude"], + ) + return builder.load(builder.add_const(val)) + + +def build_array_set( + builder: _DfBase[ops.DfParentOp], + array: Wire, + array_ty: ht.Type, + idx: Wire, + idx_ty: ht.Type, + elem: Wire, + elem_ty: ht.Type, + length: ht.TypeArg, +) -> tuple[Wire, Wire]: + """Builds an array set operation, returning the original element.""" + op = ops.Custom( + extension="guppy.unsupported.array", + signature=ht.FunctionType([array_ty, idx_ty, elem_ty], [array_ty, elem_ty]), + name="get", + args=[length, ht.TypeTypeArg(elem_ty)], + ) + array, swapped_elem = iter(builder.add_op(op, array, idx, elem)) + return array, swapped_elem diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 111a0626..4840ae3a 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -20,6 +20,8 @@ UnsupportedChecker, ) from guppylang.prelude._internal.compiler import ( + ArrayGetitemCompiler, + ArraySetitemCompiler, FloatBoolCompiler, FloatDivmodCompiler, FloatFloordivCompiler, @@ -36,7 +38,6 @@ linst_op, list_op, logic_op, - type_arg, ) from guppylang.tys.builtin import ( array_type_def, @@ -654,13 +655,11 @@ def sort(self: linst[T]) -> None: ... @guppy.extend_type(builtins, array_type_def) class Array: - @guppy.hugr_op( - builtins, - custom_op( - "ArrayGet", args=[int_arg(), type_arg()], variable_remap={0: 1, 1: 0} - ), - ) - def __getitem__(self: array[T, n], idx: int) -> T: ... + @guppy.custom(builtins, ArrayGetitemCompiler()) + def __getitem__(self: array[L, n] @ inout, idx: int) -> L: ... + + @guppy.custom(builtins, ArraySetitemCompiler()) + def __setitem__(self: array[L, n] @ inout, idx: int, value: L) -> None: ... @guppy.custom(builtins, checker=ArrayLenChecker()) def __len__(self: array[T, n]) -> int: ... diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 427ea1e3..70159ee6 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -133,10 +133,17 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: [ty_arg, len_arg] = args assert isinstance(ty_arg, TypeArg) assert isinstance(len_arg, ConstArg) + # Linear elements are turned into an optional to enable unsafe indexing. + # See `ArrayGetitemCompiler` for details. + elem_ty: ht.Type + if ty_arg.ty.linear: + elem_ty = ht.Sum([[ty_arg.ty.to_hugr()], []]) + else: + elem_ty = ty_arg.ty.to_hugr() return ht.Opaque( extension="prelude", id="array", - args=[len_arg.to_hugr(), ty_arg.to_hugr()], + args=[len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)], bound=ty_arg.ty.hugr_bound, ) diff --git a/tests/error/array_errors/linear_index.err b/tests/error/array_errors/linear_index.err deleted file mode 100644 index dfa0a247..00000000 --- a/tests/error/array_errors/linear_index.err +++ /dev/null @@ -1,7 +0,0 @@ -Guppy compilation failed. Error in file $FILE:14 - -12: @guppy(module) -13: def main(qs: array[qubit, 42]) -> int: -14: return qs[0] - ^^ -GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall n, T: nat. (array[T, n], int) -> T` with linear type `qubit` diff --git a/tests/error/array_errors/subscript_after_use.err b/tests/error/array_errors/subscript_after_use.err new file mode 100644 index 00000000..098b28ed --- /dev/null +++ b/tests/error/array_errors/subscript_after_use.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main(qs: array[qubit, 42]) -> array[qubit, 42]: +18: return foo(qs, qs[0]) + ^^ +GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:15) diff --git a/tests/error/array_errors/subscript_after_use.py b/tests/error/array_errors/subscript_after_use.py new file mode 100644 index 00000000..b05d31a9 --- /dev/null +++ b/tests/error/array_errors/subscript_after_use.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array, inout +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def foo(qs: array[qubit, 42], q: qubit @inout) -> array[qubit, 42]: ... + + +@guppy(module) +def main(qs: array[qubit, 42]) -> array[qubit, 42]: + return foo(qs, qs[0]) + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/subscript_drop.err b/tests/error/array_errors/subscript_drop.err new file mode 100644 index 00000000..448abbbc --- /dev/null +++ b/tests/error/array_errors/subscript_drop.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main() -> qubit: +18: return foo()[0] + ^^^^^^^^ +GuppyTypeError: Remaining linear items with type `qubit` are not used diff --git a/tests/error/array_errors/subscript_drop.py b/tests/error/array_errors/subscript_drop.py new file mode 100644 index 00000000..8775a1d4 --- /dev/null +++ b/tests/error/array_errors/subscript_drop.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def foo() -> array[qubit, 10]: ... + + +@guppy(module) +def main() -> qubit: + return foo()[0] + + +module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/subscript_non_inout.err b/tests/error/array_errors/subscript_non_inout.err new file mode 100644 index 00000000..55df1950 --- /dev/null +++ b/tests/error/array_errors/subscript_non_inout.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:14 + +12: @guppy(module) +13: def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]: +14: q = qs[0] + ^^^^^ +GuppyError: Subscripting on expression with linear type `array[qubit, 42]` is only allowed in `@inout` position diff --git a/tests/error/array_errors/linear_index.py b/tests/error/array_errors/subscript_non_inout.py similarity index 75% rename from tests/error/array_errors/linear_index.py rename to tests/error/array_errors/subscript_non_inout.py index a99dff87..0455bcfb 100644 --- a/tests/error/array_errors/linear_index.py +++ b/tests/error/array_errors/subscript_non_inout.py @@ -10,8 +10,9 @@ @guppy(module) -def main(qs: array[qubit, 42]) -> int: - return qs[0] +def main(qs: array[qubit, 42]) -> tuple[qubit, array[qubit, 42]]: + q = qs[0] + return q, qs module.compile() \ No newline at end of file diff --git a/tests/error/array_errors/use_after_subscript.err b/tests/error/array_errors/use_after_subscript.err new file mode 100644 index 00000000..566f65de --- /dev/null +++ b/tests/error/array_errors/use_after_subscript.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def main(qs: array[qubit, 42]) -> array[qubit, 42]: +18: return foo(qs[0], qs) + ^^^^^ +GuppyError: Variable `qs` with linear type `array[qubit, 42]` was already used (at 18:22) diff --git a/tests/error/array_errors/use_after_subscript.py b/tests/error/array_errors/use_after_subscript.py new file mode 100644 index 00000000..938686f7 --- /dev/null +++ b/tests/error/array_errors/use_after_subscript.py @@ -0,0 +1,21 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array, inout +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def foo(q: qubit @inout, qs: array[qubit, 42]) -> array[qubit, 42]: ... + + +@guppy(module) +def main(qs: array[qubit, 42]) -> array[qubit, 42]: + return foo(qs[0], qs) + + +module.compile() \ No newline at end of file diff --git a/tests/error/inout_errors/subscript_not_setable.err b/tests/error/inout_errors/subscript_not_setable.err new file mode 100644 index 00000000..510cfca5 --- /dev/null +++ b/tests/error/inout_errors/subscript_not_setable.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:24 + +22: @guppy(module) +23: def test(c: MyImmutableContainer) -> MyImmutableContainer: +24: foo(c[0]) + ^^^^ +GuppyTypeError: Expression of type `MyImmutableContainer` is not allowed in a subscripted `@inout` position since it does not implement the `__setitem__` method diff --git a/tests/error/inout_errors/subscript_not_setable.py b/tests/error/inout_errors/subscript_not_setable.py new file mode 100644 index 00000000..ca592bc5 --- /dev/null +++ b/tests/error/inout_errors/subscript_not_setable.py @@ -0,0 +1,28 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout +from guppylang.prelude.quantum import qubit, quantum + +module = GuppyModule("test") +module.load_all(quantum) + + +@guppy.declare(module) +def foo(q: qubit @inout) -> None: ... + + +@guppy.struct(module) +class MyImmutableContainer: + q: qubit + + @guppy.declare(module) + def __getitem__(self: "MyImmutableContainer" @inout, idx: int) -> qubit: ... + + +@guppy(module) +def test(c: MyImmutableContainer) -> MyImmutableContainer: + foo(c[0]) + return c + + +module.compile() diff --git a/tests/error/type_errors/not_subscriptable.err b/tests/error/type_errors/not_subscriptable.err new file mode 100644 index 00000000..0611a3bf --- /dev/null +++ b/tests/error/type_errors/not_subscriptable.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy(module) +8: def foo(x: int) -> None: +9: x[0] + ^ +GuppyTypeError: Expression of type `int` is not subscriptable diff --git a/tests/error/type_errors/not_subscriptable.py b/tests/error/type_errors/not_subscriptable.py new file mode 100644 index 00000000..a93e2961 --- /dev/null +++ b/tests/error/type_errors/not_subscriptable.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy(module) +def foo(x: int) -> None: + x[0] + + +module.compile() diff --git a/tests/error/type_errors/subscript_bad_item.err b/tests/error/type_errors/subscript_bad_item.err new file mode 100644 index 00000000..7337c3b0 --- /dev/null +++ b/tests/error/type_errors/subscript_bad_item.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy(module) +9: def foo(xs: array[int, 42]) -> int: +10: return xs[1.0] + ^^^ +GuppyTypeError: Expected argument of type `int`, got `float` diff --git a/tests/error/type_errors/subscript_bad_item.py b/tests/error/type_errors/subscript_bad_item.py new file mode 100644 index 00000000..ee4e0a6e --- /dev/null +++ b/tests/error/type_errors/subscript_bad_item.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import array + +module = GuppyModule("test") + + +@guppy(module) +def foo(xs: array[int, 42]) -> int: + return xs[1.0] + + +module.compile() diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 12481ac6..ea62087d 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -1,11 +1,15 @@ +import pytest from hugr import ops from hugr.std.int import IntVal from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.prelude.builtins import array +from guppylang.prelude.builtins import array, inout from tests.util import compile_guppy +from guppylang.prelude.quantum import qubit +import guppylang.prelude.quantum as quantum + def test_len(validate): module = GuppyModule("test") @@ -23,6 +27,7 @@ def main(xs: array[float, 42]) -> int: assert val.val.v == 42 +@pytest.mark.skip("Skipped until Hugr lowering is updated") def test_index(validate): @compile_guppy def main(xs: array[int, 5], i: int) -> int: @@ -56,3 +61,139 @@ def main(ys: array[int, 0]) -> array[array[int, 0], 2]: validate(main) + +def test_subscript_drop_rest(validate): + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.declare(module) + def foo() -> array[int, 10]: ... + + @guppy(module) + def main() -> int: + return foo()[0] + + validate(module.compile()) + + +def test_linear_subscript(validate): + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.declare(module) + def foo(q: qubit @inout) -> None: ... + + @guppy(module) + def main(qs: array[qubit, 42], i: int) -> array[qubit, 42]: + foo(qs[i]) + return qs + + validate(module.compile()) + + +def test_inout_subscript(validate): + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.declare(module) + def foo(q: qubit @inout) -> None: ... + + @guppy(module) + def main(qs: array[qubit, 42] @inout, i: int) -> None: + foo(qs[i]) + + validate(module.compile()) + + +def test_multi_subscripts(validate): + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.declare(module) + def foo(q1: qubit @inout, q2: qubit @inout) -> None: ... + + @guppy(module) + def main(qs: array[qubit, 42]) -> array[qubit, 42]: + foo(qs[0], qs[1]) + foo(qs[0], qs[0]) # Will panic at runtime + return qs + + validate(module.compile()) + + +def test_struct_array(validate): + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.struct(module) + class S: + q1: qubit + q2: qubit + + @guppy.declare(module) + def foo(q1: qubit @inout, q2: qubit @inout) -> None: ... + + @guppy(module) + def main(ss: array[S, 10]) -> array[S, 10]: + # This will panic at runtime :( + # To make this work, we would need to replace the qubits in the struct + # with `qubit | None` and write back `None` after `q1` has been extracted... + foo(ss[0].q1, ss[0].q2) + return ss + + validate(module.compile()) + + +def test_nested_subscripts(validate): + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.declare(module) + def foo(q: qubit @inout) -> None: ... + + @guppy.declare(module) + def bar( + q1: qubit @inout, q2: qubit @inout, q3: qubit @inout, q4: qubit @inout + ) -> None: ... + + @guppy(module) + def main(qs: array[array[qubit, 13], 42]) -> array[array[qubit, 13], 42]: + foo(qs[0][0]) + # The following should work *without* panicking at runtime! Accessing `qs[0][0]` + # replaces one qubit with `None` but puts everything back into `qs` before + # going to the next argument. + bar(qs[0][0], qs[0][1], qs[1][0], qs[1][1]) + return qs + + validate(module.compile()) + + +def test_struct_nested_subscript(validate): + module = GuppyModule("test") + module.load_all(quantum) + + @guppy.struct(module) + class C: + c: qubit + blah: int + + @guppy.struct(module) + class B: + ys: array[array[C, 10], 20] + foo: C + + @guppy.struct(module) + class A: + xs: array[B, 42] + bar: qubit + baz: tuple[B, B] + + @guppy.declare(module) + def foo(q1: qubit @inout) -> None: ... + + @guppy(module) + def main(a: A, i: int, j: int, k: int) -> A: + foo(a.xs[i].ys[j][k].c) + return a + + validate(module.compile()) diff --git a/tests/integration/test_list.py b/tests/integration/test_list.py index e90f1bfe..30a1a6d5 100644 --- a/tests/integration/test_list.py +++ b/tests/integration/test_list.py @@ -1,3 +1,5 @@ +import pytest + from tests.util import compile_guppy @@ -37,6 +39,7 @@ def test(xs: list[int]) -> list[int]: validate(test) +@pytest.mark.skip("Requires updating lists to use inout") def test_subscript(validate): @compile_guppy def test(xs: list[float], i: int) -> float: