Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support hugr packages, fix the notebooks #622

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tket2-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ derive_more = { workspace = true }
itertools = { workspace = true }
portmatching = { workspace = true }
strum = { workspace = true }
# Required to acces the `Package` type.
# Remove once https://github.com/CQCL/hugr/issues/1530 is fixed.
hugr-cli = { workspace = true }

[dev-dependencies]
rstest = { workspace = true }
Expand Down
3,204 changes: 2,909 additions & 295 deletions tket2-py/examples/1-Getting-Started.ipynb

Large diffs are not rendered by default.

109 changes: 67 additions & 42 deletions tket2-py/examples/2-Rewriting-Circuits.ipynb

Large diffs are not rendered by default.

32 changes: 22 additions & 10 deletions tket2-py/examples/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
"""Some utility functions for the example notebooks."""

from typing import TYPE_CHECKING, Any
from hugr import Hugr
from tket2.passes import lower_to_pytket
from tket2.circuit import Tk2Circuit
from guppylang.definition.function import RawFunctionDef # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401

if TYPE_CHECKING:
try:
from guppylang.definition.function import RawFunctionDef # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401
except ImportError:
RawFunctionDef = Any

def setup_jupyter_rendering():
"""Set up hugr rendering for Jupyter notebooks."""

# We need to define this helper function for now. It will be included in guppy in the future.
def _repr_hugr(
h: Hugr, include=None, exclude=None, **kwargs
) -> dict[str, bytes | str]:
return h.render_dot()._repr_mimebundle_(include, exclude, **kwargs)

def _repr_tk2circ(
circ: Tk2Circuit, include=None, exclude=None, **kwargs
) -> dict[str, bytes | str]:
h = Hugr.load_json(circ.to_hugr_json())
return _repr_hugr(h, include, exclude, **kwargs)

setattr(Hugr, "_repr_mimebundle_", _repr_hugr)
setattr(Tk2Circuit, "_repr_mimebundle_", _repr_tk2circ)


# TODO: Should this be part of the guppy API? Or tket2?
def guppy_to_circuit(func_def: RawFunctionDef) -> Tk2Circuit:
"""Convert a Guppy function definition to a `Tk2Circuit`."""
module = func_def.id.module
assert module is not None, "Function definition must belong to a module"

hugr = module.compile()
assert hugr is not None, "Module must be compilable"
pkg = module.compile()

json = hugr.to_raw().to_json()
json = pkg.to_json()
circ = Tk2Circuit.from_guppy_json(json, func_def.name)

return lower_to_pytket(circ)
18 changes: 16 additions & 2 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! Rust-backed representation of circuits

use std::borrow::{Borrow, Cow};
use std::mem;

