Skip to content

Commit

Permalink
feat: Move flag parsing into arg_from_ast (#349)
Browse files Browse the repository at this point in the history
Closes #347 and closes #312
  • Loading branch information
mark-koch authored Jul 30, 2024
1 parent 11f2e25 commit bf070ab
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 59 deletions.
29 changes: 8 additions & 21 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from guppylang.definition.common import DefId
from guppylang.error import GuppyError
from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.parsing import type_from_ast, type_with_flags_from_ast
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType

if TYPE_CHECKING:
Expand Down Expand Up @@ -146,27 +146,14 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
inputs = []
input_names = []
for inp in func_def.args.args:
ty_ast = inp.annotation
if ty_ast is None:
if inp.annotation is None:
raise GuppyError("Argument type must be annotated", inp)
flags = InputFlags.NoFlags
# Detect `@flag` argument annotations
# TODO: This doesn't work if the type annotation is a string forward ref. We
# should rethink how we handle these...
if isinstance(ty_ast, ast.BinOp) and isinstance(ty_ast.op, ast.MatMult):
ty = type_from_ast(ty_ast.left, globals, param_var_mapping)
match ty_ast.right:
case ast.Name(id="inout"):
if not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
ty_ast.right,
)
flags |= InputFlags.Inout
case _:
raise GuppyError("Invalid annotation", ty_ast.right)
else:
ty = type_from_ast(ty_ast, globals, param_var_mapping)
ty, flags = type_with_flags_from_ast(inp.annotation, globals, param_var_mapping)
if InputFlags.Inout in flags and not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
inp.annotation,
)
inputs.append(FuncInput(ty, flags))
input_names.append(inp.arg)
ret_type = type_from_ast(func_def.returns, globals, param_var_mapping)
Expand Down
13 changes: 7 additions & 6 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
DefaultCallChecker,
)
from guppylang.definition.parameter import ParamDef
from guppylang.definition.ty import TypeDef
from guppylang.definition.ty import FlaggedArgs, TypeDef, check_no_flags
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, StructType, Type
Expand Down Expand Up @@ -133,7 +132,7 @@ def parse(self, globals: Globals) -> "ParsedStructDef":
return ParsedStructDef(self.id, self.name, cls_def, params, fields)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
raise InternalGuppyError("Tried to instantiate raw struct definition")

Expand Down Expand Up @@ -163,9 +162,10 @@ def check(self, globals: Globals) -> "CheckedStructDef":
)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
"""Checks if the struct can be instantiated with the given arguments."""
args = check_no_flags(args, loc)
check_all_args(self.params, args, self.name, loc)
# Obtain a checked version of this struct definition so we can construct a
# `StructType` instance
Expand All @@ -187,9 +187,10 @@ class CheckedStructDef(TypeDef, CompiledDef):
fields: Sequence[StructField]

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
"""Checks if the struct can be instantiated with the given arguments."""
args = check_no_flags(args, loc)
check_all_args(self.params, args, self.name, loc)
return StructType(args, self)

Expand Down Expand Up @@ -290,7 +291,7 @@ class DummyStructDef(TypeDef):

def check_instantiate(
self,
args: Sequence[Argument],
args: FlaggedArgs,
globals: "Globals",
loc: AstNode | None = None,
) -> Type:
Expand Down
23 changes: 19 additions & 4 deletions guppylang/definition/ty.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeAlias

from hugr.serialization import tys

from guppylang.ast_util import AstNode
from guppylang.definition.common import CompiledDef, Definition
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.ty import OpaqueType, Type
from guppylang.tys.ty import InputFlags, OpaqueType, Type

if TYPE_CHECKING:
from guppylang.checker.core import Globals


FlaggedArgs: TypeAlias = Sequence[tuple[Argument, InputFlags]]


@dataclass(frozen=True)
class TypeDef(Definition):
"""Abstract base class for type definitions."""
Expand All @@ -23,7 +27,7 @@ class TypeDef(Definition):

