From e6abfe75670e41bc803d8586191e9d384249bc7e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 5 Sep 2024 18:56:31 +0200 Subject: [PATCH 1/2] feat: Add implicit importing of modules --- README.md | 10 ++-- examples/random_walk_qpe.py | 14 ++---- examples/t_factory.py | 14 ++---- guppylang/decorator.py | 47 ++++++++++--------- quickstart.md | 10 ++-- .../misc_errors/implicit_module_error.err | 7 +++ .../misc_errors/implicit_module_error.py | 9 ++++ tests/integration/test_decorator.py | 4 +- tests/integration/test_docstring.py | 2 +- 9 files changed, 64 insertions(+), 53 deletions(-) create mode 100644 tests/error/misc_errors/implicit_module_error.err create mode 100644 tests/error/misc_errors/implicit_module_error.py diff --git a/README.md b/README.md index fae568ec..04ca059a 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,12 @@ Guppy is a quantum programming language that is fully embedded into Python. It allows you to write high-level hybrid quantum programs with classical control flow and mid-circuit measurements using Pythonic syntax: ```python -from guppylang import guppy, qubit, quantum +from guppylang import guppy +from guppylang.prelude.quantum import cx, h, measure, qubit, x, z -guppy.load_all(quantum) - - -# Teleports the state in `src` to `tgt`. @guppy def teleport(src: qubit, tgt: qubit) -> qubit: + """Teleports the state in `src` to `tgt`.""" # Create ancilla and entangle it with src and tgt tmp = qubit() tmp, tgt = cx(h(tmp), tgt) @@ -31,6 +29,8 @@ def teleport(src: qubit, tgt: qubit) -> qubit: if measure(tmp): tgt = x(tgt) return tgt + +guppy.compile_module() ``` More examples and tutorials are available [here][examples]. diff --git a/examples/random_walk_qpe.py b/examples/random_walk_qpe.py index 2862875b..5b0900b0 100644 --- a/examples/random_walk_qpe.py +++ b/examples/random_walk_qpe.py @@ -8,20 +8,16 @@ import math from collections.abc import Callable -import guppylang.prelude.quantum as quantum from guppylang.decorator import guppy -from guppylang.module import GuppyModule from guppylang.prelude.builtins import py, result from guppylang.prelude.quantum import cx, discard, h, measure, qubit, rz, x -module = GuppyModule("test") -module.load_all(quantum) sqrt_e = math.sqrt(math.e) sqrt_e_div = math.sqrt((math.e - 1) / math.e) -@guppy(module) +@guppy def random_walk_phase_estimation( eigenstate: Callable[[], qubit], controlled_oracle: Callable[[qubit, qubit, float], tuple[qubit, qubit]], @@ -64,7 +60,7 @@ def random_walk_phase_estimation( return mu -@guppy(module) +@guppy def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qubit]: """A controlled e^itH gate for the example Hamiltonian H = -0.5 * Z""" # This is just a controlled rz gate @@ -75,14 +71,14 @@ def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qu return cx(q1, q2) -@guppy(module) +@guppy def example_eigenstate() -> qubit: """The eigenstate of e^itH for the example Hamiltonian H = -0.5 * Z""" # This is just |1> return x(qubit()) -@guppy(module) +@guppy def main() -> int: num_iters = 24 # To avoid underflows reset_rate = 8 @@ -100,4 +96,4 @@ def main() -> int: return 0 -hugr = module.compile() +hugr = guppy.compile_module() diff --git a/examples/t_factory.py b/examples/t_factory.py index 1fa9c0ee..99453a29 100644 --- a/examples/t_factory.py +++ b/examples/t_factory.py @@ -1,27 +1,23 @@ import numpy as np from guppylang.decorator import guppy -from guppylang.module import GuppyModule from guppylang.prelude.builtins import linst, py from guppylang.prelude.quantum import ( cz, discard, h, measure, - quantum, qubit, rx, rz, ) -module = GuppyModule("t_factory") -module.load_all(quantum) phi = np.arccos(1 / 3) pi = np.pi -@guppy(module) +@guppy def ry(q: qubit, theta: float) -> qubit: q = rx(q, py(pi / 2)) q = rz(q, theta + py(pi)) @@ -30,7 +26,7 @@ def ry(q: qubit, theta: float) -> qubit: # Preparation of approximate T state, from https://arxiv.org/abs/2310.12106 -@guppy(module) +@guppy def prepare_approx(q: qubit) -> qubit: phi_ = py(phi) pi_ = py(pi) @@ -40,7 +36,7 @@ def prepare_approx(q: qubit) -> qubit: # The inverse of the [[5,3,1]] encoder in figure 3 of https://arxiv.org/abs/2208.01863 -@guppy(module) +@guppy def distill( target: qubit, q0: qubit, q1: qubit, q2: qubit, q3: qubit ) -> tuple[qubit, bool]: @@ -63,7 +59,7 @@ def distill( return target, success -@guppy(module) +@guppy def t_state(timeout: int) -> tuple[linst[qubit], bool]: """Create a T state using magic state distillation with `timeout` attempts. @@ -93,4 +89,4 @@ def t_state(timeout: int) -> tuple[linst[qubit], bool]: return [], False -hugr = module.compile() +hugr = guppy.compile_module() diff --git a/guppylang/decorator.py b/guppylang/decorator.py index e1811e2f..855bdbd7 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -9,8 +9,9 @@ from hugr import Hugr, ops from hugr import tys as ht +import guppylang from guppylang.ast_util import annotate_location, has_empty_body -from guppylang.definition.common import DefId +from guppylang.definition.common import DefId, Definition from guppylang.definition.custom import ( CustomCallChecker, CustomCallCompiler, @@ -49,13 +50,14 @@ class ModuleIdentifier: #: module, we only take the module path into account. name: str = field(compare=False) + #: A reference to the python module + module: ModuleType | None = field(compare=False) + class _Guppy: """Class for the `@guppy` decorator.""" - # The currently-alive GuppyModules, associated with a Python file/module. - # - # Only contains **uncompiled** modules. + # The currently-alive GuppyModules, associated with a Python file/module _modules: dict[ModuleIdentifier, GuppyModule] def __init__(self) -> None: @@ -72,11 +74,7 @@ def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionD # Decorator used without any arguments. # We default to a module associated with the caller of the decorator. f = arg - - caller = self._get_python_caller(f) - if caller not in self._modules: - self._modules[caller] = GuppyModule(caller.name) - module = self._modules[caller] + module = self.get_module() return module.register_func_def(f) if isinstance(arg, GuppyModule): @@ -101,12 +99,14 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: if s.filename != __file__: filename = s.filename module = inspect.getmodule(s.frame) - break + # Skip frames from the `pretty_error` decorator + if module != guppylang.error: + break else: raise GuppyError("Could not find a caller for the `@guppy` decorator") module_path = Path(filename) return ModuleIdentifier( - module_path, module.__name__ if module else module_path.name + module_path, module.__name__ if module else module_path.name, module ) @pretty_errors @@ -294,23 +294,26 @@ def load(self, m: ModuleType | GuppyModule) -> None: module = self._modules[caller] module.load_all(m) - def take_module(self, id: ModuleIdentifier | None = None) -> GuppyModule: - """Returns the local GuppyModule, removing it from the local state.""" - orig_id = id + def get_module(self, id: ModuleIdentifier | None = None) -> GuppyModule: + """Returns the local GuppyModule.""" if id is None: id = self._get_python_caller() if id not in self._modules: - err = ( - f"Module {orig_id.name} not found." - if orig_id - else "No Guppy functions or types defined in this module." - ) - raise MissingModuleError(err) - return self._modules.pop(id) + self._modules[id] = GuppyModule(id.name) + module = self._modules[id] + # Update implicit imports + if id.module: + defs = { + x: v + for x, v in id.module.__dict__.items() + if isinstance(v, Definition) and v.id.module != module + } + module.load(**defs) + return module def compile_module(self, id: ModuleIdentifier | None = None) -> Hugr[ops.Module]: """Compiles the local module into a Hugr.""" - module = self.take_module(id) + module = self.get_module(id) if not module: err = ( f"Module {id.name} not found." diff --git a/quickstart.md b/quickstart.md index 40e398ab..3cdca383 100644 --- a/quickstart.md +++ b/quickstart.md @@ -3,14 +3,12 @@ allows you to write high-level hybrid quantum programs with classical control flow and mid-circuit measurements using Pythonic syntax: ```python -from guppylang import guppy, qubit, quantum +from guppylang import guppy +from guppylang.prelude.quantum import cx, h, measure, qubit, x, z -guppy.load_all(quantum) - - -# Teleports the state in `src` to `tgt`. @guppy def teleport(src: qubit, tgt: qubit) -> qubit: + """Teleports the state in `src` to `tgt`.""" # Create ancilla and entangle it with src and tgt tmp = qubit() tmp, tgt = cx(h(tmp), tgt) @@ -22,4 +20,6 @@ def teleport(src: qubit, tgt: qubit) -> qubit: if measure(tmp): tgt = x(tgt) return tgt + +guppy.compile_module() ``` diff --git a/tests/error/misc_errors/implicit_module_error.err b/tests/error/misc_errors/implicit_module_error.err new file mode 100644 index 00000000..c2afc3f8 --- /dev/null +++ b/tests/error/misc_errors/implicit_module_error.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return 1.0 + ^^^ +GuppyTypeError: Expected return value of type `int`, got `float` diff --git a/tests/error/misc_errors/implicit_module_error.py b/tests/error/misc_errors/implicit_module_error.py new file mode 100644 index 00000000..1821b448 --- /dev/null +++ b/tests/error/misc_errors/implicit_module_error.py @@ -0,0 +1,9 @@ +from guppylang import guppy + + +@guppy +def foo() -> int: + return 1.0 + + +guppy.compile_module() diff --git a/tests/integration/test_decorator.py b/tests/integration/test_decorator.py index 09d9a073..80dc48b0 100644 --- a/tests/integration/test_decorator.py +++ b/tests/integration/test_decorator.py @@ -17,7 +17,7 @@ def b() -> None: def c() -> None: pass - default_module = guppy.take_module() + default_module = guppy.get_module() assert not module.contains("a") assert module.contains("b") @@ -34,7 +34,7 @@ def make_module() -> GuppyModule: def a() -> None: pass - return guppy.take_module() + return guppy.get_module() module_a = make_module() module_b = make_module() diff --git a/tests/integration/test_docstring.py b/tests/integration/test_docstring.py index 80227f30..c9312177 100644 --- a/tests/integration/test_docstring.py +++ b/tests/integration/test_docstring.py @@ -32,7 +32,7 @@ def g_nested() -> None: string. """ - default_module = guppy.take_module() + default_module = guppy.get_module() validate(default_module.compile()) From 52b7b51bd84fd32a70d7ac5e7e48aba736b2f430 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 5 Sep 2024 19:51:28 +0200 Subject: [PATCH 2/2] Allow imports from implicit modules --- guppylang/decorator.py | 26 ++++++++++++++--------- guppylang/module.py | 9 ++++++++ tests/integration/modules/implicit_mod.py | 6 ++++++ tests/integration/test_imports.py | 26 +++++++++++++++++++++++ 4 files changed, 57 insertions(+), 10 deletions(-) create mode 100644 tests/integration/modules/implicit_mod.py diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 855bdbd7..35d7ef35 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -1,6 +1,6 @@ import ast import inspect -from collections.abc import Callable +from collections.abc import Callable, KeysView from dataclasses import dataclass, field from pathlib import Path from types import ModuleType @@ -27,7 +27,7 @@ from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import GuppyError, MissingModuleError, pretty_errors -from guppylang.module import GuppyModule, PyFunc +from guppylang.module import GuppyModule, PyFunc, find_guppy_module_in_py_module from guppylang.tys.subst import Inst from guppylang.tys.ty import NumericType @@ -299,15 +299,21 @@ def get_module(self, id: ModuleIdentifier | None = None) -> GuppyModule: if id is None: id = self._get_python_caller() if id not in self._modules: - self._modules[id] = GuppyModule(id.name) + self._modules[id] = GuppyModule(id.name.split(".")[-1]) module = self._modules[id] # Update implicit imports if id.module: - defs = { - x: v - for x, v in id.module.__dict__.items() - if isinstance(v, Definition) and v.id.module != module - } + defs: dict[str, Definition | ModuleType] = {} + for x, value in id.module.__dict__.items(): + if isinstance(value, Definition) and value.id.module != module: + defs[x] = value + elif isinstance(value, ModuleType): + try: + other_module = find_guppy_module_in_py_module(value) + if other_module and other_module != module: + defs[x] = value + except GuppyError: + pass module.load(**defs) return module @@ -323,9 +329,9 @@ def compile_module(self, id: ModuleIdentifier | None = None) -> Hugr[ops.Module] raise MissingModuleError(err) return module.compile() - def registered_modules(self) -> list[ModuleIdentifier]: + def registered_modules(self) -> KeysView[ModuleIdentifier]: """Returns a list of all currently registered modules for local contexts.""" - return list(self._modules.keys()) + return self._modules.keys() guppy = _Guppy() diff --git a/guppylang/module.py b/guppylang/module.py index 841a535f..f6942701 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -1,6 +1,7 @@ import inspect import sys from collections.abc import Callable, Mapping +from pathlib import Path from types import ModuleType from typing import Any @@ -351,6 +352,14 @@ def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule: Raises a user-error if no unique module can be found. """ mods = [val for val in module.__dict__.values() if isinstance(val, GuppyModule)] + # Also include implicit modules + from guppylang.decorator import ModuleIdentifier, guppy + + if hasattr(module, "__file__") and module.__file__: + module_id = ModuleIdentifier(Path(module.__file__), module.__name__, module) + if module_id in guppy.registered_modules(): + mods.append(guppy.get_module(module_id)) + if not mods: msg = f"No Guppy modules found in `{module.__name__}`" raise GuppyError(msg) diff --git a/tests/integration/modules/implicit_mod.py b/tests/integration/modules/implicit_mod.py new file mode 100644 index 00000000..232f8a3b --- /dev/null +++ b/tests/integration/modules/implicit_mod.py @@ -0,0 +1,6 @@ +from guppylang import guppy + + +@guppy +def foo(x: int) -> int: + return x + 1 diff --git a/tests/integration/test_imports.py b/tests/integration/test_imports.py index 38020603..49443740 100644 --- a/tests/integration/test_imports.py +++ b/tests/integration/test_imports.py @@ -28,6 +28,19 @@ def test(x: MyType) -> MyType: validate(module.compile()) +def test_import_implicit(validate): + from tests.integration.modules.implicit_mod import foo + + module = GuppyModule("test") + module.load(foo) + + @guppy(module) + def test(x: int) -> int: + return foo(x) + + validate(module.compile()) + + def test_func_alias(validate): from tests.integration.modules.mod_a import f as g @@ -151,3 +164,16 @@ def test(x: mod_a.MyType, y: mod_b.MyType) -> tuple[mod_a.MyType, mod_b.MyType]: return -x, +y validate(module.compile()) + + +def test_qualified_implicit(validate): + import tests.integration.modules.implicit_mod as implicit_mod + + module = GuppyModule("test") + module.load(implicit_mod) + + @guppy(module) + def test(x: int) -> int: + return implicit_mod.foo(x) + + validate(module.compile())