Skip to content

Commit

Permalink
feat!: Track circuit extensions and read/write packages (#680)
Browse files Browse the repository at this point in the history
Removes the guppy-specific and adds supports for loading functions
packages and standalone hugrs.

Temporarily keeps track of the required extensions for the hugr in an
optional `Circuit::required_extensions` field until
CQCL/hugr#1613 gets implemented. Fallbacks to
a default set when loading bare hugrs.

Note that storing a circuit with a non-root parent is currently an
error. We'll need to store some pointer to the entrypoint on the hugr's
metadata, and that'll require some serialization-stable path encoding.
I'll open an issue for that .

blocked-by: CQCL/hugr#1621. I'll remove the
patch in cargo.toml once that gets released.

drive-by: Use `circuit_hash` for the `PartialEq` implementation of
circuits. The derived equality failed on graphs with different node
indices.

BREAKING CHANGE: Removed `load_guppy_*` methods. Use
`Circuit::load_function_reader` instead.
  • Loading branch information
aborgna-q authored Nov 7, 2024
1 parent 5cd934b commit 5e87dd9
Show file tree
Hide file tree
Showing 10 changed files with 975 additions and 311 deletions.
442 changes: 338 additions & 104 deletions Cargo.lock

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ license = "Apache-2.0"
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(ci_run)'] }
missing_docs = "warn"

[patch.crates-io]

# Uncomment to use unreleased versions of hugr
#hugr-core = { git = "https://github.com/CQCL/hugr.git" }
#hugr = { git = "https://github.com/CQCL/hugr.git" }
#hugr-cli = { git = "https://github.com/CQCL/hugr.git" }

[workspace.dependencies]

# Make sure to run `just recompile-eccs` if the hugr serialisation format changes.
hugr = "0.13.2"
hugr-core = "0.13.2"
hugr-cli = "0.13.2"
hugr = "0.13.3"
hugr-core = "0.13.3"
hugr-cli = "0.13.3"
portgraph = "0.12"
pyo3 = "0.22.5"
itertools = "0.13.0"
Expand Down
2 changes: 1 addition & 1 deletion tket2-py/examples/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def guppy_to_circuit(func_def: RawFunctionDef) -> Tk2Circuit:
pkg = module.compile()

json = pkg.to_json()
circ = Tk2Circuit.from_guppy_json(json, func_def.name)
circ = Tk2Circuit.from_package_json(json, func_def.name)

return lower_to_pytket(circ)
59 changes: 41 additions & 18 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Rust-backed representation of circuits
use std::borrow::{Borrow, Cow};
use std::fmt::Display;
use std::mem;

use hugr::builder::{CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr};
Expand Down Expand Up @@ -91,32 +92,54 @@ impl Tk2Circuit {
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_hugr_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(self.circ.hugr()).unwrap())
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not encode circuit: {e}"))
};
let mut buf = Vec::new();
self.circ.to_hugr_writer(&mut buf).map_err(err)?;
let res = std::str::from_utf8(&buf).map_err(err)?;
Ok(res.to_string())
}

/// Decode a HUGR json string to a circuit.
/// Encode the circuit as a Hugr Package json string.
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_package_json(&self) -> PyResult<String> {
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not encode circuit: {e}"))
};
let mut buf = Vec::new();
self.circ.to_package_writer(&mut buf).map_err(err)?;
let res = std::str::from_utf8(&buf).map_err(err)?;
Ok(res.to_string())
}

/// Decode a HUGR json to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let mut pkg: Package = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
let mut reg = REGISTRY.clone();
pkg.update_validate(&mut reg).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
let Ok(hugr) = pkg.modules.into_iter().exactly_one() else {
return Err(PyValueError::new_err(
"Invalid HUGR json: Package must contain exactly one hugr.",
));
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not read hugr: {e}"))
};
Ok(Tk2Circuit { circ: hugr.into() })
let circ = Circuit::load_hugr_reader(json.as_bytes()).map_err(err)?;
Ok(Tk2Circuit { circ })
}

