Skip to content

Commit

Permalink
Merge branch 'feat/inout' into inout/parse
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jul 29, 2024
2 parents 68ece53 + 86c8b31 commit 11f2e25
Show file tree
Hide file tree
Showing 20 changed files with 272 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main
pull_request:
branches:
- main
- '**'
merge_group:
types: [checks_requested]
workflow_dispatch: {}
Expand Down
2 changes: 1 addition & 1 deletion .release-please-manifest.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
".": "0.6.2"
".": "0.7.0"
}
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# Changelog

## [0.7.0](https://github.com/CQCL/guppylang/compare/v0.6.2...v0.7.0) (2024-07-25)


### ⚠ BREAKING CHANGES

* `qubit`s are now reset on allocation

### Features

* `qubit`s are now reset on allocation, and `dirty_qubit` added ([#325](https://github.com/CQCL/guppylang/issues/325)) ([4a9e205](https://github.com/CQCL/guppylang/commit/4a9e20529a4d0859f010fad62ba06f62ca1c98ce))
* Allow access to struct fields and mutation of linear ones ([#295](https://github.com/CQCL/guppylang/issues/295)) ([6698b75](https://github.com/CQCL/guppylang/commit/6698b75b01421cd1fa545219786266fb0c1da05b)), closes [#293](https://github.com/CQCL/guppylang/issues/293)
* Allow redefinition of names in guppy modules ([#326](https://github.com/CQCL/guppylang/issues/326)) ([314409c](https://github.com/CQCL/guppylang/commit/314409cd63b544d0fdbf16db66201b08dead81fe)), closes [#307](https://github.com/CQCL/guppylang/issues/307)


### Bug Fixes

* Use correct hook for error printing inside jupyter notebooks ([#324](https://github.com/CQCL/guppylang/issues/324)) ([bfdb003](https://github.com/CQCL/guppylang/commit/bfdb003d454d3d8fb6385c2c758dab56ab622496)), closes [#323](https://github.com/CQCL/guppylang/issues/323)

## [0.6.2](https://github.com/CQCL/guppylang/compare/v0.6.1...v0.6.2) (2024-07-10)


Expand Down
5 changes: 4 additions & 1 deletion guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> BB | None:
def visit_FunctionDef(
self, node: ast.FunctionDef, bb: BB, jumps: Jumps
) -> BB | None:
from guppylang.checker.func_checker import check_signature
from guppylang.checker.func_checker import check_signature, parse_docstring

node, docstring = parse_docstring(node)

func_ty = check_signature(node, self.globals)
returns_none = isinstance(func_ty.output, NoneType)
Expand All @@ -220,6 +222,7 @@ def visit_FunctionDef(
new_node = NestedFunctionDef(
cfg,
func_ty,
docstring=docstring,
name=node.name,
args=node.args,
body=node.body,
Expand Down
23 changes: 22 additions & 1 deletion guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def check_nested_func_def(
from guppylang.definition.function import ParsedFunctionDef

func = ParsedFunctionDef(
def_id, func_def.name, func_def, func_ty, globals.python_scope
def_id, func_def.name, func_def, func_ty, globals.python_scope, None
)
globals = ctx.globals | Globals(
{func.id: func}, {func_def.name: func.id}, {}, {}
Expand Down Expand Up @@ -177,3 +177,24 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
input_names,
sorted(param_var_mapping.values(), key=lambda v: v.idx),
)


def parse_docstring(func_ast: ast.FunctionDef) -> tuple[ast.FunctionDef, str | None]:
"""Check if the first line of a function is a docstring.
If it is, return the function with the docstring removed, plus the docstring.
Else, return the original function and `None`
"""
docstring = None
match func_ast.body:
case [doc, *xs]:
if (
isinstance(doc, ast.Expr)
and isinstance(doc.value, ast.Constant)
and isinstance(doc.value.value, str)
):
docstring = doc.value.value
func_ast.body = xs
case _:
pass
return func_ast, docstring
18 changes: 18 additions & 0 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@
GlobalName,
LocalCall,
PlaceNode,
ResultExpr,
TensorCall,
TypeApply,
)
from guppylang.tys.arg import ConstArg, TypeArg
from guppylang.tys.builtin import bool_type, get_element_type, is_list_type
from guppylang.tys.const import ConstValue
from guppylang.tys.subst import Inst
from guppylang.tys.ty import (
BoundTypeVar,
FunctionType,
NoneType,
NumericType,
TupleType,
Type,
type_to_row,
Expand Down Expand Up @@ -294,6 +298,20 @@ def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> OutPortV:
unpack = self.graph.add_unpack_tuple(struct_port)
return unpack.out_port(node.struct_ty.fields.index(node.field))

def visit_ResultExpr(self, node: ResultExpr) -> OutPortV:
type_args = [
TypeArg(node.ty),
ConstArg(ConstValue(value=node.tag, ty=NumericType(NumericType.Kind.Nat))),
]
op = ops.CustomOp(
extension="tket2.results",
op_name="Result",
args=[arg.to_hugr() for arg in type_args],
parent=UNDEFINED,
)
self.graph.add_node(ops.OpType(op), inputs=[self.visit(node.value)])
return self._pack_returns([], NoneType())

def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV:
from guppylang.compiler.stmt_compiler import StmtCompiler

Expand Down
1 change: 1 addition & 0 deletions guppylang/compiler/func_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def compile_local_func_def(
func,
func.ty,
{},
None,
func.cfg,
def_node,
)
Expand Down
2 changes: 1 addition & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def custom(
"""

def dec(f: PyFunc) -> RawCustomFunctionDef:
func_ast = parse_py_func(f)
func_ast, docstring = parse_py_func(f)
if not has_empty_body(func_ast):
raise GuppyError(
"Body of custom function declaration must be empty",
Expand Down
15 changes: 12 additions & 3 deletions guppylang/definition/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ class RawFunctionDecl(ParsableDef):

def parse(self, globals: Globals) -> "CheckedFunctionDecl":
"""Parses and checks the user-provided signature of the function."""
func_ast = parse_py_func(self.python_func)
func_ast, docstring = parse_py_func(self.python_func)
ty = check_signature(func_ast, globals)
if not has_empty_body(func_ast):
raise GuppyError(
"Body of function declaration must be empty", func_ast.body[0]
)
return CheckedFunctionDecl(self.id, self.name, func_ast, ty, self.python_func)
return CheckedFunctionDecl(
self.id, self.name, func_ast, ty, self.python_func, docstring
)


@dataclass(frozen=True)
Expand All @@ -46,6 +48,7 @@ class CheckedFunctionDecl(RawFunctionDecl, CompilableDef, CallableDef):
"""

defined_at: ast.FunctionDef
docstring: str | None

def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
Expand All @@ -69,7 +72,13 @@ def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledFunctionDecl":
"""Adds a Hugr `FuncDecl` node for this function to the Hugr."""
node = graph.add_declare(self.ty, parent, self.name)
return CompiledFunctionDecl(
self.id, self.name, self.defined_at, self.ty, self.python_func, node
self.id,
self.name,
self.defined_at,
self.ty,
self.python_func,
self.docstring,
node,
)


Expand Down
19 changes: 14 additions & 5 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from guppylang.checker.cfg_checker import CheckedCFG
from guppylang.checker.core import Context, Globals, PyScope
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import check_global_func_def, check_signature
from guppylang.checker.func_checker import (
check_global_func_def,
check_signature,
parse_docstring,
)
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.compiler.func_compiler import compile_global_func_def
from guppylang.definition.common import CheckableDef, CompilableDef, ParsableDef
Expand Down Expand Up @@ -39,13 +43,15 @@ class RawFunctionDef(ParsableDef):

def parse(self, globals: Globals) -> "ParsedFunctionDef":
"""Parses and checks the user-provided signature of the function."""
func_ast = parse_py_func(self.python_func)
func_ast, docstring = parse_py_func(self.python_func)
ty = check_signature(func_ast, globals)
if ty.parametrized:
raise GuppyError(
"Generic function definitions are not supported yet", func_ast
)
return ParsedFunctionDef(self.id, self.name, func_ast, ty, self.python_scope)
return ParsedFunctionDef(
self.id, self.name, func_ast, ty, self.python_scope, docstring
)


@dataclass(frozen=True)
Expand All @@ -59,6 +65,7 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
python_scope: PyScope
defined_at: ast.FunctionDef
ty: FunctionType
docstring: str | None

description: str = field(default="function", init=False)

Expand All @@ -73,6 +80,7 @@ def check(self, globals: Globals) -> "CheckedFunctionDef":
self.defined_at,
self.ty,
self.python_scope,
self.docstring,
cfg,
)

Expand Down Expand Up @@ -119,6 +127,7 @@ def compile_outer(self, graph: Hugr, parent: Node) -> "CompiledFunctionDef":
self.defined_at,
self.ty,
self.python_scope,
self.docstring,
self.cfg,
def_node,
)
Expand Down Expand Up @@ -161,7 +170,7 @@ def compile_inner(self, graph: Hugr, globals: CompiledGlobals) -> None:
compile_global_func_def(self, self.hugr_node, graph, globals)


def parse_py_func(f: PyFunc) -> ast.FunctionDef:
def parse_py_func(f: PyFunc) -> tuple[ast.FunctionDef, str | None]:
source_lines, line_offset = inspect.getsourcelines(f)
source = "".join(source_lines) # Lines already have trailing \n's
source = textwrap.dedent(source)
Expand All @@ -172,4 +181,4 @@ def parse_py_func(f: PyFunc) -> ast.FunctionDef:
annotate_location(func_ast, source, file, line_offset)
if not isinstance(func_ast, ast.FunctionDef):
raise GuppyError("Expected a function definition", func_ast)
return func_ast
return parse_docstring(func_ast)
27 changes: 23 additions & 4 deletions guppylang/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,29 @@ class InternalGuppyError(Exception):
@contextmanager
def exception_hook(hook: ExceptHook) -> Iterator[None]:
"""Sets a custom `excepthook` for the scope of a 'with' block."""
old_hook = sys.excepthook
sys.excepthook = hook
yield
sys.excepthook = old_hook
try:
# Check if we're inside a jupyter notebook since it uses its own exception
# hook. If we're in a regular interpreter, this line will raise a `NameError`
ipython_shell = get_ipython() # type: ignore[name-defined]

def ipython_excepthook(
shell: Any,
etype: type[BaseException],
value: BaseException,
tb: TracebackType | None,
tb_offset: Any = None,
) -> Any:
return hook(etype, value, tb)

ipython_shell.set_custom_exc((GuppyError,), ipython_excepthook)
yield
ipython_shell.set_custom_exc((), None)
except NameError:
# Otherwise, override the regular sys.excepthook
old_hook = sys.excepthook
sys.excepthook = hook
yield
sys.excepthook = old_hook


def format_source_location(
Expand Down
12 changes: 3 additions & 9 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from types import ModuleType
from typing import Any, Union

from guppylang.ast_util import AstNode
from guppylang.checker.core import Globals, PyScope
from guppylang.compiler.core import CompiledGlobals
from guppylang.definition.common import (
Expand Down Expand Up @@ -104,14 +103,16 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None:
def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
"""Registers a definition with this module.
If the name of the definition is already defined, the new definition
replaces the old.
Optionally, the definition can be marked as an instance method by passing the
corresponding instance type definition.
"""
self._check_not_yet_compiled()
if self._instance_func_buffer is not None and not isinstance(defn, TypeDef):
self._instance_func_buffer[defn.name] = defn
else:
self._check_name_available(defn.name, defn.defined_at)
if isinstance(defn, TypeDef | ParamDef):
self._raw_type_defs[defn.id] = defn
else:
Expand Down Expand Up @@ -228,13 +229,6 @@ def _check_not_yet_compiled(self) -> None:
if self._compiled:
raise GuppyError(f"The module `{self.name}` has already been compiled")

def _check_name_available(self, name: str, node: AstNode | None) -> None:
if self.contains(name):
raise GuppyError(
f"Module `{self.name}` already contains a definition named `{name}`",
node,
)


def get_py_scope(f: PyFunc) -> PyScope:
"""Returns a mapping of all variables captured by a Python function.
Expand Down
11 changes: 11 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,20 @@ class PyExpr(ast.expr):
_fields = ("value",)


class ResultExpr(ast.expr):
"""A `result(tag, value)` expression."""

value: ast.expr
ty: Type
tag: int

_fields = ("value", "ty", "tag")


class NestedFunctionDef(ast.FunctionDef):
cfg: "CFG"
ty: FunctionType
docstring: str | None

def __init__(self, cfg: "CFG", ty: FunctionType, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand Down
Loading

0 comments on commit 11f2e25

Please sign in to comment.