-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Parse inout annotations in function signatures #316
Changes from 6 commits
2bd7fdf
b5a590a
c455a54
68ece53
11f2e25
bf070ab
77a4dc8
ff545a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,24 @@ | ||
from abc import abstractmethod | ||
from collections.abc import Callable, Sequence | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING | ||
from typing import TYPE_CHECKING, TypeAlias | ||
|
||
from hugr.serialization import tys | ||
|
||
from guppylang.ast_util import AstNode | ||
from guppylang.definition.common import CompiledDef, Definition | ||
from guppylang.error import GuppyError | ||
from guppylang.tys.arg import Argument | ||
from guppylang.tys.param import Parameter, check_all_args | ||
from guppylang.tys.ty import OpaqueType, Type | ||
from guppylang.tys.ty import InputFlags, OpaqueType, Type | ||
|
||
if TYPE_CHECKING: | ||
from guppylang.checker.core import Globals | ||
|
||
|
||
FlaggedArgs: TypeAlias = Sequence[tuple[Argument, InputFlags]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most uses have gone from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This no longer applied after ff545a7 |
||
|
||
|
||
@dataclass(frozen=True) | ||
class TypeDef(Definition): | ||
"""Abstract base class for type definitions.""" | ||
|
@@ -23,7 +27,7 @@ class TypeDef(Definition): | |
|
||
@abstractmethod | ||
def check_instantiate( | ||
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None | ||
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None | ||
) -> Type: | ||
"""Checks if the type definition can be instantiated with the given arguments. | ||
|
||
|
@@ -42,12 +46,23 @@ class OpaqueTypeDef(TypeDef, CompiledDef): | |
bound: tys.TypeBound | None = None | ||
|
||
def check_instantiate( | ||
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None | ||
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None | ||
) -> OpaqueType: | ||
"""Checks if the type definition can be instantiated with the given arguments. | ||
|
||
Returns the resulting concrete type or raises a user error if the arguments are | ||
invalid. | ||
""" | ||
args = check_no_flags(args, loc) | ||
check_all_args(self.params, args, self.name, loc) | ||
return OpaqueType(args, self) | ||
|
||
|
||
def check_no_flags(args: FlaggedArgs, loc: AstNode | None) -> list[Argument]: | ||
"""Checks that no argument to `check_instantiate` has any `@flags`.""" | ||
for _, flags in args: | ||
if flags != InputFlags.NoFlags: | ||
raise GuppyError( | ||
"`@` type annotations are not allowed in this position", loc | ||
) | ||
return [ty for ty, _ in args] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,16 @@ def py(*_args: Any) -> Any: | |
raise GuppyError("`py` can only by used in a Guppy context") | ||
|
||
|
||
class _Inout: | ||
"""Dummy class to support `@inout` annotations.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to allow However, have you thought about the "pythonic" alternative, which I believe would be Not seriously recommending you change, but how different / how much simpler might that make things, would it impact the design at all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I created #359 for discussion.
I think this is mostly a syntax concern. The design wouldn't be affected much imo |
||
|
||
def __rmatmul__(self, other: Any) -> Any: | ||
return other | ||
|
||
|
||
inout = _Inout() | ||
|
||
|
||
class nat: | ||
"""Class to import in order to use nats.""" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,14 +13,14 @@ | |
from guppylang.tys.arg import Argument, ConstArg, TypeArg | ||
from guppylang.tys.const import ConstValue | ||
from guppylang.tys.param import Parameter, TypeParam | ||
from guppylang.tys.ty import NoneType, NumericType, TupleType, Type | ||
from guppylang.tys.ty import InputFlags, NoneType, NumericType, TupleType, Type | ||
|
||
|
||
def arg_from_ast( | ||
node: AstNode, | ||
globals: Globals, | ||
param_var_mapping: dict[str, Parameter] | None = None, | ||
) -> Argument: | ||
) -> tuple[Argument, InputFlags]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. arg! (argh!) Sorry but I got getting really confused here; what is an I think
Generally I think I'd favour breaking the lengthy Which is to say that (my gut feeling is) you'd do better to push error checking downwards towards the parsing/leaves, rather than parsing whatever and then raising error messages (like, found function parameter of type 6). Not sure if that would help with the error locations. Then, say, Note you've already gone this route with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course we should also think about likely future annotations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion, yes
I added #358 to consider the rename
After reflecting a bit on this I think that special casing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The neatest solution here (I think) is adding a new TypeParam. However it's definitely worth having a think about what the future flags might be - Special-casing Callable might be fine in the short-term, tho, and maybe even in the longer term if it turns out we don't need similar specials for any of the others... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Agreed, I think I'll do that for now 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another benefit of special casing |
||
"""Turns an AST expression into an argument.""" | ||
# A single identifier | ||
if isinstance(node, ast.Name): | ||
|
@@ -30,7 +30,8 @@ def arg_from_ast( | |
match globals[x]: | ||
# Either a defined type (e.g. `int`, `bool`, ...) | ||
case TypeDef() as defn: | ||
return TypeArg(defn.check_instantiate([], globals, node)) | ||
ty_arg = TypeArg(defn.check_instantiate([], globals, node)) | ||
return ty_arg, InputFlags.NoFlags | ||
# Or a parameter (e.g. `T`, `n`, ...) | ||
case ParamDef() as defn: | ||
if param_var_mapping is None: | ||
|
@@ -39,7 +40,7 @@ def arg_from_ast( | |
) | ||
if x not in param_var_mapping: | ||
param_var_mapping[x] = defn.to_param(len(param_var_mapping)) | ||
return param_var_mapping[x].to_bound() | ||
return param_var_mapping[x].to_bound(), InputFlags.NoFlags | ||
case defn: | ||
raise GuppyError( | ||
f"Expected a type, got {defn.description} `{defn.name}`", node | ||
|
@@ -70,7 +71,7 @@ def arg_from_ast( | |
for arg_node in arg_nodes | ||
] | ||
ty = defn.check_instantiate(args, globals, node) | ||
return TypeArg(ty) | ||
return TypeArg(ty), InputFlags.NoFlags | ||
# We don't allow parametrised variables like `T[int]` | ||
if isinstance(defn, ParamDef): | ||
raise GuppyError( | ||
|
@@ -79,16 +80,26 @@ def arg_from_ast( | |
node, | ||
) | ||
|
||
# An annotated argument, e.g. `int @inout` | ||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult): | ||
arg, flags = arg_from_ast(node.left, globals, param_var_mapping) | ||
match node.right: | ||
case ast.Name(id="inout"): | ||
flags |= InputFlags.Inout | ||
case _: | ||
raise GuppyError("Invalid annotation", node.right) | ||
return arg, flags | ||
|
||
# We allow tuple types to be written as `(int, bool)` | ||
if isinstance(node, ast.Tuple): | ||
ty = TupleType( | ||
[type_from_ast(el, globals, param_var_mapping) for el in node.elts] | ||
) | ||
return TypeArg(ty) | ||
return TypeArg(ty), InputFlags.NoFlags | ||
|
||
# `None` is represented as a `ast.Constant` node with value `None` | ||
if isinstance(node, ast.Constant) and node.value is None: | ||
return TypeArg(NoneType()) | ||
return TypeArg(NoneType()), InputFlags.NoFlags | ||
|
||
# Integer literals are turned into nat args since these are the only ones we support | ||
# right now. | ||
|
@@ -98,7 +109,7 @@ def arg_from_ast( | |
# `ast.UnaryOp` negation of a `ast.Constant(5)` | ||
assert node.value >= 0 | ||
nat_ty = NumericType(NumericType.Kind.Nat) | ||
return ConstArg(ConstValue(nat_ty, node.value)) | ||
return ConstArg(ConstValue(nat_ty, node.value)), InputFlags.NoFlags | ||
|
||
# Finally, we also support delayed annotations in strings | ||
if isinstance(node, ast.Constant) and isinstance(node.value, str): | ||
|
@@ -122,15 +133,27 @@ def arg_from_ast( | |
_type_param = TypeParam(0, "T", True) | ||
|
||
|
||
def type_with_flags_from_ast( | ||
node: AstNode, | ||
globals: Globals, | ||
param_var_mapping: dict[str, Parameter] | None = None, | ||
) -> tuple[Type, InputFlags]: | ||
"""Turns an AST expression into a Guppy type possibly annotated with @flags.""" | ||
# Parse an argument and check that it's valid for a `TypeParam` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's amusing to think that we could have got this far with |
||
arg, flags = arg_from_ast(node, globals, param_var_mapping) | ||
return _type_param.check_arg(arg, node).ty, flags | ||
|
||
|
||
def type_from_ast( | ||
node: AstNode, | ||
globals: Globals, | ||
param_var_mapping: dict[str, Parameter] | None = None, | ||
) -> Type: | ||
"""Turns an AST expression into a Guppy type.""" | ||
# Parse an argument and check that it's valid for a `TypeParam` | ||
arg = arg_from_ast(node, globals, param_var_mapping) | ||
return _type_param.check_arg(arg, node).ty | ||
ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping) | ||
if flags != InputFlags.NoFlags: | ||
raise GuppyError("`@` type annotations are not allowed in this position", node) | ||
return ty | ||
|
||
|
||
def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:11 | ||
|
||
9: @guppy.declare(module) | ||
10: def foo(x: int @inout) -> qubit: ... | ||
^^^^^^^^^^ | ||
GuppyError: Non-linear type `int` cannot be annotated as `@inout` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
from guppylang.prelude.builtins import inout | ||
from guppylang.prelude.quantum import qubit | ||
|
||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy.declare(module) | ||
def foo(x: int @inout) -> qubit: ... | ||
|
||
|
||
module.compile() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:12 | ||
|
||
10: @guppy.declare(module) | ||
11: def foo(f: Callable[[int @inout], None]) -> None: ... | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
GuppyError: Non-linear type `int` cannot be annotated as `@inout` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does seem a bit like, we have two places that define the set(s) of flags legal in the signature of a function -
_CallableTypeDef
also makes similar checks. If you parsedtype_with_flags_from_ast
for the return type too, you could call some single checking function (maybe it takes alist[FlaggedArg]
and a singleFlaggedArg
and then_CallableTypeDef
first separates out the last element of the list) that issues appropriate errors and returns some appropriate e.g.(list[TypeArg],TypeArg)