From 28be8a4fca3a7a4348dfa071f0bc703782b78ab6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah <seyon.sivarajah@quantinuum.com> Date: Wed, 8 Jan 2025 10:52:53 +0000 Subject: [PATCH] feat: add `panic` builtin function (#757) Closes #756 --------- Co-authored-by: Mark Koch <48097969+mark-koch@users.noreply.github.com> Co-authored-by: Mark Koch <mark.koch@quantinuum.com> --- guppylang/compiler/expr_compiler.py | 15 ++++++++ guppylang/nodes.py | 9 +++++ guppylang/std/_internal/checker.py | 37 ++++++++++++++++++ guppylang/std/_internal/compiler/prelude.py | 5 ++- guppylang/std/builtins.py | 5 +++ tests/error/misc_errors/panic_msg_empty.err | 10 +++++ tests/error/misc_errors/panic_msg_empty.py | 7 ++++ tests/error/misc_errors/panic_msg_not_str.err | 8 ++++ tests/error/misc_errors/panic_msg_not_str.py | 7 ++++ .../misc_errors/panic_tag_not_static.err | 8 ++++ .../error/misc_errors/panic_tag_not_static.py | 7 ++++ tests/integration/test_panic.py | 38 +++++++++++++++++++ tests/integration/test_quantum.py | 13 +++++++ 13 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 tests/error/misc_errors/panic_msg_empty.err create mode 100644 tests/error/misc_errors/panic_msg_empty.py create mode 100644 tests/error/misc_errors/panic_msg_not_str.err create mode 100644 tests/error/misc_errors/panic_msg_not_str.py create mode 100644 tests/error/misc_errors/panic_tag_not_static.err create mode 100644 tests/error/misc_errors/panic_tag_not_static.py create mode 100644 tests/integration/test_panic.py diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 524e7fd9..ca25edbb 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -41,6 +41,7 @@ GlobalName, InoutReturnSentinel, LocalCall, + PanicExpr, PartialApply, PlaceNode, ResultExpr, @@ -53,6 +54,7 @@ from guppylang.std._internal.compiler.list import ( list_new, ) +from guppylang.std._internal.compiler.prelude import build_error, build_panic from guppylang.tys.arg import Argument from guppylang.tys.builtin import ( get_element_type, @@ -473,6 +475,19 @@ def visit_ResultExpr(self, node: ResultExpr) -> Wire: self.builder.add_op(op, self.visit(node.value)) return self._pack_returns([], NoneType()) + def visit_PanicExpr(self, node: PanicExpr) -> Wire: + err = build_error(self.builder, 1, node.msg) + in_tys = [get_type(e).to_hugr() for e in node.values] + out_tys = [ty.to_hugr() for ty in type_to_row(get_type(node))] + outs = build_panic( + self.builder, + in_tys, + out_tys, + err, + *(self.visit(e) for e in node.values), + ).outputs() + return self._pack_returns(list(outs), get_type(node)) + def visit_DesugaredListComp(self, node: DesugaredListComp) -> Wire: # Make up a name for the list under construction and bind it to an empty list list_ty = get_type(node) diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 89d03787..63383741 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -281,6 +281,15 @@ class ResultExpr(ast.expr): _fields = ("value", "base_ty", "array_len", "tag") +class PanicExpr(ast.expr): + """A `panic(msg, *args)` expression.""" + + msg: str + values: list[ast.expr] + + _fields = ("msg", "values") + + class InoutReturnSentinel(ast.expr): """An invisible expression corresponding to an implicit use of borrowed vars whenever a function returns.""" diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py index 55d5613f..9ac74a57 100644 --- a/guppylang/std/_internal/checker.py +++ b/guppylang/std/_internal/checker.py @@ -32,6 +32,7 @@ GenericParamValue, GlobalCall, MakeIter, + PanicExpr, ResultExpr, ) from guppylang.tys.arg import ConstArg, TypeArg @@ -355,6 +356,42 @@ def _is_numeric_or_bool_type(ty: Type) -> bool: return isinstance(ty, NumericType) or is_bool_type(ty) +class PanicChecker(CustomCallChecker): + """Call checker for the `panic` function.""" + + @dataclass(frozen=True) + class NoMessageError(Error): + title: ClassVar[str] = "No panic message" + span_label: ClassVar[str] = "Missing message argument to panic call" + + @dataclass(frozen=True) + class Suggestion(Note): + message: ClassVar[str] = 'Add a message: `panic("message")`' + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: + match args: + case []: + err = PanicChecker.NoMessageError(self.node) + err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None)) + raise GuppyTypeError(err) + case [msg, *rest]: + if not isinstance(msg, ast.Constant) or not isinstance(msg.value, str): + raise GuppyTypeError(ExpectedError(msg, "a string literal")) + + vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest] + node = PanicExpr(msg.value, vals) + return with_loc(self.node, node), NoneType() + case args: + return assert_never(args) # type: ignore[arg-type] + + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: + # Panic may return any type, so we don't have to check anything. Consequently + # we also can't infer anything in the expected type, so we always return an + # empty substitution + expr, _ = self.synthesize(args) + return expr, {} + + class RangeChecker(CustomCallChecker): """Call checker for the `range` function.""" diff --git a/guppylang/std/_internal/compiler/prelude.py b/guppylang/std/_internal/compiler/prelude.py index 518ea706..06162c18 100644 --- a/guppylang/std/_internal/compiler/prelude.py +++ b/guppylang/std/_internal/compiler/prelude.py @@ -7,6 +7,7 @@ import hugr.std.collections import hugr.std.int +import hugr.std.prelude from hugr import Node, Wire, ops from hugr import tys as ht from hugr import val as hv @@ -63,7 +64,7 @@ def panic(inputs: list[ht.Type], outputs: list[ht.Type]) -> ops.ExtOp: def build_panic( - builder: DfBase[ops.Case], + builder: DfBase[P], in_tys: ht.TypeRow, out_tys: ht.TypeRow, err: Wire, @@ -74,7 +75,7 @@ def build_panic( return builder.add_op(op, err, *args) -def build_error(builder: DfBase[ops.Case], signal: int, msg: str) -> Wire: +def build_error(builder: DfBase[P], signal: int, msg: str) -> Wire: """Constructs and loads a static error value.""" val = ErrorVal(signal, msg) return builder.load(builder.add_const(val)) diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index aaddb45a..bedad29d 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -13,6 +13,7 @@ CallableChecker, DunderChecker, NewArrayChecker, + PanicChecker, RangeChecker, ResultChecker, ReversingChecker, @@ -651,6 +652,10 @@ def __iter__(self: "SizedIter[L, n]" @ owned) -> "SizedIter[L, n]": # type: ign def result(tag, value): ... +@guppy.custom(checker=PanicChecker(), higher_order_value=False) +def panic(msg, *args): ... + + @guppy.custom(checker=DunderChecker("__abs__"), higher_order_value=False) def abs(x): ... diff --git a/tests/error/misc_errors/panic_msg_empty.err b/tests/error/misc_errors/panic_msg_empty.err new file mode 100644 index 00000000..ddcab328 --- /dev/null +++ b/tests/error/misc_errors/panic_msg_empty.err @@ -0,0 +1,10 @@ +Error: No panic message (at $FILE:7:4) + | +5 | @compile_guppy +6 | def foo(x: int) -> None: +7 | panic() + | ^^^^^^^ Missing message argument to panic call + +Note: Add a message: `panic("message")` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/misc_errors/panic_msg_empty.py b/tests/error/misc_errors/panic_msg_empty.py new file mode 100644 index 00000000..38c0ffc8 --- /dev/null +++ b/tests/error/misc_errors/panic_msg_empty.py @@ -0,0 +1,7 @@ +from guppylang.std.builtins import panic +from tests.util import compile_guppy + + +@compile_guppy +def foo(x: int) -> None: + panic() diff --git a/tests/error/misc_errors/panic_msg_not_str.err b/tests/error/misc_errors/panic_msg_not_str.err new file mode 100644 index 00000000..0e354800 --- /dev/null +++ b/tests/error/misc_errors/panic_msg_not_str.err @@ -0,0 +1,8 @@ +Error: Expected a string literal (at $FILE:7:10) + | +5 | @compile_guppy +6 | def foo(x: int) -> None: +7 | panic((), x) + | ^^ Expected a string literal + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/misc_errors/panic_msg_not_str.py b/tests/error/misc_errors/panic_msg_not_str.py new file mode 100644 index 00000000..bdafb447 --- /dev/null +++ b/tests/error/misc_errors/panic_msg_not_str.py @@ -0,0 +1,7 @@ +from guppylang.std.builtins import panic +from tests.util import compile_guppy + + +@compile_guppy +def foo(x: int) -> None: + panic((), x) diff --git a/tests/error/misc_errors/panic_tag_not_static.err b/tests/error/misc_errors/panic_tag_not_static.err new file mode 100644 index 00000000..0d732c94 --- /dev/null +++ b/tests/error/misc_errors/panic_tag_not_static.err @@ -0,0 +1,8 @@ +Error: Expected a string literal (at $FILE:7:10) + | +5 | @compile_guppy +6 | def foo(y: bool) -> None: +7 | panic("foo" + "bar", y) + | ^^^^^^^^^^^^^ Expected a string literal + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/misc_errors/panic_tag_not_static.py b/tests/error/misc_errors/panic_tag_not_static.py new file mode 100644 index 00000000..d7a1d155 --- /dev/null +++ b/tests/error/misc_errors/panic_tag_not_static.py @@ -0,0 +1,7 @@ +from guppylang.std.builtins import panic +from tests.util import compile_guppy + + +@compile_guppy +def foo(y: bool) -> None: + panic("foo" + "bar", y) diff --git a/tests/integration/test_panic.py b/tests/integration/test_panic.py new file mode 100644 index 00000000..a6a342a6 --- /dev/null +++ b/tests/integration/test_panic.py @@ -0,0 +1,38 @@ +from guppylang import GuppyModule, guppy +from guppylang.std.builtins import panic +from tests.util import compile_guppy + + +def test_basic(validate): + @compile_guppy + def main() -> None: + panic("I panicked!") + + validate(main) + + +def test_discard(validate): + @compile_guppy + def main() -> None: + a = 1 + 2 + panic("I panicked!", False, a) + + validate(main) + + +def test_value(validate): + module = GuppyModule("test") + + @guppy(module) + def foo() -> int: + return panic("I panicked!") + + @guppy(module) + def bar() -> tuple[int, float]: + return panic("I panicked!") + + @guppy(module) + def baz() -> None: + return panic("I panicked!") + + validate(module.compile()) diff --git a/tests/integration/test_quantum.py b/tests/integration/test_quantum.py index 17e77ace..1dfe344f 100644 --- a/tests/integration/test_quantum.py +++ b/tests/integration/test_quantum.py @@ -1,5 +1,6 @@ """Various tests for the functions defined in `guppylang.prelude.quantum`.""" +from typing import no_type_check from hugr.package import ModulePointer import guppylang.decorator @@ -154,3 +155,15 @@ def test() -> None: discard_array(qs) validate(test) + + +def test_panic_discard(validate): + """Panic while discarding qubit.""" + + @compile_quantum_guppy + @no_type_check + def test() -> None: + q = qubit() + panic("I panicked!", q) + + validate(test)