From 28be8a4fca3a7a4348dfa071f0bc703782b78ab6 Mon Sep 17 00:00:00 2001
From: Seyon Sivarajah <seyon.sivarajah@quantinuum.com>
Date: Wed, 8 Jan 2025 10:52:53 +0000
Subject: [PATCH] feat: add `panic` builtin function (#757)

Closes #756

---------

Co-authored-by: Mark Koch <48097969+mark-koch@users.noreply.github.com>
Co-authored-by: Mark Koch <mark.koch@quantinuum.com>
---
 guppylang/compiler/expr_compiler.py           | 15 ++++++++
 guppylang/nodes.py                            |  9 +++++
 guppylang/std/_internal/checker.py            | 37 ++++++++++++++++++
 guppylang/std/_internal/compiler/prelude.py   |  5 ++-
 guppylang/std/builtins.py                     |  5 +++
 tests/error/misc_errors/panic_msg_empty.err   | 10 +++++
 tests/error/misc_errors/panic_msg_empty.py    |  7 ++++
 tests/error/misc_errors/panic_msg_not_str.err |  8 ++++
 tests/error/misc_errors/panic_msg_not_str.py  |  7 ++++
 .../misc_errors/panic_tag_not_static.err      |  8 ++++
 .../error/misc_errors/panic_tag_not_static.py |  7 ++++
 tests/integration/test_panic.py               | 38 +++++++++++++++++++
 tests/integration/test_quantum.py             | 13 +++++++
 13 files changed, 167 insertions(+), 2 deletions(-)
 create mode 100644 tests/error/misc_errors/panic_msg_empty.err
 create mode 100644 tests/error/misc_errors/panic_msg_empty.py
 create mode 100644 tests/error/misc_errors/panic_msg_not_str.err
 create mode 100644 tests/error/misc_errors/panic_msg_not_str.py
 create mode 100644 tests/error/misc_errors/panic_tag_not_static.err
 create mode 100644 tests/error/misc_errors/panic_tag_not_static.py
 create mode 100644 tests/integration/test_panic.py

diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py
index 524e7fd9..ca25edbb 100644
--- a/guppylang/compiler/expr_compiler.py
+++ b/guppylang/compiler/expr_compiler.py
@@ -41,6 +41,7 @@
     GlobalName,
     InoutReturnSentinel,
     LocalCall,
+    PanicExpr,
     PartialApply,
     PlaceNode,
     ResultExpr,
@@ -53,6 +54,7 @@
 from guppylang.std._internal.compiler.list import (
     list_new,
 )
+from guppylang.std._internal.compiler.prelude import build_error, build_panic
 from guppylang.tys.arg import Argument
 from guppylang.tys.builtin import (
     get_element_type,
@@ -473,6 +475,19 @@ def visit_ResultExpr(self, node: ResultExpr) -> Wire:
         self.builder.add_op(op, self.visit(node.value))
         return self._pack_returns([], NoneType())
 
+    def visit_PanicExpr(self, node: PanicExpr) -> Wire:
+        err = build_error(self.builder, 1, node.msg)
+        in_tys = [get_type(e).to_hugr() for e in node.values]
+        out_tys = [ty.to_hugr() for ty in type_to_row(get_type(node))]
+        outs = build_panic(
+            self.builder,
+            in_tys,
+            out_tys,
+            err,
+            *(self.visit(e) for e in node.values),
+        ).outputs()
+        return self._pack_returns(list(outs), get_type(node))
+
     def visit_DesugaredListComp(self, node: DesugaredListComp) -> Wire:
         # Make up a name for the list under construction and bind it to an empty list
         list_ty = get_type(node)
diff --git a/guppylang/nodes.py b/guppylang/nodes.py
index 89d03787..63383741 100644
--- a/guppylang/nodes.py
+++ b/guppylang/nodes.py
@@ -281,6 +281,15 @@ class ResultExpr(ast.expr):
     _fields = ("value", "base_ty", "array_len", "tag")
 
 
+class PanicExpr(ast.expr):
+    """A `panic(msg, *args)` expression."""
+
+    msg: str
+    values: list[ast.expr]
+
+    _fields = ("msg", "values")
+
+
 class InoutReturnSentinel(ast.expr):
     """An invisible expression corresponding to an implicit use of borrowed vars
     whenever a function returns."""
diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py
index 55d5613f..9ac74a57 100644
--- a/guppylang/std/_internal/checker.py
+++ b/guppylang/std/_internal/checker.py
@@ -32,6 +32,7 @@
     GenericParamValue,
     GlobalCall,
     MakeIter,
+    PanicExpr,
     ResultExpr,
 )
 from guppylang.tys.arg import ConstArg, TypeArg
@@ -355,6 +356,42 @@ def _is_numeric_or_bool_type(ty: Type) -> bool:
         return isinstance(ty, NumericType) or is_bool_type(ty)
 
 
+class PanicChecker(CustomCallChecker):
+    """Call checker for the `panic` function."""
+
+    @dataclass(frozen=True)
+    class NoMessageError(Error):
+        title: ClassVar[str] = "No panic message"
+        span_label: ClassVar[str] = "Missing message argument to panic call"
+
+        @dataclass(frozen=True)
+        class Suggestion(Note):
+            message: ClassVar[str] = 'Add a message: `panic("message")`'
+
+    def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
+        match args:
+            case []:
+                err = PanicChecker.NoMessageError(self.node)
+                err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
+                raise GuppyTypeError(err)
+            case [msg, *rest]:
+                if not isinstance(msg, ast.Constant) or not isinstance(msg.value, str):
+                    raise GuppyTypeError(ExpectedError(msg, "a string literal"))
+
+                vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
+                node = PanicExpr(msg.value, vals)
+                return with_loc(self.node, node), NoneType()
+            case args:
+                return assert_never(args)  # type: ignore[arg-type]
+
+    def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
+        # Panic may return any type, so we don't have to check anything. Consequently
+        # we also can't infer anything in the expected type, so we always return an
+        # empty substitution
+        expr, _ = self.synthesize(args)
+        return expr, {}
+
+
 class RangeChecker(CustomCallChecker):
     """Call checker for the `range` function."""
 
diff --git a/guppylang/std/_internal/compiler/prelude.py b/guppylang/std/_internal/compiler/prelude.py
index 518ea706..06162c18 100644
--- a/guppylang/std/_internal/compiler/prelude.py
+++ b/guppylang/std/_internal/compiler/prelude.py
@@ -7,6 +7,7 @@
 
 import hugr.std.collections
 import hugr.std.int
+import hugr.std.prelude
 from hugr import Node, Wire, ops
 from hugr import tys as ht
 from hugr import val as hv
@@ -63,7 +64,7 @@ def panic(inputs: list[ht.Type], outputs: list[ht.Type]) -> ops.ExtOp:
 
 
 def build_panic(
-    builder: DfBase[ops.Case],
+    builder: DfBase[P],
     in_tys: ht.TypeRow,
     out_tys: ht.TypeRow,
     err: Wire,
@@ -74,7 +75,7 @@ def build_panic(
     return builder.add_op(op, err, *args)
 
 
-def build_error(builder: DfBase[ops.Case], signal: int, msg: str) -> Wire:
+def build_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
     """Constructs and loads a static error value."""
     val = ErrorVal(signal, msg)
     return builder.load(builder.add_const(val))
diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py
index aaddb45a..bedad29d 100644
--- a/guppylang/std/builtins.py
+++ b/guppylang/std/builtins.py
@@ -13,6 +13,7 @@
     CallableChecker,
     DunderChecker,
     NewArrayChecker,
+    PanicChecker,
     RangeChecker,
     ResultChecker,
     ReversingChecker,
@@ -651,6 +652,10 @@ def __iter__(self: "SizedIter[L, n]" @ owned) -> "SizedIter[L, n]":  # type: ign
 def result(tag, value): ...
 
 
+@guppy.custom(checker=PanicChecker(), higher_order_value=False)
+def panic(msg, *args): ...
+
+
 @guppy.custom(checker=DunderChecker("__abs__"), higher_order_value=False)
 def abs(x): ...
 
diff --git a/tests/error/misc_errors/panic_msg_empty.err b/tests/error/misc_errors/panic_msg_empty.err
new file mode 100644
index 00000000..ddcab328
--- /dev/null
+++ b/tests/error/misc_errors/panic_msg_empty.err
@@ -0,0 +1,10 @@
+Error: No panic message (at $FILE:7:4)
+  | 
+5 | @compile_guppy
+6 | def foo(x: int) -> None:
+7 |     panic()
+  |     ^^^^^^^ Missing message argument to panic call
+
+Note: Add a message: `panic("message")`
+
+Guppy compilation failed due to 1 previous error
diff --git a/tests/error/misc_errors/panic_msg_empty.py b/tests/error/misc_errors/panic_msg_empty.py
new file mode 100644
index 00000000..38c0ffc8
--- /dev/null
+++ b/tests/error/misc_errors/panic_msg_empty.py
@@ -0,0 +1,7 @@
+from guppylang.std.builtins import panic
+from tests.util import compile_guppy
+
+
+@compile_guppy
+def foo(x: int) -> None:
+    panic()
diff --git a/tests/error/misc_errors/panic_msg_not_str.err b/tests/error/misc_errors/panic_msg_not_str.err
new file mode 100644
index 00000000..0e354800
--- /dev/null
+++ b/tests/error/misc_errors/panic_msg_not_str.err
@@ -0,0 +1,8 @@
+Error: Expected a string literal (at $FILE:7:10)
+  | 
+5 | @compile_guppy
+6 | def foo(x: int) -> None:
+7 |     panic((), x)
+  |           ^^ Expected a string literal
+
+Guppy compilation failed due to 1 previous error
diff --git a/tests/error/misc_errors/panic_msg_not_str.py b/tests/error/misc_errors/panic_msg_not_str.py
new file mode 100644
index 00000000..bdafb447
--- /dev/null
+++ b/tests/error/misc_errors/panic_msg_not_str.py
@@ -0,0 +1,7 @@
+from guppylang.std.builtins import panic
+from tests.util import compile_guppy
+
+
+@compile_guppy
+def foo(x: int) -> None:
+    panic((), x)
diff --git a/tests/error/misc_errors/panic_tag_not_static.err b/tests/error/misc_errors/panic_tag_not_static.err
new file mode 100644
index 00000000..0d732c94
--- /dev/null
+++ b/tests/error/misc_errors/panic_tag_not_static.err
@@ -0,0 +1,8 @@
+Error: Expected a string literal (at $FILE:7:10)
+  | 
+5 | @compile_guppy
+6 | def foo(y: bool) -> None:
+7 |     panic("foo" + "bar", y)
+  |           ^^^^^^^^^^^^^ Expected a string literal
+
+Guppy compilation failed due to 1 previous error
diff --git a/tests/error/misc_errors/panic_tag_not_static.py b/tests/error/misc_errors/panic_tag_not_static.py
new file mode 100644
index 00000000..d7a1d155
--- /dev/null
+++ b/tests/error/misc_errors/panic_tag_not_static.py
@@ -0,0 +1,7 @@
+from guppylang.std.builtins import panic
+from tests.util import compile_guppy
+
+
+@compile_guppy
+def foo(y: bool) -> None:
+    panic("foo" + "bar", y)
diff --git a/tests/integration/test_panic.py b/tests/integration/test_panic.py
new file mode 100644
index 00000000..a6a342a6
--- /dev/null
+++ b/tests/integration/test_panic.py
@@ -0,0 +1,38 @@
+from guppylang import GuppyModule, guppy
+from guppylang.std.builtins import panic
+from tests.util import compile_guppy
+
+
+def test_basic(validate):
+    @compile_guppy
+    def main() -> None:
+        panic("I panicked!")
+
+    validate(main)
+
+
+def test_discard(validate):
+    @compile_guppy
+    def main() -> None:
+        a = 1 + 2
+        panic("I panicked!", False, a)
+
+    validate(main)
+
+
+def test_value(validate):
+    module = GuppyModule("test")
+
+    @guppy(module)
+    def foo() -> int:
+        return panic("I panicked!")
+
+    @guppy(module)
+    def bar() -> tuple[int, float]:
+        return panic("I panicked!")
+
+    @guppy(module)
+    def baz() -> None:
+        return panic("I panicked!")
+
+    validate(module.compile())
diff --git a/tests/integration/test_quantum.py b/tests/integration/test_quantum.py
index 17e77ace..1dfe344f 100644
--- a/tests/integration/test_quantum.py
+++ b/tests/integration/test_quantum.py
@@ -1,5 +1,6 @@
 """Various tests for the functions defined in `guppylang.prelude.quantum`."""
 
+from typing import no_type_check
 from hugr.package import ModulePointer
 
 import guppylang.decorator
@@ -154,3 +155,15 @@ def test() -> None:
         discard_array(qs)
 
     validate(test)
+
+
+def test_panic_discard(validate):
+    """Panic while discarding qubit."""
+
+    @compile_quantum_guppy
+    @no_type_check
+    def test() -> None:
+        q = qubit()
+        panic("I panicked!", q)
+
+    validate(test)