Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Parse inout annotations in function signatures #316

Merged
merged 8 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from guppylang.definition.common import DefId
from guppylang.error import GuppyError
from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.parsing import type_from_ast, type_with_flags_from_ast
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType

if TYPE_CHECKING:
Expand Down Expand Up @@ -148,8 +148,13 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
for inp in func_def.args.args:
if inp.annotation is None:
raise GuppyError("Argument type must be annotated", inp)
ty = type_from_ast(inp.annotation, globals, param_var_mapping)
inputs.append(FuncInput(ty, InputFlags.NoFlags))
ty, flags = type_with_flags_from_ast(inp.annotation, globals, param_var_mapping)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does seem a bit like, we have two places that define the set(s) of flags legal in the signature of a function - _CallableTypeDef also makes similar checks. If you parsed type_with_flags_from_ast for the return type too, you could call some single checking function (maybe it takes a list[FlaggedArg] and a single FlaggedArg and then _CallableTypeDef first separates out the last element of the list) that issues appropriate errors and returns some appropriate e.g. (list[TypeArg],TypeArg)

if InputFlags.Inout in flags and not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
inp.annotation,
)
inputs.append(FuncInput(ty, flags))
input_names.append(inp.arg)
ret_type = type_from_ast(func_def.returns, globals, param_var_mapping)

Expand Down
13 changes: 7 additions & 6 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
DefaultCallChecker,
)
from guppylang.definition.parameter import ParamDef
from guppylang.definition.ty import TypeDef
from guppylang.definition.ty import FlaggedArgs, TypeDef, check_no_flags
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, StructType, Type
Expand Down Expand Up @@ -133,7 +132,7 @@ def parse(self, globals: Globals) -> "ParsedStructDef":
return ParsedStructDef(self.id, self.name, cls_def, params, fields)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
raise InternalGuppyError("Tried to instantiate raw struct definition")

Expand Down Expand Up @@ -163,9 +162,10 @@ def check(self, globals: Globals) -> "CheckedStructDef":
)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
"""Checks if the struct can be instantiated with the given arguments."""
args = check_no_flags(args, loc)
check_all_args(self.params, args, self.name, loc)
# Obtain a checked version of this struct definition so we can construct a
# `StructType` instance
Expand All @@ -187,9 +187,10 @@ class CheckedStructDef(TypeDef, CompiledDef):
fields: Sequence[StructField]

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
"""Checks if the struct can be instantiated with the given arguments."""
args = check_no_flags(args, loc)
check_all_args(self.params, args, self.name, loc)
return StructType(args, self)

Expand Down Expand Up @@ -290,7 +291,7 @@ class DummyStructDef(TypeDef):

def check_instantiate(
self,
args: Sequence[Argument],
args: FlaggedArgs,
globals: "Globals",
loc: AstNode | None = None,
) -> Type:
Expand Down
23 changes: 19 additions & 4 deletions guppylang/definition/ty.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeAlias

from hugr.serialization import tys

from guppylang.ast_util import AstNode
from guppylang.definition.common import CompiledDef, Definition
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.ty import OpaqueType, Type
from guppylang.tys.ty import InputFlags, OpaqueType, Type

if TYPE_CHECKING:
from guppylang.checker.core import Globals


