Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(py): Allow using Tk2Ops in the builder #436

Merged
merged 6 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 7 additions & 124 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use crate::utils::ConvertPyErr;

pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, CircuitType};
pub use self::cost::PyCircuitCost;
use self::tk2circuit::Dfg;
pub use self::tk2circuit::Tk2Circuit;
use self::tk2circuit::{into_vec, Dfg};
pub use tket2::{Pauli, Tk2Op};

/// The module definition
Expand All @@ -38,13 +38,10 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
m.add_class::<PyWire>()?;
m.add_class::<WireIter>()?;
m.add_class::<PyCircuitCost>()?;
m.add_class::<PyCustom>()?;
m.add_class::<PyHugrType>()?;
m.add_class::<PyTypeBound>()?;

m.add_function(wrap_pyfunction!(validate_hugr, &m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, &m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_mermaid, &m)?)?;
m.add_function(wrap_pyfunction!(validate_circuit, &m)?)?;
m.add_function(wrap_pyfunction!(render_circuit_dot, &m)?)?;
m.add_function(wrap_pyfunction!(render_circuit_mermaid, &m)?)?;

m.add("HugrError", py.get_type_bound::<PyHugrError>())?;
m.add("BuildError", py.get_type_bound::<PyBuildError>())?;
Expand Down Expand Up @@ -90,19 +87,19 @@ create_py_exception!(

/// Run the validation checks on a circuit.
#[pyfunction]
pub fn validate_hugr(c: &Bound<PyAny>) -> PyResult<()> {
pub fn validate_circuit(c: &Bound<PyAny>) -> PyResult<()> {
try_with_hugr(c, |hugr, _| hugr.validate(&REGISTRY))
}

/// Return a Graphviz DOT string representation of the circuit.
#[pyfunction]
pub fn to_hugr_dot(c: &Bound<PyAny>) -> PyResult<String> {
pub fn render_circuit_dot(c: &Bound<PyAny>) -> PyResult<String> {
with_hugr(c, |hugr, _| hugr.dot_string())
}

/// Return a Mermaid diagram representation of the circuit.
#[pyfunction]
pub fn to_hugr_mermaid(c: &Bound<PyAny>) -> PyResult<String> {
pub fn render_circuit_mermaid(c: &Bound<PyAny>) -> PyResult<String> {
with_hugr(c, |hugr, _| hugr.mermaid_string())
}

Expand Down Expand Up @@ -210,117 +207,3 @@ impl PyWire {
self.wire.source().index()
}
}

#[pyclass]
#[pyo3(name = "CustomOp")]
#[repr(transparent)]
#[derive(From, Into, PartialEq, Clone)]
struct PyCustom(CustomOp);

impl fmt::Debug for PyCustom {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

impl From<PyCustom> for OpType {
fn from(op: PyCustom) -> Self {
op.0.into()
}
}

#[pymethods]
impl PyCustom {
#[new]
fn new(
extension: &str,
op_name: &str,
input_types: Vec<PyHugrType>,
output_types: Vec<PyHugrType>,
) -> PyResult<Self> {
Ok(CustomOp::new_opaque(OpaqueOp::new(
IdentList::new(extension).unwrap(),
op_name,
Default::default(),
[],
FunctionType::new(into_vec(input_types), into_vec(output_types)),
))
.into())
}

fn to_custom(&self) -> Self {
self.clone()
}
pub fn __repr__(&self) -> String {
format!("{:?}", self)
}

fn name(&self) -> String {
self.0.name().to_string()
}
}
#[pyclass]
#[pyo3(name = "TypeBound")]
#[derive(PartialEq, Clone, Debug)]
enum PyTypeBound {
Any,
Copyable,
Eq,
}

impl From<PyTypeBound> for TypeBound {
fn from(bound: PyTypeBound) -> Self {
match bound {
PyTypeBound::Any => TypeBound::Any,
PyTypeBound::Copyable => TypeBound::Copyable,
PyTypeBound::Eq => TypeBound::Eq,
}
}
}

impl From<TypeBound> for PyTypeBound {
fn from(bound: TypeBound) -> Self {
match bound {
TypeBound::Any => PyTypeBound::Any,
TypeBound::Copyable => PyTypeBound::Copyable,
TypeBound::Eq => PyTypeBound::Eq,
}
}
}

#[pyclass]
#[pyo3(name = "HugrType")]
#[repr(transparent)]
#[derive(From, Into, PartialEq, Clone)]
struct PyHugrType(Type);

impl fmt::Debug for PyHugrType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

#[pymethods]
impl PyHugrType {
#[new]
fn new(extension: &str, type_name: &str, bound: PyTypeBound) -> Self {
Self(Type::new_extension(CustomType::new_simple(
type_name.into(),
IdentList::new(extension).unwrap(),
bound.into(),
)))
}
#[staticmethod]
fn qubit() -> Self {
Self(QB_T)
}

#[staticmethod]
fn bool() -> Self {
Self(BOOL_T)
}

pub fn __repr__(&self) -> String {
format!("{:?}", self)
}
}
2 changes: 1 addition & 1 deletion tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use tket_json_rs::circuit_json::SerialCircuit;
use crate::rewrite::PyCircuitRewrite;
use crate::utils::ConvertPyErr;

use super::{cost, PyCircuitCost, PyCustom, PyHugrType, PyNode, PyWire, Tk2Circuit};
use super::{cost, PyCircuitCost, PyNode, PyWire, Tk2Circuit};

/// A flag to indicate the encoding of a circuit.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
Expand Down
29 changes: 18 additions & 11 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use itertools::Itertools;
use pyo3::exceptions::{PyAttributeError, PyValueError};
use pyo3::types::{PyAnyMethods, PyModule, PyString, PyTypeMethods};
use pyo3::{
pyclass, pymethods, Bound, FromPyObject, PyAny, PyErr, PyObject, PyRefMut, PyResult,
pyclass, pymethods, Bound, FromPyObject, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult,
PyTypeInfo, Python, ToPyObject,
};

Expand All @@ -26,11 +26,12 @@ use tket2::serialize::TKETDecode;
use tket2::{Circuit, Tk2Op};
use tket_json_rs::circuit_json::SerialCircuit;

use crate::ops::PyTk2Op;
use crate::ops::{PyCustomOp, PyTk2Op};
use crate::rewrite::PyCircuitRewrite;
use crate::utils::ConvertPyErr;
use crate::types::PyHugrType;
use crate::utils::{into_vec, ConvertPyErr};

use super::{cost, with_hugr, PyCircuitCost, PyCustom, PyHugrType, PyNode, PyWire};
use super::{cost, with_hugr, PyCircuitCost, PyNode, PyWire};

/// A circuit in tket2 format.
///
Expand Down Expand Up @@ -177,7 +178,7 @@ impl Tk2Circuit {
Ok(self.clone())
}

fn node_op(&self, node: PyNode) -> PyResult<PyCustom> {
fn node_op(&self, node: PyNode) -> PyResult<PyCustomOp> {
let custom: CustomOp = self
.circ
.hugr()
Expand Down Expand Up @@ -248,8 +249,18 @@ impl Dfg {
self.builder.input_wires().map_into().collect()
}

fn add_op(&mut self, op: PyCustom, inputs: Vec<PyWire>) -> PyResult<PyNode> {
let custom: CustomOp = op.into();
fn add_op(&mut self, op: Bound<PyAny>, inputs: Vec<PyWire>) -> PyResult<PyNode> {
// TODO: Once we wrap `Dfg` in a pure python class we can make the conversion there,
// and have a concrete `op: PyCustomOp` argument here.
let custom: PyCustomOp = op
.call_method0("to_custom")
.map_err(|_| {
PyErr::new::<PyValueError, _>(
"The operation must implement the `ToCustomOp` protocol.",
)
})?
.extract()?;
let custom: CustomOp = custom.into();
self.builder
.add_dataflow_op(custom, inputs.into_iter().map_into())
.convert_pyerrs()
Expand All @@ -267,7 +278,3 @@ impl Dfg {
})
}
}

pub(super) fn into_vec<T, S: From<T>>(v: impl IntoIterator<Item = T>) -> Vec<S> {
v.into_iter().map_into().collect()
}
2 changes: 2 additions & 0 deletions tket2-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod optimiser;
pub mod passes;
pub mod pattern;
pub mod rewrite;
pub mod types;
pub mod utils;

use pyo3::prelude::*;
Expand All @@ -18,6 +19,7 @@ fn _tket2(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(py, m, passes::module(py)?)?;
add_submodule(py, m, pattern::module(py)?)?;
add_submodule(py, m, rewrite::module(py)?)?;
add_submodule(py, m, types::module(py)?)?;
Ok(())
}

Expand Down
72 changes: 70 additions & 2 deletions tket2-py/src/ops.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
//! Bindings for rust-defined operations

use derive_more::From;
use hugr::ops::NamedOp;
use derive_more::{From, Into};
use hugr::hugr::IdentList;
use hugr::ops::custom::{ExtensionOp, OpaqueOp};
use hugr::types::FunctionType;
use pyo3::prelude::*;
use std::fmt;
use std::str::FromStr;
use strum::IntoEnumIterator;

use hugr::ops::{CustomOp, NamedOp, OpType};
use tket2::{Pauli, Tk2Op};

use crate::types::PyHugrType;
use crate::utils::into_vec;

/// The module definition
pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
let m = PyModule::new_bound(py, "ops")?;
m.add_class::<PyTk2Op>()?;
m.add_class::<PyPauli>()?;
m.add_class::<PyCustomOp>()?;
Ok(m)
}

Expand Down Expand Up @@ -58,6 +67,12 @@ impl PyTk2Op {
self.op.exposed_name().to_string()
}

/// Wrap the operation as a custom operation.
pub fn to_custom(&self) -> PyCustomOp {
let custom: ExtensionOp = self.op.into_extension_op();
CustomOp::new_extension(custom).into()
}

/// String representation of the operation.
pub fn __repr__(&self) -> String {
self.qualified_name()
Expand Down Expand Up @@ -203,3 +218,56 @@ impl PyPauliIter {
self.it.next().map(|p| PyPauli { p })
}
}

/// A wrapped custom operation.
#[pyclass]
#[pyo3(name = "CustomOp")]
#[repr(transparent)]
#[derive(From, Into, PartialEq, Clone)]
pub struct PyCustomOp(CustomOp);

impl fmt::Debug for PyCustomOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

impl From<PyCustomOp> for OpType {
fn from(op: PyCustomOp) -> Self {
op.0.into()
}
}

#[pymethods]
impl PyCustomOp {
#[new]
fn new(
extension: &str,
op_name: &str,
input_types: Vec<PyHugrType>,
output_types: Vec<PyHugrType>,
) -> PyResult<Self> {
Ok(CustomOp::new_opaque(OpaqueOp::new(
IdentList::new(extension).unwrap(),
op_name,
Default::default(),
[],
FunctionType::new(into_vec(input_types), into_vec(output_types)),
))
.into())
}

fn to_custom(&self) -> Self {
self.clone()
}

/// String representation of the operation.
pub fn __repr__(&self) -> String {
format!("{:?}", self)
}

#[getter]
fn name(&self) -> String {
self.0.name().to_string()
}
}
Loading
Loading