Skip to content

Commit

Permalink
feat: Parse inout annotations in function signatures (#316)
Browse files Browse the repository at this point in the history
Closes #309. See #282 for context.
  • Loading branch information
mark-koch committed Aug 9, 2024
1 parent 86c8b31 commit 4316ac9
Show file tree
Hide file tree
Showing 25 changed files with 367 additions and 59 deletions.
16 changes: 8 additions & 8 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from guppylang.definition.common import DefId
from guppylang.error import GuppyError
from guppylang.nodes import CheckedNestedFunctionDef, NestedFunctionDef
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.ty import FuncInput, FunctionType, InputFlags, NoneType
from guppylang.tys.parsing import parse_function_io_types
from guppylang.tys.ty import FunctionType, NoneType

if TYPE_CHECKING:
from guppylang.tys.param import Parameter
Expand Down Expand Up @@ -143,19 +143,19 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType

# TODO: Prepopulate mapping when using Python 3.12 style generic functions
param_var_mapping: dict[str, Parameter] = {}
inputs = []
input_nodes = []
input_names = []
for inp in func_def.args.args:
if inp.annotation is None:
raise GuppyError("Argument type must be annotated", inp)
ty = type_from_ast(inp.annotation, globals, param_var_mapping)
inputs.append(FuncInput(ty, InputFlags.NoFlags))
input_nodes.append(inp.annotation)
input_names.append(inp.arg)
ret_type = type_from_ast(func_def.returns, globals, param_var_mapping)

inputs, output = parse_function_io_types(
input_nodes, func_def.returns, func_def, globals, param_var_mapping
)
return FunctionType(
inputs,
ret_type,
output,
input_names,
sorted(param_var_mapping.values(), key=lambda v: v.idx),
)
Expand Down
10 changes: 10 additions & 0 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def py(*_args: Any) -> Any:
raise GuppyError("`py` can only by used in a Guppy context")


class _Inout:
"""Dummy class to support `@inout` annotations."""

def __rmatmul__(self, other: Any) -> Any:
return other


inout = _Inout()


class nat:
"""Class to import in order to use nats."""

Expand Down
25 changes: 5 additions & 20 deletions guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from itertools import repeat
from typing import TYPE_CHECKING, Literal

from hugr.serialization import tys

from guppylang.ast_util import AstNode
from guppylang.definition.common import DefId
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
NumericType,
OpaqueType,
Expand All @@ -27,7 +24,7 @@


@dataclass(frozen=True)
class _CallableTypeDef(TypeDef):
class CallableTypeDef(TypeDef):
"""Type definition associated with the builtin `Callable` type.
Any impls on functions can be registered with this definition.
Expand All @@ -38,20 +35,8 @@ class _CallableTypeDef(TypeDef):
def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> FunctionType:
# We get the inputs/output as a flattened list: `args = [*inputs, output]`.
if not args:
raise GuppyError(f"Missing parameter for type `{self.name}`", loc)
args = [
# TODO: Better error location
TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty
for i, arg in enumerate(args)
]
*input_tys, output = args
inputs = [
FuncInput(ty, flags)
for ty, flags in zip(input_tys, repeat(InputFlags.NoFlags), strict=False)
]
return FunctionType(list(inputs), output)
# Callable types are constructed using special logic in the type parser
raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`")


@dataclass(frozen=True)
Expand Down Expand Up @@ -157,7 +142,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> tys.Type:
return tys.Type(ty)


callable_type_def = _CallableTypeDef(DefId.fresh(), None)
callable_type_def = CallableTypeDef(DefId.fresh(), None)
tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
none_type_def = _NoneTypeDef(DefId.fresh(), None)
bool_type_def = OpaqueTypeDef(
Expand Down
156 changes: 125 additions & 31 deletions guppylang/tys/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.builtin import CallableTypeDef
from guppylang.tys.const import ConstValue
from guppylang.tys.param import Parameter, TypeParam
from guppylang.tys.ty import NoneType, NumericType, TupleType, Type
from guppylang.tys.ty import (
FuncInput,
FunctionType,
InputFlags,
NoneType,
NumericType,
TupleType,
Type,
)


def arg_from_ast(
Expand All @@ -28,6 +37,11 @@ def arg_from_ast(
if x not in globals:
raise GuppyError("Unknown identifier", node)
match globals[x]:
# Special case for the `Callable` type
case CallableTypeDef():
return TypeArg(
_parse_callable_type([], node, globals, param_var_mapping)
)
# Either a defined type (e.g. `int`, `bool`, ...)
case TypeDef() as defn:
return TypeArg(defn.check_instantiate([], globals, node))
Expand All @@ -50,21 +64,16 @@ def arg_from_ast(
x = node.value.id
if x in globals:
defn = globals[x]
if isinstance(defn, TypeDef):
arg_nodes = (
node.slice.elts
if isinstance(node.slice, ast.Tuple)
else [node.slice]
arg_nodes = (
node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
)
if isinstance(defn, CallableTypeDef):
# Special case for the `Callable[[S1, S2, ...], T]` type to support the
# input list syntax and @inout annotations.
return TypeArg(
_parse_callable_type(arg_nodes, node, globals, param_var_mapping)
)
# Hack: Flatten argument lists to support the `Callable` type. For
# example, we turn `Callable[[int, int], bool]` into
# `Callable[int, int, bool]`.
# TODO: We can get rid of this once we added support for variadic params
arg_nodes = [
n
for arg in arg_nodes
for n in (arg.elts if isinstance(arg, ast.List) else (arg,))
]
if isinstance(defn, TypeDef):
args = [
arg_from_ast(arg_node, globals, param_var_mapping)
for arg_node in arg_nodes
Expand Down Expand Up @@ -102,35 +111,120 @@ def arg_from_ast(

# Finally, we also support delayed annotations in strings
if isinstance(node, ast.Constant) and isinstance(node.value, str):
try:
[stmt] = ast.parse(node.value).body
if not isinstance(stmt, ast.Expr):
raise GuppyError("Invalid Guppy type", node)
set_location_from(stmt, loc=node)
shift_loc(
stmt,
delta_lineno=node.lineno - 1, # -1 since lines start at 1
delta_col_offset=node.col_offset + 1, # +1 to remove the `"`
)
return arg_from_ast(stmt.value, globals, param_var_mapping)
except (SyntaxError, ValueError):
raise GuppyError("Invalid Guppy type", node) from None
node = _parse_delayed_annotation(node.value, node)
return arg_from_ast(node, globals, param_var_mapping)

raise GuppyError("Not a valid type argument", node)


def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr:
"""Parses a delayed type annotation in a string."""
try:
[stmt] = ast.parse(ast_str).body
if not isinstance(stmt, ast.Expr):
raise GuppyError("Invalid Guppy type", node)
set_location_from(stmt, loc=node)
shift_loc(
stmt,
delta_lineno=node.lineno - 1, # -1 since lines start at 1
delta_col_offset=node.col_offset + 1, # +1 to remove the `"`
)
except (SyntaxError, ValueError):
raise GuppyError("Invalid Guppy type", node) from None
else:
return stmt.value


def _parse_callable_type(
args: list[ast.expr],
loc: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None,
) -> FunctionType:
"""Helper function to parse a `Callable[[<arguments>], <return type>]` type."""
err = (
"Function types should be specified via "
"`Callable[[<arguments>], <return type>]`"
)
if len(args) != 2:
raise GuppyError(err, loc)
[inputs, output] = args
if not isinstance(inputs, ast.List):
raise GuppyError(err, loc)
inouts, output = parse_function_io_types(
inputs.elts, output, loc, globals, param_var_mapping
)
return FunctionType(inouts, output)


def parse_function_io_types(
input_nodes: list[ast.expr],
output_node: ast.expr,
loc: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None,
) -> tuple[list[FuncInput], Type]:
"""Parses the inputs and output types of a function type.
This function takes care of parsing `@inout` annotations and any related checks.
Returns the parsed input and output types.
"""
inputs = []
for inp in input_nodes:
ty, flags = type_with_flags_from_ast(inp, globals, param_var_mapping)
if InputFlags.Inout in flags and not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`", loc
)
inputs.append(FuncInput(ty, flags))
output = type_from_ast(output_node, globals, param_var_mapping)
return inputs, output


_type_param = TypeParam(0, "T", True)


def type_with_flags_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> tuple[Type, InputFlags]:
"""Turns an AST expression into a Guppy type with some optional @flags."""
# Check for `type @flag` annotations
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
ty, flags = type_with_flags_from_ast(node.left, globals, param_var_mapping)
match node.right:
case ast.Name(id="inout"):
if not ty.linear:
raise GuppyError(
f"Non-linear type `{ty}` cannot be annotated as `@inout`",
node.right,
)
flags |= InputFlags.Inout
case _:
raise GuppyError("Invalid annotation", node.right)
return ty, flags
# We also need to handle the case that this could be a delayed string annotation
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
node = _parse_delayed_annotation(node.value, node)
return type_with_flags_from_ast(node, globals, param_var_mapping)
else:
# Parse an argument and check that it's valid for a `TypeParam`
arg = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty, InputFlags.NoFlags


def type_from_ast(
node: AstNode,
globals: Globals,
param_var_mapping: dict[str, Parameter] | None = None,
) -> Type:
"""Turns an AST expression into a Guppy type."""
# Parse an argument and check that it's valid for a `TypeParam`
arg = arg_from_ast(node, globals, param_var_mapping)
return _type_param.check_arg(arg, node).ty
ty, flags = type_with_flags_from_ast(node, globals, param_var_mapping)
if flags != InputFlags.NoFlags:
raise GuppyError("`@` type annotations are not allowed in this position", node)
return ty


def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]:
Expand Down
Empty file.
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:11

9: @guppy.declare(module)
10: def foo(x: int @inout) -> qubit: ...
^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
14 changes: 14 additions & 0 deletions tests/error/inout_errors/nonlinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")


@guppy.declare(module)
def foo(x: int @inout) -> qubit: ...


module.compile()
6 changes: 6 additions & 0 deletions tests/error/inout_errors/nonlinear_callable.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:12

10: @guppy.declare(module)
11: def foo(f: Callable[[int @inout], None]) -> None: ...
^^^^^
GuppyError: Non-linear type `int` cannot be annotated as `@inout`
15 changes: 15 additions & 0 deletions tests/error/inout_errors/nonlinear_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Callable

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import inout


module = GuppyModule("test")


@guppy.declare(module)
def foo(f: Callable[[int @inout], None]) -> None: ...


module.compile()
6 changes: 6 additions & 0 deletions tests/error/misc_errors/callable_no_args.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:10

8: @guppy.declare(module)
9: def foo(f: Callable) -> None: ...
^^^^^^^^
GuppyError: Function types should be specified via `Callable[[<arguments>], <return type>]`
13 changes: 13 additions & 0 deletions tests/error/misc_errors/callable_no_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Callable

from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")

@guppy.declare(module)
def foo(f: Callable) -> None: ...


module.compile()
6 changes: 6 additions & 0 deletions tests/error/misc_errors/callable_not_list1.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:10

8: @guppy.declare(module)
9: def foo(f: "Callable[int, float, bool]") -> None: ...
^^^^^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Function types should be specified via `Callable[[<arguments>], <return type>]`
Loading

0 comments on commit 4316ac9

Please sign in to comment.