FlaggedArgs: TypeAlias = Sequence[tuple[Argument, InputFlags]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most uses have gone from Sequence[Argument] to FlaggedArgs, but you do have a couple of tuple[Argument, InputFlags] elsewhere. Maybe define FlaggedArg = tuple[...] and use that (so Sequence[FlaggedArg] in many places)....albeit, elsewhere I suggest a bigger redifinition of FlaggedArg ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This no longer applied after ff545a7



@dataclass(frozen=True)
class TypeDef(Definition):
"""Abstract base class for type definitions."""
Expand All @@ -23,7 +27,7 @@ class TypeDef(Definition):

@abstractmethod
def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
"""Checks if the type definition can be instantiated with the given arguments.

Expand All @@ -42,12 +46,23 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
bound: tys.TypeBound | None = None

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
"""Checks if the type definition can be instantiated with the given arguments.

Returns the resulting concrete type or raises a user error if the arguments are
invalid.
"""
args = check_no_flags(args, loc)
check_all_args(self.params, args, self.name, loc)
return OpaqueType(args, self)


def check_no_flags(args: FlaggedArgs, loc: AstNode | None) -> list[Argument]:
"""Checks that no argument to `check_instantiate` has any `@flags`."""
for _, flags in args:
if flags != InputFlags.NoFlags:
raise GuppyError(
"`@` type annotations are not allowed in this position", loc
)
return [ty for ty, _ in args]
10 changes: 10 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def py(*_args: Any) -> Any:
raise GuppyError("`py` can only by used in a Guppy context")


class _Inout:
"""Dummy class to support `@inout` annotations."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to allow def foo(x: int @inout), since Python (in its infinite wisdom cough splutter boo) disallows def foo(x: @inout int). Obviously, it's fine to do matrix multiplication inside a type annotation but not invoke a decorator....

However, have you thought about the "pythonic" alternative, which I believe would be def foo(x: Annotated[int, "inout"]?@inout is a bastardization of python syntax anyway so what would users think? I admit Annotated is rather long - what about def foo(x: Inout[int]) ? (No way, I hear you say - and I see the downside - maybe that lengthy Annotated isn't sooo bad, then....)

Not seriously recommending you change, but how different / how much simpler might that make things, would it impact the design at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #359 for discussion.

Not seriously recommending you change, but how different / how much simpler might that make things, would it impact the design at all?

I think this is mostly a syntax concern. The design wouldn't be affected much imo


def __rmatmul__(self, other: Any) -> Any:
return other


inout = _Inout()


class nat:
"""Class to import in order to use nats."""

Expand Down
38 changes: 22 additions & 16 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from itertools import repeat
from typing import TYPE_CHECKING, Literal

from hugr.serialization import tys

from guppylang.ast_util import AstNode
from guppylang.definition.common import DefId
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.definition.ty import FlaggedArgs, OpaqueTypeDef, TypeDef, check_no_flags
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.param import ConstParam, TypeParam
Expand Down Expand Up @@ -36,22 +35,28 @@ class _CallableTypeDef(TypeDef):
name: Literal["Callable"] = field(default="Callable", init=False)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> FunctionType:
# We get the inputs/output as a flattened list: `args = [*inputs, output]`.
if not args:
raise GuppyError(f"Missing parameter for type `{self.name}`", loc)
args = [
# TODO: Better error location
TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty
for i, arg in enumerate(args)
]
*input_tys, output = args
inputs = [
FuncInput(ty, flags)
for ty, flags in zip(input_tys, repeat(InputFlags.NoFlags), strict=False)
(TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty, flags)
for i, (arg, flags) in enumerate(args)
]
return FunctionType(list(inputs), output)
*inputs, (output_ty, output_flags) = args
for ty, flags in inputs:
if InputFlags.Inout in flags and not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
loc, # TODO: Better error location
)
if output_flags != InputFlags.NoFlags:
raise GuppyError(
"`@` type annotations are not allowed in this position", loc
)
return FunctionType([FuncInput(ty, flags) for ty, flags in inputs], output_ty)


@dataclass(frozen=True)
Expand All @@ -64,8 +69,9 @@ class _TupleTypeDef(TypeDef):
name: Literal["tuple"] = field(default="tuple", init=False)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> TupleType:
args = check_no_flags(args, loc)
# We accept any number of arguments. If users just write `tuple`, we give them
# the empty tuple type. We just have to make sure that the args are of kind type
args = [
Expand All @@ -86,7 +92,7 @@ class _NoneTypeDef(TypeDef):
name: Literal["None"] = field(default="None", init=False)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> NoneType:
if args:
raise GuppyError("Type `None` is not parameterized", loc)
Expand All @@ -103,7 +109,7 @@ class _NumericTypeDef(TypeDef):
ty: NumericType

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> NumericType:
if args:
raise GuppyError(f"Type `{self.name}` is not parameterized", loc)
Expand All @@ -119,10 +125,10 @@ class _ListTypeDef(OpaqueTypeDef):
"""

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
if len(args) == 1:
[arg] = args
[arg] = check_no_flags(args, loc)
if isinstance(arg, TypeArg) and arg.ty.linear:
raise GuppyError(
"Type `list` cannot store linear data, use `linst` instead", loc
Expand Down
45 changes: 34 additions & 11 deletions guppylang/tys/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.const import ConstValue
from guppylang.tys.param import Parameter, TypeParam
from guppylang.tys.ty import NoneType, NumericType, TupleType, Type
from guppylang.tys.ty import InputFlags, NoneType, NumericType, TupleType, Type


def arg_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> Argument:
) -> tuple[Argument, InputFlags]:
Copy link
Contributor

@acl-cqc acl-cqc Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg! (argh!) Sorry but I got getting really confused here; what is an Argument? Is it a value Hugr static TypeArg? (It certainly looks like one) - but then we wouldn't want it to be inout or have InputFlags, would we? More comments would really help, maybe also think about renaming Argument->StaticArg, say.

I think arg_from_ast parses both TypeArgs (Hugr TypeArgs of kind TypeParam::Type), and ConstArgs (Hugr TypeArgs of other kinds). But, hang on, ConstArgs are not function parameters (that'd be like def foo(x: 6) !). And @inout is only appropriate for the top/outermost level argument to a def, or to Callable, yet we make life difficult for all the other check_instantiates in order to fit in Callable (e.g. list[int @inout]) - I see 5 calls to check_no_flags, plus a couple that have no args to have any flags.

  • I think passing the flag by subclassing could be an easier route, you might still have an equivalent of check_no_flags but I reckon it might just fall out of normal check_all_args or checking-against-TypeParam (that is: make a normal Hugr TypeParam::Type unable to accept an InoutTypeArg, but Callable has a special TypeParam::MaybeInOutType that can accept both InoutTypeArg and normal TypeArg).
  • Alternatively....this'll be even more controversial....can you special-case Callable?

Generally I think I'd favour breaking the lengthy arg_from_ast up into flagged_arg_from_ast or (maybe_)inout_arg_from_ast that looks for BinOp with @ and then (usually) falls back to a shorter arg_from_ast that does not support that case, and thus, returns a TypeArg (| ConstArg) only. (Or returns Argument without separate flags, that'd be a first/lesser step from what's in the PR now).

Which is to say that (my gut feeling is) you'd do better to push error checking downwards towards the parsing/leaves, rather than parsing whatever and then raising error messages (like, found function parameter of type 6). Not sure if that would help with the error locations. Then, say, maybe_inout_arg_from_ast would (if not BinOp) call type_arg_from_ast (that returns TypeArg, and does not support ConstArg), and then there is const_or_type_arg_from_ast (used for arguments to parameterized types), that kind of thing.

Note you've already gone this route with type_from_arg vs type_with_flags_from_arg, so it's not thaaaat radical....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course we should also think about likely future annotations. list[MyStruct @frozen] is plausible, true....

Copy link
Collaborator Author

@mark-koch mark-koch Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg! (argh!) Sorry but I got getting really confused here; what is an Argument? Is it a value Hugr static TypeArg?

Sorry for the confusion, yes Argument is Guppy's version of static arguments to types. E.g. the type array[int, 5] has a TypeArg int and a ConstArg 5.

More comments would really help, maybe also think about renaming Argument->StaticArg, say.

I added #358 to consider the rename

And @inout is only appropriate for the top/outermost level argument to a def, or to Callable, yet we make life difficult for all the other check_instantiates in order to fit in Callable (e.g. list[int @inout]) [...] Alternatively....this'll be even more controversial....can you special-case Callable?

After reflecting a bit on this I think that special casing Callable is probaby the best solution in the near term. I initially wanted a general solution that can handle future flags like @frozen etc., but I agree it adds a lot of complexity and probably isn't worth it for now. I like your subclassing idea though, I'll probably go for that if we want more general flags in the future 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The neatest solution here (I think) is adding a new TypeParam. However it's definitely worth having a think about what the future flags might be - Frozen<T> might be a lot like T (i.e. acceptable anywhere a T is?), so probably needs special consideration, whereas inout is clearly a bit special (it's not really a property of the parameter but of the whole function signature).

Special-casing Callable might be fine in the short-term, tho, and maybe even in the longer term if it turns out we don't need similar specials for any of the others...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Special-casing Callable might be fine in the short-term, tho, and maybe even in the longer term if it turns out we don't need similar specials for any of the others...

Agreed, I think I'll do that for now 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another benefit of special casing Callable is that we can now properly handle the list syntax Callable[[Arg1, Arg2], Return] without the flattening hack

"""Turns an AST expression into an argument."""
# A single identifier
if isinstance(node, ast.Name):
Expand All @@ -30,7 +30,8 @@ def arg_from_ast(
match globals[x]:
# Either a defined type (e.g. `int`, `bool`, ...)
case TypeDef() as defn:
return TypeArg(defn.check_instantiate([], globals, node))
ty_arg = TypeArg(defn.check_instantiate([], globals, node))
return ty_arg, InputFlags.NoFlags
# Or a parameter (e.g. `T`, `n`, ...)
case ParamDef() as defn:
if param_var_mapping is None:
Expand All @@ -39,7 +40,7 @@ def arg_from_ast(
)
if x not in param_var_mapping:
param_var_mapping[x] = defn.to_param(len(param_var_mapping))
return param_var_mapping[x].to_bound()
return param_var_mapping[x].to_bound(), InputFlags.NoFlags
case defn:
raise GuppyError(
f"Expected a type, got {defn.description} `{defn.name}`", node
Expand Down Expand Up @@ -70,7 +71,7 @@ def arg_from_ast(
for arg_node in arg_nodes
]
ty = defn.check_instantiate(args, globals, node)
return TypeArg(ty)
return TypeArg(ty), InputFlags.NoFlags
# We don't allow parametrised variables like `T[int]`
if isinstance(defn, ParamDef):
raise GuppyError(
Expand All @@ -79,16 +80,26 @@ def arg_from_ast(
node,
)

# An annotated argument, e.g. `int @inout`
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
arg, flags = arg_from_ast(node.left, globals, param_var_mapping)
match node.right:
case ast.Name(id="inout"):
flags |= InputFlags.Inout
case _:
raise GuppyError("Invalid annotation", node.right)
return arg, flags

# We allow tuple types to be written as `(int, bool)`
if isinstance(node, ast.Tuple):
ty = TupleType(
[type_from_ast(el, globals, param_var_mapping) for el in node.elts]
)
return TypeArg(ty)
return TypeArg(ty), InputFlags.NoFlags

# `None` is represented as a `ast.Constant` node with value `None`
if isinstance(node, ast.Constant) and node.value is None:
return TypeArg(NoneType())
return TypeArg(NoneType()), InputFlags.NoFlags

# Integer literals are turned into nat args since these are the only ones we support
# right now.
Expand All @@ -98,7 +109,7 @@ def arg_from_ast(
# `ast.UnaryOp` negation of a `ast.Constant(5)`
assert node.value >= 0
nat_ty = NumericType(NumericType.Kind.Nat)
return ConstArg(ConstValue(nat_ty, node.value))
return ConstArg(ConstValue(nat_ty, node.value)), InputFlags.NoFlags

# Finally, we also support delayed annotations in strings
if isinstance(node, ast.Constant) and isinstance(node.value, str):
Expand All @@ -122,15 +133,27 @@ def arg_from_ast(
_type_param = TypeParam(0, "T", True)


def type_with_flags_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> tuple[Type, InputFlags]:
"""Turns an AST expression into a Guppy type possibly annotated with @flags."""
# Parse an argument and check that it's valid for a `TypeParam`
Copy link
Contributor

@acl-cqc acl-cqc Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's amusing to think that we could have got this far with x: 6 @inout ;-). But I'm reassured that you are making the right check!

arg, flags = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty, flags


def type_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> Type:
"""Turns an AST expression into a Guppy type."""
# Parse an argument and check that it's valid for a `TypeParam`
arg = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty
ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping)
if flags != InputFlags.NoFlags:
raise GuppyError("`@` type annotations are not allowed in this position", node)
return ty


def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]:
Expand Down
Empty file.
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:11

9: @guppy.declare(module)
10: def foo(x: int @inout) -> qubit: ...
^^^^^^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
14 changes: 14 additions & 0 deletions tests/error/inout_errors/nonlinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")


@guppy.declare(module)
def foo(x: int @inout) -> qubit: ...


module.compile()
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear_callable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:12

10: @guppy.declare(module)
11: def foo(f: Callable[[int @inout], None]) -> None: ...
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
Loading
Loading