Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add method to load pytket circuit without function stub #712

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
RawFunctionDef,
)
from guppylang.definition.parameter import ConstVarDef, TypeVarDef
from guppylang.definition.pytket_circuits import RawPytketDef
from guppylang.definition.pytket_circuits import (
RawLoadPytketDef,
RawPytketDef,
)
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import MissingModuleError, pretty_errors
Expand All @@ -47,11 +50,13 @@
get_calling_frame,
sphinx_running,
)
from guppylang.span import SourceMap
from guppylang.span import Loc, SourceMap, Span
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType
from guppylang.tys.ty import (
NumericType,
)

S = TypeVar("S")
T = TypeVar("T")
Expand Down Expand Up @@ -501,6 +506,27 @@ def func(f: PyFunc) -> RawPytketDef:

return func

@pretty_errors
def load_pytket(
self, name: str, input_circuit: Any, module: GuppyModule | None = None
) -> RawLoadPytketDef:
"""Adds a pytket circuit function definition with implicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.load_pytket"
try:
import pytket

if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None

except ImportError:
raise TypeError(err_msg) from None

mod = module or self.get_module()
span = _find_load_call(self._sources)
defn = RawLoadPytketDef(DefId.fresh(module), name, None, span, input_circuit)
mod.register_def(defn)
return defn


class _GuppyDummy:
"""A dummy class with the same interface as `@guppy` that is used during sphinx
Expand Down Expand Up @@ -586,3 +612,32 @@ def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.e
node.col_offset = 0
node.end_col_offset = len(source_lines[info.lineno - 1]) - 1
return expr_ast


def _find_load_call(sources: SourceMap) -> Span | None:
"""Helper function to find location where pytket circuit was loaded.

Tries to define a source code span by inspecting the call stack.
"""
# Go back as first frame outside of compiler modules is 'pretty_errors_wrapped'.
if (caller_frame := get_calling_frame()) and (load_frame := caller_frame.f_back):
info = inspect.getframeinfo(load_frame)
sources.add_file(info.filename)

filename = load_frame.f_code.co_filename
# Need to check that none of the information is None.
if (
(positions := info.positions)
and (lineno := positions.lineno)
and (col_offset := positions.col_offset)
and (end_lineno := positions.end_lineno)
and (end_col_offset := positions.end_col_offset)
):
start = Loc(filename, lineno, col_offset)
end = Loc(
filename,
end_lineno,
end_col_offset,
)
return Span(start, end)
return None
119 changes: 78 additions & 41 deletions guppylang/definition/pytket_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
)
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.error import GuppyError
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.span import SourceMap, Span, ToSpan
from guppylang.tys.builtin import bool_type
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
Expand Down Expand Up @@ -74,65 +74,73 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef":
func_ast, globals.with_python_scope(self.python_scope)
)

# Compare signatures.
# TODO: Allow arrays as arguments.
# Retrieve circuit signature and compare.
try:
import pytket

if isinstance(self.input_circuit, pytket.circuit.Circuit):
try:
import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401

qubit = cast(TypeDef, globals["qubit"]).check_instantiate(
[], globals
)

circuit_signature = FunctionType(
[FuncInput(qubit, InputFlags.Inout)]
* self.input_circuit.n_qubits,
row_to_type([bool_type()] * self.input_circuit.n_bits),
)

if not (
circuit_signature.inputs == stub_signature.inputs
and circuit_signature.output == stub_signature.output
):
# TODO: Implement pretty-printing for signatures in order to add
# a note for expected vs. actual types.
raise GuppyError(PytketSignatureMismatch(func_ast, self.name))
except ImportError:
err = Tket2NotInstalled(func_ast)
err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None))
raise GuppyError(err) from None
except ImportError:
pass
circuit_signature = _signature_from_circuit(
self.input_circuit, globals, self.defined_at
)
if not (
circuit_signature.inputs == stub_signature.inputs
and circuit_signature.output == stub_signature.output
):
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Implement pretty-printing for signatures in order to add
# a note for expected vs. actual types.
raise GuppyError(PytketSignatureMismatch(func_ast, self.name))