@abstractmethod
def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> Type:
"""Checks if the type definition can be instantiated with the given arguments.
Expand All @@ -42,12 +46,23 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
bound: tys.TypeBound | None = None

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
"""Checks if the type definition can be instantiated with the given arguments.
Returns the resulting concrete type or raises a user error if the arguments are
invalid.
"""
args = check_no_flags(args, loc)
check_all_args(self.params, args, self.name, loc)
return OpaqueType(args, self)


def check_no_flags(args: FlaggedArgs, loc: AstNode | None) -> list[Argument]:
"""Checks that no argument to `check_instantiate` has any `@flags`."""
for _, flags in args:
if flags != InputFlags.NoFlags:
raise GuppyError(
"`@` type annotations are not allowed in this position", loc
)
return [ty for ty, _ in args]
38 changes: 22 additions & 16 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from itertools import repeat
from typing import TYPE_CHECKING, Literal

from hugr.serialization import tys

from guppylang.ast_util import AstNode
from guppylang.definition.common import DefId
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.definition.ty import FlaggedArgs, OpaqueTypeDef, TypeDef, check_no_flags
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.param import ConstParam, TypeParam
Expand Down Expand Up @@ -36,22 +35,28 @@ class _CallableTypeDef(TypeDef):
name: Literal["Callable"] = field(default="Callable", init=False)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> FunctionType:
# We get the inputs/output as a flattened list: `args = [*inputs, output]`.
if not args:
raise GuppyError(f"Missing parameter for type `{self.name}`", loc)
args = [
# TODO: Better error location
TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty
for i, arg in enumerate(args)
]
*input_tys, output = args
inputs = [
FuncInput(ty, flags)
for ty, flags in zip(input_tys, repeat(InputFlags.NoFlags), strict=False)
(TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty, flags)
for i, (arg, flags) in enumerate(args)
]
return FunctionType(list(inputs), output)
*inputs, (output_ty, output_flags) = args
for ty, flags in inputs:
if InputFlags.Inout in flags and not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
loc, # TODO: Better error location
)
if output_flags != InputFlags.NoFlags:
raise GuppyError(
"`@` type annotations are not allowed in this position", loc
)
return FunctionType([FuncInput(ty, flags) for ty, flags in inputs], output_ty)


@dataclass(frozen=True)
Expand All @@ -64,8 +69,9 @@ class _TupleTypeDef(TypeDef):
name: Literal["tuple"] = field(default="tuple", init=False)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> TupleType:
args = check_no_flags(args, loc)
# We accept any number of arguments. If users just write `tuple`, we give them
# the empty tuple type. We just have to make sure that the args are of kind type
args = [
Expand All @@ -86,7 +92,7 @@ class _NoneTypeDef(TypeDef):
name: Literal["None"] = field(default="None", init=False)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> NoneType:
if args:
raise GuppyError("Type `None` is not parameterized", loc)
Expand All @@ -103,7 +109,7 @@ class _NumericTypeDef(TypeDef):
ty: NumericType

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> NumericType:
if args:
raise GuppyError(f"Type `{self.name}` is not parameterized", loc)
Expand All @@ -119,10 +125,10 @@ class _ListTypeDef(OpaqueTypeDef):
"""

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
self, args: FlaggedArgs, globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
if len(args) == 1:
[arg] = args
[arg] = check_no_flags(args, loc)
if isinstance(arg, TypeArg) and arg.ty.linear:
raise GuppyError(
"Type `list` cannot store linear data, use `linst` instead", loc
Expand Down
45 changes: 34 additions & 11 deletions guppylang/tys/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.const import ConstValue
from guppylang.tys.param import Parameter, TypeParam
from guppylang.tys.ty import NoneType, NumericType, TupleType, Type
from guppylang.tys.ty import InputFlags, NoneType, NumericType, TupleType, Type


def arg_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> Argument:
) -> tuple[Argument, InputFlags]:
"""Turns an AST expression into an argument."""
# A single identifier
if isinstance(node, ast.Name):
Expand All @@ -30,7 +30,8 @@ def arg_from_ast(
match globals[x]:
# Either a defined type (e.g. `int`, `bool`, ...)
case TypeDef() as defn:
return TypeArg(defn.check_instantiate([], globals, node))
ty_arg = TypeArg(defn.check_instantiate([], globals, node))
return ty_arg, InputFlags.NoFlags
# Or a parameter (e.g. `T`, `n`, ...)
case ParamDef() as defn:
if param_var_mapping is None:
Expand All @@ -39,7 +40,7 @@ def arg_from_ast(
)
if x not in param_var_mapping:
param_var_mapping[x] = defn.to_param(len(param_var_mapping))
return param_var_mapping[x].to_bound()
return param_var_mapping[x].to_bound(), InputFlags.NoFlags
case defn:
raise GuppyError(
f"Expected a type, got {defn.description} `{defn.name}`", node
Expand Down Expand Up @@ -70,7 +71,7 @@ def arg_from_ast(
for arg_node in arg_nodes
]
ty = defn.check_instantiate(args, globals, node)
return TypeArg(ty)
return TypeArg(ty), InputFlags.NoFlags
# We don't allow parametrised variables like `T[int]`
if isinstance(defn, ParamDef):
raise GuppyError(
Expand All @@ -79,16 +80,26 @@ def arg_from_ast(
node,
)

# An annotated argument, e.g. `int @inout`
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
arg, flags = arg_from_ast(node.left, globals, param_var_mapping)
match node.right:
case ast.Name(id="inout"):
flags |= InputFlags.Inout
case _:
raise GuppyError("Invalid annotation", node.right)
return arg, flags

# We allow tuple types to be written as `(int, bool)`
if isinstance(node, ast.Tuple):
ty = TupleType(
[type_from_ast(el, globals, param_var_mapping) for el in node.elts]
)
return TypeArg(ty)
return TypeArg(ty), InputFlags.NoFlags

# `None` is represented as a `ast.Constant` node with value `None`
if isinstance(node, ast.Constant) and node.value is None:
return TypeArg(NoneType())
return TypeArg(NoneType()), InputFlags.NoFlags

# Integer literals are turned into nat args since these are the only ones we support
# right now.
Expand All @@ -98,7 +109,7 @@ def arg_from_ast(
# `ast.UnaryOp` negation of a `ast.Constant(5)`
assert node.value >= 0
nat_ty = NumericType(NumericType.Kind.Nat)
return ConstArg(ConstValue(nat_ty, node.value))
return ConstArg(ConstValue(nat_ty, node.value)), InputFlags.NoFlags

# Finally, we also support delayed annotations in strings
if isinstance(node, ast.Constant) and isinstance(node.value, str):
Expand All @@ -122,15 +133,27 @@ def arg_from_ast(
_type_param = TypeParam(0, "T", True)


def type_with_flags_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> tuple[Type, InputFlags]:
"""Turns an AST expression into a Guppy type possibly annotated with @flags."""
# Parse an argument and check that it's valid for a `TypeParam`
arg, flags = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty, flags


def type_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> Type:
"""Turns an AST expression into a Guppy type."""
# Parse an argument and check that it's valid for a `TypeParam`
arg = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty
ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping)
if flags != InputFlags.NoFlags:
raise GuppyError("`@` type annotations are not allowed in this position", node)
return ty


def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]:
Expand Down
2 changes: 1 addition & 1 deletion tests/error/inout_errors/nonlinear.err
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:11

9: @guppy.declare(module)
10: def foo(x: int @inout) -> qubit: ...
^^^^^
^^^^^^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear_callable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:12

10: @guppy.declare(module)
11: def foo(f: Callable[[int @inout], None]) -> None: ...
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
15 changes: 15 additions & 0 deletions tests/error/inout_errors/nonlinear_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Callable

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout


module = GuppyModule("test")


@guppy.declare(module)
def foo(f: Callable[[int @inout], None]) -> None: ...


module.compile()
Loading

0 comments on commit bf070ab

Please sign in to comment.