/// Load a function from a compiled guppy module, encoded as a json string.
/// Decode a HUGR Package json to a circuit.
///
/// Traverses the package's modules in order until it finds one containing a
/// function named `function_name`, and loads it as a circuit.
///
/// If the json is a hugr json, it will be decoded as a `main` function in an empty module.
///
/// When `function_name` is not given, it defaults to `main`.
#[staticmethod]
pub fn from_guppy_json(json: &str, function: &str) -> PyResult<Self> {
let circ = tket2::serialize::load_guppy_json_str(json, function).map_err(|e| {
PyErr::new::<PyAttributeError, _>(format!("Invalid encoded circuit: {e}"))
})?;
#[pyo3(signature = (json, function_name = None))]
pub fn from_package_json(json: &str, function_name: Option<String>) -> PyResult<Self> {
fn err(e: impl Display) -> PyErr {
PyErr::new::<PyAttributeError, _>(format!("Could not read package: {e}"))
};
let name = function_name.unwrap_or_else(|| "main".to_string());
let circ = Circuit::load_function_reader(json.as_bytes(), &name).map_err(err)?;
Ok(Tk2Circuit { circ })
}

Expand Down
25 changes: 19 additions & 6 deletions tket2-py/tket2/_tket2/circuit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,31 @@ class Tk2Circuit:
"""The output node of the circuit."""

def to_hugr_json(self) -> str:
"""Encode the circuit as a HUGR json string."""
"""Encode the circuit as a HUGR json."""

def to_package_json(self) -> str:
"""Encode the circuit as a HUGR Package json."""

@staticmethod
def from_hugr_json(json: str) -> Tk2Circuit:
"""Decode a HUGR json string to a Tk2Circuit."""

def to_tket1_json(self) -> str:
"""Encode the circuit as a pytket json string."""

@staticmethod
def from_guppy_json(json: str, function: str) -> Tk2Circuit:
"""Load a function from a compiled guppy module, encoded as a json string."""
def from_package_json(json: str, function_name: str | None = None) -> Tk2Circuit:
"""Decode a HUGR Package json to a circuit.
Traverses the package's modules in order until it finds one containing a
function named `function_name`, and loads it as a circuit.
If the json is a hugr json, it will be decoded as a `main` function in an empty module.
When `function_name` is not given, it defaults to `main`.
"""

def to_tket1_json(
self,
) -> str:
"""Encode the circuit as a pytket json string."""

@staticmethod
def from_tket1_json(json: str) -> Tk2Circuit:
Expand Down
41 changes: 33 additions & 8 deletions tket2-py/tket2/circuit/build.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations
from typing import Iterable

from hugr import tys, ops
from hugr import Hugr, tys, ops
from hugr.package import Package
from hugr.ext import Extension
from hugr.ops import ComWire, Command
from hugr.std.float import FLOAT_T
from hugr.build.function import Module
from hugr.build.tracked_dfg import TrackedDfg
from tket2.circuit import Tk2Circuit

Expand All @@ -20,17 +21,32 @@ class CircBuild(TrackedDfg):
def with_nqb(cls, n_qb: int) -> CircBuild:
return cls(*[tys.Qubit] * n_qb, track_inputs=True)

def finish_hugr(self) -> Hugr:
"""Finish building the hugr by setting all the qubits as the output.
Returns:
The finished Hugr.
"""
return self.hugr

def finish_package(
self, other_extensions: Iterable[Extension] | None = None
self,
*,
other_extensions: Iterable[Extension] | None = None,
function_name="main",
) -> Package:
"""Finish building the package by setting all the qubits as the output
and wrap it in a hugr package with the required extensions.
Args:
other_extensions: Other extensions to include in the package.
function_name: The name of the function containing the circuit in
the package's module. Defaults to "main".
Returns:
The finished package.
"""
# TODO: Replace with `finish_hugr` once extensions are included in the hugr itself.
# See https://github.com/CQCL/hugr/pull/1621
import tket2.extensions as ext

extensions = [
Expand All @@ -42,13 +58,26 @@ def finish_package(
*(other_extensions or []),
]

return Package(modules=[self.hugr], extensions=extensions)
# Convert the DFG into a Function definition
dfg_op = self.hugr[self.hugr.root].op
assert type(dfg_op) is ops.DFG, "CircBuild must have a Dfg root"
self.hugr[self.hugr.root].op = ops.FuncDefn(
function_name, inputs=dfg_op.inputs, _outputs=dfg_op.outputs
)

# Insert it into a module, as required by the package.
module = Module()
module.hugr.insert_hugr(self.hugr)

return Package(modules=[module.hugr], extensions=extensions)

def finish(self, other_extensions: list[Extension] | None = None) -> Tk2Circuit:
"""Finish building the circuit by setting all the qubits as the output
and validate."""

return load_hugr_pkg(self.finish_package(other_extensions))
return Tk2Circuit.from_package_json(
self.finish_package(other_extensions=other_extensions).to_json()
)


def from_coms(*args: Command) -> Tk2Circuit:
Expand All @@ -68,10 +97,6 @@ def from_coms(*args: Command) -> Tk2Circuit:
return build.finish()


def load_hugr_pkg(package: Package) -> Tk2Circuit:
return Tk2Circuit.from_hugr_json(package.to_json())


def load_custom(serialized: bytes) -> ops.Custom:
import hugr._serialization.ops as sops
import json
Expand Down
80 changes: 76 additions & 4 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod units;

use std::collections::HashSet;
use std::iter::Sum;
use std::mem;
use std::sync::Arc;

pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
Expand All @@ -20,7 +22,7 @@ use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::dataflow::IOTrait;
use hugr::ops::{Input, NamedOp, OpName, OpParent, OpTag, OpTrait, Output};
use hugr::types::{PolyFuncType, Signature};
use hugr::{Hugr, PortIndex};
use hugr::{Extension, Hugr, PortIndex};
use hugr::{HugrView, OutgoingPort};
use itertools::Itertools;
use lazy_static::lazy_static;
Expand All @@ -29,24 +31,46 @@ pub use hugr::ops::OpType;
pub use hugr::types::{EdgeKind, Type, TypeRow};
pub use hugr::{Node, Port, Wire};

use crate::extension;

use self::units::{filter, LinearUnit, Units};

/// A quantum circuit, represented as a function in a HUGR.
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
pub struct Circuit<T = Hugr> {
/// The HUGR containing the circuit.
hugr: T,
/// The parent node of the circuit.
///
/// This is checked at runtime to ensure that the node is a DFG node.
parent: Node,
/// An optional set of extensions required to validate the circuit,
/// not including the prelude.
///
/// Wrapped in an Arc to allow sharing between circuits, specially for borrowed circuits.
///
/// Defaults to an standard set of quantum extensions and Hugr's std set.
required_extensions: Option<Arc<Vec<Extension>>>,
}

impl<T: Default + HugrView> Default for Circuit<T> {
fn default() -> Self {
let hugr = T::default();
let parent = hugr.root();
Self { hugr, parent }
Self {
hugr,
parent,
required_extensions: None,
}
}
}

impl<T: HugrView> PartialEq for Circuit<T> {
fn eq(&self, other: &Self) -> bool {
match (self.circuit_hash(), other.circuit_hash()) {
(Ok(hash1), Ok(hash2)) => hash1 == hash2,
_ => false,
}
}
}

Expand All @@ -64,11 +88,22 @@ lazy_static! {
set.insert(format!("prelude.{}", LiftDef.name()).into());
set
};

/// A default set of required extensions for a circuit,
/// used when loading with hugr with no pre-defined extension set.
///
/// We should be able to drop this once hugrs embed their required extensions.
/// See https://github.com/CQCL/hugr/issues/1613
static ref DEFAULT_REQUIRED_EXTENSIONS: Vec<Extension> = extension::REGISTRY.iter().map(|(_, ext)| ext.clone()).collect();
}
/// The [IGNORED_EXTENSION_OPS] definition depends on the buggy behaviour of [`NamedOp::name`], which returns bare names instead of scoped names on some cases.
/// Once this test starts failing it should be time to drop the `format!("prelude.{}", ...)`.
/// https://github.com/CQCL/hugr/issues/1496
#[test]
fn issue_1496_remains() {
assert_eq!("Noop", NoopDef.name())
}

impl<T: HugrView> Circuit<T> {
/// Create a new circuit from a HUGR and a node.
///
Expand All @@ -77,7 +112,11 @@ impl<T: HugrView> Circuit<T> {
/// Returns an error if the parent node is not a DFG node in the HUGR.
pub fn try_new(hugr: T, parent: Node) -> Result<Self, CircuitError> {
check_hugr(&hugr, parent)?;
Ok(Self { hugr, parent })
Ok(Self {
hugr,
parent,
required_extensions: None,
})
}

/// Create a new circuit from a HUGR and a node.
Expand Down Expand Up @@ -114,12 +153,40 @@ impl<T: HugrView> Circuit<T> {
&mut self.hugr
}

/// Get the required extensions for the circuit.
///
/// If no extension set was defined, returns the default set of quantum extensions and Hugr's std set.
///
/// Note: This API is not currently public. We expect hugrs to embed their required extensions in the future,
/// at which point this method will be removed.
/// See https://github.com/CQCL/hugr/issues/1613
pub(crate) fn required_extensions(&self) -> &[Extension] {
self.required_extensions
.as_deref()
.unwrap_or_else(|| &DEFAULT_REQUIRED_EXTENSIONS)
}

/// Set the required extension set for the circuit.
///
/// Returns the previous set of required extensions, if any.
///
/// Note: This API is not currently public. We expect hugrs to embed their required extensions in the future,
/// at which point this method will be removed.
/// See https://github.com/CQCL/hugr/issues/1613
pub(crate) fn set_required_extensions(
&mut self,
extensions: Arc<Vec<Extension>>,
) -> Option<Arc<Vec<Extension>>> {
mem::replace(&mut self.required_extensions, Some(extensions))
}

/// Ensures the circuit contains an owned HUGR.
pub fn to_owned(&self) -> Circuit<Hugr> {
let hugr = self.hugr.base_hugr().clone();
Circuit {
hugr,
parent: self.parent,
required_extensions: self.required_extensions.clone(),
}
}

Expand Down Expand Up @@ -732,9 +799,14 @@ mod tests {
})
.unwrap();

let orig_circ = circ.clone();
assert_eq!(circ, orig_circ);

assert_eq!(circ.qubit_count(), 2);
assert!(remove_empty_wire(&mut circ, 1).is_ok());
assert_eq!(circ.qubit_count(), 1);
assert_ne!(circ, orig_circ);

assert_eq!(
remove_empty_wire(&mut circ, 0).unwrap_err(),
CircuitMutError::DeleteNonEmptyWire {
Expand Down
Loading

0 comments on commit 5e87dd9

Please sign in to comment.