return ParsedPytketDef(
self.id,
self.name,
func_ast,
stub_signature,
self.python_scope,
self.input_circuit,
)


@dataclass(frozen=True)
class RawLoadPytketDef(ParsableDef):
"""A raw definition for loading pytket circuits without explicit function stub.

Args:
id: The unique definition identifier.
name: The name of the circuit function.
defined_at: The AST node of the definition (here always None).
source_span: The source span where the circuit was loaded.
input_circuit: The user-provided pytket circuit.
"""

source_span: Span | None
input_circuit: Any

description: str = field(default="pytket circuit", init=False)

def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef":
"""Creates a function signature based on the user-provided circuit."""
circuit_signature = _signature_from_circuit(
self.input_circuit, globals, self.source_span
)

return ParsedPytketDef(
self.id,
self.name,
self.defined_at,
circuit_signature,
self.input_circuit,
)


@dataclass(frozen=True)
class ParsedPytketDef(CallableDef, CompilableDef):
"""A circuit definition with parsed and checked signature.
"""A circuit definition with signature.

Args:
id: The unique definition identifier.
name: The name of the function.
defined_at: The AST node where the function was defined.
defined_at: The AST node of the function stub, if there is one.
ty: The type of the function.
python_scope: The Python scope where the function was defined.
input_circuit: The user-provided pytket circuit.
"""

defined_at: ast.FunctionDef
ty: FunctionType
python_scope: PyScope
input_circuit: Any

description: str = field(default="pytket circuit", init=False)
Expand Down Expand Up @@ -181,7 +189,6 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef"
self.name,
self.defined_at,
self.ty,
self.python_scope,
self.input_circuit,
outer_func,
)
Expand Down Expand Up @@ -214,7 +221,6 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef):
name: The name of the function.
defined_at: The AST node where the function was defined.
ty: The type of the function.
python_scope: The Python scope where the function was defined.
input_circuit: The user-provided pytket circuit.
func_df: The Hugr function definition.
"""
Expand Down Expand Up @@ -243,3 +249,34 @@ def compile_call(
"""Compiles a call to the function."""
# Use implementation from function definition.
return compile_call(args, type_args, dfg, self.ty, self.func_def)


def _signature_from_circuit(
input_circuit: Any, globals: Globals, defined_at: ToSpan | None
) -> FunctionType:
"""Helper function for inferring a function signature from a pytket circuit."""
try:
import pytket

if isinstance(input_circuit, pytket.circuit.Circuit):
try:
import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401

qubit = cast(TypeDef, globals["qubit"]).check_instantiate([], globals)

circuit_signature = FunctionType(
[FuncInput(qubit, InputFlags.Inout)] * input_circuit.n_qubits,
row_to_type([bool_type()] * input_circuit.n_bits),
)
except ImportError:
err = Tket2NotInstalled(defined_at)
err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None))
raise GuppyError(err) from None
else:
pass
except ImportError:
raise InternalGuppyError(
"Pytket error should have been caught earlier"
) from None
else:
return circuit_signature
26 changes: 25 additions & 1 deletion tests/integration/test_pytket_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def foo(q: qubit) -> bool:


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
@pytest.mark.skip("Not implemented")
def test_load_circuit(validate):
from pytket import Circuit

Expand All @@ -134,4 +133,29 @@ def test_load_circuit(validate):
def foo(q: qubit) -> None:
guppy_circ(q)

validate(module.compile())


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
def test_load_circuits(validate):
from pytket import Circuit

circ1 = Circuit(1)
circ1.H(0)

circ2 = Circuit(2)
circ2.CX(0, 1)
circ2.measure_all()

module = GuppyModule("test")
module.load_all(quantum)

guppy.load_pytket("guppy_circ1", circ1, module)
guppy.load_pytket("guppy_circ2", circ2, module)

@guppy(module)
def foo(q1: qubit, q2: qubit, q3: qubit) -> tuple[bool, bool]:
guppy_circ1(q1)
return guppy_circ2(q2, q3)

validate(module.compile())
Loading