use hugr::builder::{CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::QB_T;
use hugr::extension::{ExtensionRegistry, EMPTY_REG};
use hugr::ops::handle::NodeHandle;
use hugr::ops::{ExtensionOp, NamedOp, OpType};
use hugr::types::Type;
use hugr_cli::Package;
use itertools::Itertools;
use pyo3::exceptions::{PyAttributeError, PyValueError};
use pyo3::types::{PyAnyMethods, PyModule, PyString, PyTypeMethods};
Expand Down Expand Up @@ -94,9 +97,20 @@ impl Tk2Circuit {
/// Decode a HUGR json string to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let hugr: Hugr = serde_json::from_str(json)
let pkg: Package = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit { circ: hugr.into() })
let mut reg = REGISTRY.clone();
let mut hugrs = pkg.validate(&mut reg).map_err(|e| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nastiness with mismatched extensions here. What happens when the in-package registry and our REGISTRY have different versions of the same extension? Nothing to do for now I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, currently we merge the extensions and keep the latest one but I see how that may cause problems.
https://github.com/CQCL/hugr/blob/836b6656adc4f014f1182a39e403184f7f8d45e7/hugr-core/src/extension.rs#L90-L95

PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
if hugrs.len() != 1 {
return Err(PyValueError::new_err(
"Invalid HUGR json: Package must contain exactly one hugr.",
));
}
Ok(Tk2Circuit {
circ: mem::take(&mut hugrs[0]).into(),
})
}

/// Load a function from a compiled guppy module, encoded as a json string.
Expand Down
39 changes: 33 additions & 6 deletions tket2-py/tket2/circuit/build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
from hugr.hugr import Hugr
from typing import Iterable

from hugr import tys, ops
from hugr.ext import Package, Extension
from hugr.ops import ComWire, Command
from hugr.std.float import FLOAT_T
from hugr.build.tracked_dfg import TrackedDfg
Expand All @@ -17,10 +19,35 @@ class CircBuild(TrackedDfg):
def with_nqb(cls, n_qb: int) -> CircBuild:
return cls(*[tys.Qubit] * n_qb, track_inputs=True)

def finish(self) -> Tk2Circuit:
def finish_package(
self, other_extensions: Iterable[Extension] | None = None
) -> Package:
"""Finish building the package by setting all the qubits as the output
and wrap it in a hugr package with the required extensions.

Args:
other_extensions: Other extensions to include in the package.
Returns:
The finished package.
"""
import tket2.extensions as ext

extensions = [
ext.rotation(),
ext.futures(),
ext.hseries(),
ext.quantum(),
ext.result(),
*(other_extensions or []),
]

return Package(modules=[self.hugr], extensions=extensions)

def finish(self, other_extensions: list[Extension] | None = None) -> Tk2Circuit:
"""Finish building the circuit by setting all the qubits as the output
and validate."""
return load_hugr(self.hugr)

return load_hugr_pkg(self.finish_package(other_extensions))


def from_coms(*args: Command) -> Tk2Circuit:
Expand All @@ -40,8 +67,8 @@ def from_coms(*args: Command) -> Tk2Circuit:
return build.finish()


def load_hugr(h: Hugr) -> Tk2Circuit:
return Tk2Circuit.from_hugr_json(h.to_json())
def load_hugr_pkg(package: Package) -> Tk2Circuit:
return Tk2Circuit.from_hugr_json(package.to_json())


def load_custom(serialized: bytes) -> ops.Custom:
Expand All @@ -61,7 +88,7 @@ def id_circ(n_qb: int) -> Tk2Circuit:

@dataclass(frozen=True)
class QuantumOps(ops.Custom):
extension: tys.ExtensionId = "quantum.tket2"
extension: tys.ExtensionId = "tket2.quantum"


_OneQbSig = tys.FunctionType.endo([tys.Qubit])
Expand Down
42 changes: 42 additions & 0 deletions tket2-py/tket2/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Hugr extension definitions for tket2 circuits."""
# This will be moved to a separate python library soon.

import pkgutil
import functools

from hugr._serialization.extension import Extension as PdExtension
from hugr.ext import Extension


@functools.cache
def rotation() -> Extension:
return load_extension("tket2.rotation")


@functools.cache
def futures() -> Extension:
return load_extension("tket2.futures")


@functools.cache
def hseries() -> Extension:
return load_extension("tket2.hseries")


@functools.cache
def quantum() -> Extension:
return load_extension("tket2.quantum")


@functools.cache
def result() -> Extension:
return load_extension("tket2.result")


def load_extension(name: str) -> Extension:
replacement = name.replace(".", "/")
json_str = pkgutil.get_data(__name__, f"_json_defs/{replacement}.json")
assert json_str is not None, f"Could not load json for extension {name}"
# TODO: Replace with `Extension.from_json` once that is implemented
# https://github.com/CQCL/hugr/issues/1523
return PdExtension.model_validate_json(json_str).deserialize()
4 changes: 4 additions & 0 deletions tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ bytemuck = { workspace = true }
crossbeam-channel = { workspace = true }
tracing = { workspace = true }
zstd = { workspace = true, optional = true }
# Required to acces the `Package` type.
# Remove once https://github.com/CQCL/hugr/issues/1530 is fixed.
hugr-cli = { workspace = true }


[dev-dependencies]
rstest = { workspace = true }
Expand Down
13 changes: 5 additions & 8 deletions tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

use crate::serialize::pytket::OpaqueTk1Op;
use crate::Tk2Op;
use hugr::extension::prelude::PRELUDE;
use hugr::extension::simple_op::MakeOpDef;
use hugr::extension::{
CustomSignatureFunc, ExtensionId, ExtensionRegistry, SignatureError, Version,
};
use hugr::hugr::IdentList;
use hugr::std_extensions::arithmetic::float_types::EXTENSION as FLOAT_TYPES;
use hugr::std_extensions::STD_REG;
use hugr::types::type_param::{TypeArg, TypeParam};
use hugr::types::{CustomType, PolyFuncType, PolyFuncTypeRV};
use hugr::Extension;
Expand Down Expand Up @@ -57,15 +56,13 @@ pub static ref TKET1_EXTENSION: Extension = {
res
};

/// Extension registry including the prelude, TKET1 and Tk2Ops extensions.
pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
/// Extension registry including the prelude, std, TKET1, and Tk2Ops extensions.
pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new(
STD_REG.iter().map(|(_, e)| e.to_owned()).chain([
TKET1_EXTENSION.to_owned(),
PRELUDE.to_owned(),
TKET2_EXTENSION.to_owned(),
FLOAT_TYPES.to_owned(),
rotation::ROTATION_EXTENSION.to_owned()
]).unwrap();

])).unwrap();

}

Expand Down
2 changes: 1 addition & 1 deletion tket2/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ pub(crate) mod test {
#[test]
fn tk2op_properties() {
for op in Tk2Op::iter() {
// The exposed name should start with "quantum.tket2."
// The exposed name should start with "tket2.quantum."
assert!(op.exposed_name().starts_with(&EXTENSION_ID.to_string()));

let ext_op = op.into_extension_op();
Expand Down
19 changes: 16 additions & 3 deletions tket2/src/serialize/guppy.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
//! Load pre-compiled guppy functions.

use std::path::Path;
use std::{fs, io};
use std::{fs, io, mem};

use hugr::ops::{NamedOp, OpTag, OpTrait, OpType};
use hugr::{Hugr, HugrView};
use hugr_cli::Package;
use itertools::Itertools;
use thiserror::Error;

use crate::extension::REGISTRY;
use crate::{Circuit, CircuitError};

/// Loads a pre-compiled guppy file.
Expand All @@ -31,7 +33,12 @@ pub fn load_guppy_json_reader(
reader: impl io::Read,
function: &str,
) -> Result<Circuit, CircuitLoadError> {
let hugr: Hugr = serde_json::from_reader(reader)?;
let pkg: Package = serde_json::from_reader(reader)?;
let mut hugrs = pkg.validate(&mut REGISTRY.clone())?;
if hugrs.len() != 1 {
return Err(CircuitLoadError::InvalidNumHugrs(hugrs.len()));
}
let hugr = mem::take(&mut hugrs[0]);
find_function(hugr, function)
}

Expand All @@ -48,7 +55,7 @@ pub fn load_guppy_json_reader(
/// - If the root of the HUGR is not a module operation.
/// - If the function is not found in the module.
/// - If the function has control flow primitives.
fn find_function(hugr: Hugr, function_name: &str) -> Result<Circuit, CircuitLoadError> {
pub fn find_function(hugr: Hugr, function_name: &str) -> Result<Circuit, CircuitLoadError> {
// Find the root module.
let module = hugr.root();
if !OpTag::ModuleRoot.is_superset(hugr.get_optype(module).tag()) {
Expand Down Expand Up @@ -139,4 +146,10 @@ pub enum CircuitLoadError {
/// Error loading the circuit.
#[error("Error loading the circuit: {0}")]
CircuitLoadError(#[from] CircuitError),
/// Error validating the loaded circuit.
#[error("{0}")]
ValError(#[from] hugr_cli::validate::ValError),
/// The encoded HUGR package must have a single HUGR.
#[error("The encoded HUGR package must have a single HUGR, but it has {0} HUGRs.")]
InvalidNumHugrs(usize),
}
Loading