diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index c79c9e97..cd04ca8f 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -16,7 +16,7 @@ from guppylang.definition.common import DefId from guppylang.error import GuppyError from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef -from guppylang.tys.parsing import type_from_ast +from guppylang.tys.parsing import type_from_ast, type_with_flags_from_ast from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType if TYPE_CHECKING: @@ -146,27 +146,14 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType inputs = [] input_names = [] for inp in func_def.args.args: - ty_ast = inp.annotation - if ty_ast is None: + if inp.annotation is None: raise GuppyError("Argument type must be annotated", inp) - flags = InputFlags.NoFlags - # Detect `@flag` argument annotations - # TODO: This doesn't work if the type annotation is a string forward ref. We - # should rethink how we handle these... - if isinstance(ty_ast, ast.BinOp) and isinstance(ty_ast.op, ast.MatMult): - ty = type_from_ast(ty_ast.left, globals, param_var_mapping) - match ty_ast.right: - case ast.Name(id="inout"): - if not ty.linear: - raise GuppyError( - f"Non-linear type `{ty}` cannot be annotated as `@inout`", - ty_ast.right, - ) - flags |= InputFlags.Inout - case _: - raise GuppyError("Invalid annotation", ty_ast.right) - else: - ty = type_from_ast(ty_ast, globals, param_var_mapping) + ty, flags = type_with_flags_from_ast(inp.annotation, globals, param_var_mapping) + if InputFlags.Inout in flags and not ty.linear: + raise GuppyError( + f"Non-linear type `{ty}` cannot be annotated as `@inout`", + inp.annotation, + ) inputs.append(FuncInput(ty, flags)) input_names.append(inp.arg) ret_type = type_from_ast(func_def.returns, globals, param_var_mapping) diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 79ffd375..61e43415 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -21,10 +21,9 @@ DefaultCallChecker, ) from guppylang.definition.parameter import ParamDef -from guppylang.definition.ty import TypeDef +from guppylang.definition.ty import FlaggedArgs, TypeDef, check_no_flags from guppylang.error import GuppyError, InternalGuppyError from guppylang.hugr_builder.hugr import OutPortV -from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter, check_all_args from guppylang.tys.parsing import type_from_ast from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, StructType, Type @@ -133,7 +132,7 @@ def parse(self, globals: Globals) -> "ParsedStructDef": return ParsedStructDef(self.id, self.name, cls_def, params, fields) def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> Type: raise InternalGuppyError("Tried to instantiate raw struct definition") @@ -163,9 +162,10 @@ def check(self, globals: Globals) -> "CheckedStructDef": ) def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> Type: """Checks if the struct can be instantiated with the given arguments.""" + args = check_no_flags(args, loc) check_all_args(self.params, args, self.name, loc) # Obtain a checked version of this struct definition so we can construct a # `StructType` instance @@ -187,9 +187,10 @@ class CheckedStructDef(TypeDef, CompiledDef): fields: Sequence[StructField] def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> Type: """Checks if the struct can be instantiated with the given arguments.""" + args = check_no_flags(args, loc) check_all_args(self.params, args, self.name, loc) return StructType(args, self) @@ -290,7 +291,7 @@ class DummyStructDef(TypeDef): def check_instantiate( self, - args: Sequence[Argument], + args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None, ) -> Type: diff --git a/guppylang/definition/ty.py b/guppylang/definition/ty.py index f430fe7d..c301707b 100644 --- a/guppylang/definition/ty.py +++ b/guppylang/definition/ty.py @@ -1,20 +1,24 @@ from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias from hugr.serialization import tys from guppylang.ast_util import AstNode from guppylang.definition.common import CompiledDef, Definition +from guppylang.error import GuppyError from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter, check_all_args -from guppylang.tys.ty import OpaqueType, Type +from guppylang.tys.ty import InputFlags, OpaqueType, Type if TYPE_CHECKING: from guppylang.checker.core import Globals +FlaggedArgs: TypeAlias = Sequence[tuple[Argument, InputFlags]] + + @dataclass(frozen=True) class TypeDef(Definition): """Abstract base class for type definitions.""" @@ -23,7 +27,7 @@ class TypeDef(Definition): @abstractmethod def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> Type: """Checks if the type definition can be instantiated with the given arguments. @@ -42,12 +46,23 @@ class OpaqueTypeDef(TypeDef, CompiledDef): bound: tys.TypeBound | None = None def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> OpaqueType: """Checks if the type definition can be instantiated with the given arguments. Returns the resulting concrete type or raises a user error if the arguments are invalid. """ + args = check_no_flags(args, loc) check_all_args(self.params, args, self.name, loc) return OpaqueType(args, self) + + +def check_no_flags(args: FlaggedArgs, loc: AstNode | None) -> list[Argument]: + """Checks that no argument to `check_instantiate` has any `@flags`.""" + for _, flags in args: + if flags != InputFlags.NoFlags: + raise GuppyError( + "`@` type annotations are not allowed in this position", loc + ) + return [ty for ty, _ in args] diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 6de919dc..8ee95a52 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -1,13 +1,12 @@ from collections.abc import Sequence from dataclasses import dataclass, field -from itertools import repeat from typing import TYPE_CHECKING, Literal from hugr.serialization import tys from guppylang.ast_util import AstNode from guppylang.definition.common import DefId -from guppylang.definition.ty import OpaqueTypeDef, TypeDef +from guppylang.definition.ty import FlaggedArgs, OpaqueTypeDef, TypeDef, check_no_flags from guppylang.error import GuppyError from guppylang.tys.arg import Argument, ConstArg, TypeArg from guppylang.tys.param import ConstParam, TypeParam @@ -36,22 +35,28 @@ class _CallableTypeDef(TypeDef): name: Literal["Callable"] = field(default="Callable", init=False) def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> FunctionType: # We get the inputs/output as a flattened list: `args = [*inputs, output]`. if not args: raise GuppyError(f"Missing parameter for type `{self.name}`", loc) args = [ # TODO: Better error location - TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty - for i, arg in enumerate(args) - ] - *input_tys, output = args - inputs = [ - FuncInput(ty, flags) - for ty, flags in zip(input_tys, repeat(InputFlags.NoFlags), strict=False) + (TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty, flags) + for i, (arg, flags) in enumerate(args) ] - return FunctionType(list(inputs), output) + *inputs, (output_ty, output_flags) = args + for ty, flags in inputs: + if InputFlags.Inout in flags and not ty.linear: + raise GuppyError( + f"Non-linear type `{ty}` cannot be annotated as `@inout`", + loc, # TODO: Better error location + ) + if output_flags != InputFlags.NoFlags: + raise GuppyError( + "`@` type annotations are not allowed in this position", loc + ) + return FunctionType([FuncInput(ty, flags) for ty, flags in inputs], output_ty) @dataclass(frozen=True) @@ -64,8 +69,9 @@ class _TupleTypeDef(TypeDef): name: Literal["tuple"] = field(default="tuple", init=False) def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> TupleType: + args = check_no_flags(args, loc) # We accept any number of arguments. If users just write `tuple`, we give them # the empty tuple type. We just have to make sure that the args are of kind type args = [ @@ -86,7 +92,7 @@ class _NoneTypeDef(TypeDef): name: Literal["None"] = field(default="None", init=False) def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> NoneType: if args: raise GuppyError("Type `None` is not parameterized", loc) @@ -103,7 +109,7 @@ class _NumericTypeDef(TypeDef): ty: NumericType def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> NumericType: if args: raise GuppyError(f"Type `{self.name}` is not parameterized", loc) @@ -119,10 +125,10 @@ class _ListTypeDef(OpaqueTypeDef): """ def check_instantiate( - self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None ) -> OpaqueType: if len(args) == 1: - [arg] = args + [arg] = check_no_flags(args, loc) if isinstance(arg, TypeArg) and arg.ty.linear: raise GuppyError( "Type `list` cannot store linear data, use `linst` instead", loc diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index c73222b3..137f95a1 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -13,14 +13,14 @@ from guppylang.tys.arg import Argument, ConstArg, TypeArg from guppylang.tys.const import ConstValue from guppylang.tys.param import Parameter, TypeParam -from guppylang.tys.ty import NoneType, NumericType, TupleType, Type +from guppylang.tys.ty import InputFlags, NoneType, NumericType, TupleType, Type def arg_from_ast( node: AstNode, globals: Globals, param_var_mapping: dict[str, Parameter] | None = None, -) -> Argument: +) -> tuple[Argument, InputFlags]: """Turns an AST expression into an argument.""" # A single identifier if isinstance(node, ast.Name): @@ -30,7 +30,8 @@ def arg_from_ast( match globals[x]: # Either a defined type (e.g. `int`, `bool`, ...) case TypeDef() as defn: - return TypeArg(defn.check_instantiate([], globals, node)) + ty_arg = TypeArg(defn.check_instantiate([], globals, node)) + return ty_arg, InputFlags.NoFlags # Or a parameter (e.g. `T`, `n`, ...) case ParamDef() as defn: if param_var_mapping is None: @@ -39,7 +40,7 @@ def arg_from_ast( ) if x not in param_var_mapping: param_var_mapping[x] = defn.to_param(len(param_var_mapping)) - return param_var_mapping[x].to_bound() + return param_var_mapping[x].to_bound(), InputFlags.NoFlags case defn: raise GuppyError( f"Expected a type, got {defn.description} `{defn.name}`", node @@ -70,7 +71,7 @@ def arg_from_ast( for arg_node in arg_nodes ] ty = defn.check_instantiate(args, globals, node) - return TypeArg(ty) + return TypeArg(ty), InputFlags.NoFlags # We don't allow parametrised variables like `T[int]` if isinstance(defn, ParamDef): raise GuppyError( @@ -79,16 +80,26 @@ def arg_from_ast( node, ) + # An annotated argument, e.g. `int @inout` + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult): + arg, flags = arg_from_ast(node.left, globals, param_var_mapping) + match node.right: + case ast.Name(id="inout"): + flags |= InputFlags.Inout + case _: + raise GuppyError("Invalid annotation", node.right) + return arg, flags + # We allow tuple types to be written as `(int, bool)` if isinstance(node, ast.Tuple): ty = TupleType( [type_from_ast(el, globals, param_var_mapping) for el in node.elts] ) - return TypeArg(ty) + return TypeArg(ty), InputFlags.NoFlags # `None` is represented as a `ast.Constant` node with value `None` if isinstance(node, ast.Constant) and node.value is None: - return TypeArg(NoneType()) + return TypeArg(NoneType()), InputFlags.NoFlags # Integer literals are turned into nat args since these are the only ones we support # right now. @@ -98,7 +109,7 @@ def arg_from_ast( # `ast.UnaryOp` negation of a `ast.Constant(5)` assert node.value >= 0 nat_ty = NumericType(NumericType.Kind.Nat) - return ConstArg(ConstValue(nat_ty, node.value)) + return ConstArg(ConstValue(nat_ty, node.value)), InputFlags.NoFlags # Finally, we also support delayed annotations in strings if isinstance(node, ast.Constant) and isinstance(node.value, str): @@ -122,15 +133,27 @@ def arg_from_ast( _type_param = TypeParam(0, "T", True) +def type_with_flags_from_ast( + node: AstNode, + globals: Globals, + param_var_mapping: dict[str, Parameter] | None = None, +) -> tuple[Type, InputFlags]: + """Turns an AST expression into a Guppy type possibly annotated with @flags.""" + # Parse an argument and check that it's valid for a `TypeParam` + arg, flags = arg_from_ast(node, globals, param_var_mapping) + return _type_param.check_arg(arg, node).ty, flags + + def type_from_ast( node: AstNode, globals: Globals, param_var_mapping: dict[str, Parameter] | None = None, ) -> Type: """Turns an AST expression into a Guppy type.""" - # Parse an argument and check that it's valid for a `TypeParam` - arg = arg_from_ast(node, globals, param_var_mapping) - return _type_param.check_arg(arg, node).ty + ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping) + if flags != InputFlags.NoFlags: + raise GuppyError("`@` type annotations are not allowed in this position", node) + return ty def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: diff --git a/tests/error/inout_errors/nonlinear.err b/tests/error/inout_errors/nonlinear.err index 0e6c0e26..b7de73c2 100644 --- a/tests/error/inout_errors/nonlinear.err +++ b/tests/error/inout_errors/nonlinear.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:11 9: @guppy.declare(module) 10: def foo(x: int @inout) -> qubit: ... - ^^^^^ + ^^^^^^^^^^ GuppyError: Non-linear type `int` cannot be annotated as `@inout` diff --git a/tests/error/inout_errors/nonlinear_callable.err b/tests/error/inout_errors/nonlinear_callable.err new file mode 100644 index 00000000..4df85b50 --- /dev/null +++ b/tests/error/inout_errors/nonlinear_callable.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy.declare(module) +11: def foo(f: Callable[[int @inout], None]) -> None: ... + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Non-linear type `int` cannot be annotated as `@inout` diff --git a/tests/error/inout_errors/nonlinear_callable.py b/tests/error/inout_errors/nonlinear_callable.py new file mode 100644 index 00000000..ed21ce94 --- /dev/null +++ b/tests/error/inout_errors/nonlinear_callable.py @@ -0,0 +1,15 @@ +from typing import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout + + +module = GuppyModule("test") + + +@guppy.declare(module) +def foo(f: Callable[[int @inout], None]) -> None: ... + + +module.compile() diff --git a/tests/error/misc_errors/nested_arg_flag.err b/tests/error/misc_errors/nested_arg_flag.err new file mode 100644 index 00000000..4f06ad67 --- /dev/null +++ b/tests/error/misc_errors/nested_arg_flag.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy.declare(module) +11: def foo(x: list[qubit @inout]) -> qubit: ... + ^^^^^^^^^^^^^^^^^^ +GuppyError: `@` type annotations are not allowed in this position diff --git a/tests/error/misc_errors/nested_arg_flag.py b/tests/error/misc_errors/nested_arg_flag.py new file mode 100644 index 00000000..fb8961a9 --- /dev/null +++ b/tests/error/misc_errors/nested_arg_flag.py @@ -0,0 +1,15 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout +from guppylang.prelude.quantum import quantum, qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo(x: list[qubit @inout]) -> qubit: ... + + +module.compile() diff --git a/tests/error/misc_errors/return_flag.err b/tests/error/misc_errors/return_flag.err new file mode 100644 index 00000000..096ad4a6 --- /dev/null +++ b/tests/error/misc_errors/return_flag.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: @guppy.declare(module) +11: def foo() -> qubit @inout: ... + ^^^^^^^^^^^^ +GuppyError: `@` type annotations are not allowed in this position diff --git a/tests/error/misc_errors/return_flag.py b/tests/error/misc_errors/return_flag.py new file mode 100644 index 00000000..6191abc5 --- /dev/null +++ b/tests/error/misc_errors/return_flag.py @@ -0,0 +1,15 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout +from guppylang.prelude.quantum import quantum, qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo() -> qubit @inout: ... + + +module.compile() diff --git a/tests/error/misc_errors/return_flag_callable.err b/tests/error/misc_errors/return_flag_callable.err new file mode 100644 index 00000000..3c69888d --- /dev/null +++ b/tests/error/misc_errors/return_flag_callable.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:14 + +12: @guppy.declare(module) +13: def foo(f: Callable[[], qubit @inout]) -> None: ... + ^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: `@` type annotations are not allowed in this position diff --git a/tests/error/misc_errors/return_flag_callable.py b/tests/error/misc_errors/return_flag_callable.py new file mode 100644 index 00000000..88534cba --- /dev/null +++ b/tests/error/misc_errors/return_flag_callable.py @@ -0,0 +1,17 @@ +from typing import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.builtins import inout +from guppylang.prelude.quantum import quantum, qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def foo(f: Callable[[], qubit @inout]) -> None: ... + + +module.compile() diff --git a/tests/integration/test_inout.py b/tests/integration/test_inout.py index dfe7a512..5b4c789f 100644 --- a/tests/integration/test_inout.py +++ b/tests/integration/test_inout.py @@ -14,3 +14,13 @@ def test_declare(validate): def test(q: qubit @inout) -> qubit: ... validate(module.compile()) + + +def test_string_annotation(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def test(q: "qubit @inout") -> qubit: ... + + validate(module.compile())