diff --git a/Cargo.toml b/Cargo.toml index 7dea3fde..1d24746c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,8 @@ missing_docs = "warn" [workspace.dependencies] tket2 = { path = "./tket2" } -quantinuum-hugr = "0.2" -portgraph = "0.11" +hugr = "0.3.0" +portgraph = "0.12" pyo3 = "0.21.2" itertools = "0.12.0" tket-json-rs = "0.4.0" diff --git a/badger-optimiser/Cargo.toml b/badger-optimiser/Cargo.toml index 53d613b9..f6b883fa 100644 --- a/badger-optimiser/Cargo.toml +++ b/badger-optimiser/Cargo.toml @@ -12,7 +12,7 @@ license-file = { workspace = true } clap = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tket2 = { workspace = true, features = ["portmatching", "rewrite-tracing"] } -quantinuum-hugr = { workspace = true } +hugr = { workspace = true } itertools = { workspace = true } tket-json-rs = { workspace = true } tracing = { workspace = true } diff --git a/compile-rewriter/Cargo.toml b/compile-rewriter/Cargo.toml index c002b516..491ef215 100644 --- a/compile-rewriter/Cargo.toml +++ b/compile-rewriter/Cargo.toml @@ -1,15 +1,13 @@ [package] -name = "compile-matcher" +name = "compile-rewriter" edition = { workspace = true } version = { workspace = true } rust-version = { workspace = true } homepage = { workspace = true } license-file = { workspace = true } -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] clap = { workspace = true, features = ["derive"] } tket2 = { workspace = true, features = ["portmatching"] } -quantinuum-hugr = { workspace = true } +hugr = { workspace = true } itertools = { workspace = true } diff --git a/test_files/nam_6_3.rwr b/test_files/nam_6_3.rwr index 9feac5e2..0e60c90c 100644 Binary files a/test_files/nam_6_3.rwr and b/test_files/nam_6_3.rwr differ diff --git a/test_files/small_eccs.rwr b/test_files/small_eccs.rwr index 637215e9..3f195e12 100644 Binary files a/test_files/small_eccs.rwr and b/test_files/small_eccs.rwr differ diff --git a/tket2-py/Cargo.toml b/tket2-py/Cargo.toml index fb94a9fb..140b2681 100644 --- a/tket2-py/Cargo.toml +++ b/tket2-py/Cargo.toml @@ -18,7 +18,7 @@ tket2 = { workspace = true, features = ["pyo3", "portmatching"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tket-json-rs = { workspace = true, features = ["pyo3"] } -quantinuum-hugr = { workspace = true } +hugr = { workspace = true } portgraph = { workspace = true, features = ["serde"] } pyo3 = { workspace = true, features = ["extension-module"] } num_cpus = { workspace = true } diff --git a/tket2-py/tket2/data/nam_6_3.rwr b/tket2-py/tket2/data/nam_6_3.rwr index 9feac5e2..0e60c90c 100644 Binary files a/tket2-py/tket2/data/nam_6_3.rwr and b/tket2-py/tket2/data/nam_6_3.rwr differ diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 1d70dc92..846aeb48 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -45,7 +45,7 @@ petgraph = { workspace = true } serde_yaml = { workspace = true } portmatching = { workspace = true, optional = true, features = ["serde"] } derive_more = { workspace = true } -quantinuum-hugr = { workspace = true } +hugr = { workspace = true } portgraph = { workspace = true, features = ["serde"] } pyo3 = { workspace = true, optional = true, features = ["multiple-pymethods"] } strum_macros = { workspace = true } diff --git a/tket2/benches/benchmarks/generators.rs b/tket2/benches/benchmarks/generators.rs index e5a05ff7..510f17b8 100644 --- a/tket2/benches/benchmarks/generators.rs +++ b/tket2/benches/benchmarks/generators.rs @@ -17,7 +17,7 @@ pub fn build_simple_circuit( let qbs = h.input_wires(); - let mut circ = h.as_circuit(qbs.into_iter().collect()); + let mut circ = h.as_circuit(qbs); f(&mut circ)?; diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 9df5a412..348876d4 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -16,7 +16,7 @@ use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::NodeType; use hugr::ops::dataflow::IOTrait; use hugr::ops::{Input, Output, DFG}; -use hugr::types::FunctionType; +use hugr::types::PolyFuncType; use hugr::PortIndex; use hugr::{HugrView, OutgoingPort}; use itertools::Itertools; @@ -45,7 +45,7 @@ pub trait Circuit: HugrView { /// /// Equivalent to [`HugrView::get_function_type`]. #[inline] - fn circuit_signature(&self) -> FunctionType { + fn circuit_signature(&self) -> PolyFuncType { self.get_function_type() .expect("Circuit has no function type") } @@ -183,7 +183,7 @@ pub(crate) fn remove_empty_wire( return Err(CircuitMutError::DeleteNonEmptyWire(input_port.index())); } if link.is_some() { - circ.disconnect(inp, input_port)?; + circ.disconnect(inp, input_port); } // Shift ports at input @@ -231,7 +231,7 @@ fn shift_ports( for port in port_range { let links = circ.linked_ports(node, port).collect_vec(); if !links.is_empty() { - circ.disconnect(node, port)?; + circ.disconnect(node, port); } for (other_n, other_p) in links { match other_p.as_directed() { @@ -243,7 +243,7 @@ fn shift_ports( let src_port = free_port.as_outgoing().unwrap(); circ.connect(node, src_port, other_n, other_p) } - }?; + }; } free_port = port; } @@ -308,6 +308,7 @@ impl Circuit for T where T: HugrView {} #[cfg(test)] mod tests { + use hugr::types::FunctionType; use hugr::{ builder::{DFGBuilder, DataflowHugr}, extension::{prelude::BOOL_T, PRELUDE_REGISTRY}, @@ -338,8 +339,8 @@ mod tests { let circ = test_circuit(); assert_eq!(circ.name(), None); - assert_eq!(circ.circuit_signature().input_count(), 3); - assert_eq!(circ.circuit_signature().output_count(), 3); + assert_eq!(circ.circuit_signature().body().input_count(), 3); + assert_eq!(circ.circuit_signature().body().output_count(), 3); assert_eq!(circ.qubit_count(), 2); assert_eq!(circ.num_gates(), 3); diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 0c369933..3cc7b7dd 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -552,8 +552,8 @@ mod test { let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row)).unwrap(); let [q_in] = h.input_wires_arr(); - let constant = h.add_constant(ConstF64::new(0.5)).unwrap(); - let loaded_const = h.load_const(&constant).unwrap(); + let constant = h.add_constant(ConstF64::new(0.5)); + let loaded_const = h.load_const(&constant); let rz = h .add_dataflow_op(Tk2Op::RzF64, [q_in, loaded_const]) .unwrap(); diff --git a/tket2/src/circuit/hash.rs b/tket2/src/circuit/hash.rs index 8b5f5e1f..b608d654 100644 --- a/tket2/src/circuit/hash.rs +++ b/tket2/src/circuit/hash.rs @@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher}; use fxhash::{FxHashMap, FxHasher64}; use hugr::hugr::views::{HierarchyView, SiblingGraph}; -use hugr::ops::{LeafOp, OpName, OpType}; +use hugr::ops::{OpName, OpType}; use hugr::{HugrView, Node}; use petgraph::visit::{self as pg, Walker}; use thiserror::Error; @@ -78,7 +78,7 @@ impl HashState { /// Returns a hashable representation of an operation. fn hashable_op(op: &OpType) -> impl Hash { match op { - OpType::LeafOp(LeafOp::CustomOp(op)) if !op.args().is_empty() => { + OpType::CustomOp(op) if !op.args().is_empty() => { // TODO: Require hashing for TypeParams? format!( "{}[{}]", diff --git a/tket2/src/circuit/units.rs b/tket2/src/circuit/units.rs index 5da38f6d..17383aa7 100644 --- a/tket2/src/circuit/units.rs +++ b/tket2/src/circuit/units.rs @@ -140,6 +140,8 @@ where // // TODO: This is quite hacky, but we need it to accept Const static inputs. // We should revisit it once this is reworked on the HUGR side. + // + // TODO: EdgeKind::Function is not currently supported. fn init_types(circuit: &impl Circuit, node: Node, direction: Direction) -> TypeRow { let optype = circuit.get_optype(node); let sig = circuit.signature(node).unwrap_or_default(); @@ -147,10 +149,10 @@ where Direction::Outgoing => sig.output, Direction::Incoming => sig.input, }; - if let Some(EdgeKind::Static(static_type)) = optype.static_port_kind(direction) { + if let Some(EdgeKind::Const(static_type)) = optype.static_port_kind(direction) { types.to_mut().push(static_type); }; - if let Some(EdgeKind::Static(other)) = optype.other_port_kind(direction) { + if let Some(EdgeKind::Const(other)) = optype.other_port_kind(direction) { types.to_mut().push(other); } types diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index 64842aa5..cb29e1b4 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -8,7 +8,8 @@ use hugr::extension::prelude::PRELUDE; use hugr::extension::simple_op::MakeOpDef; use hugr::extension::{CustomSignatureFunc, ExtensionId, ExtensionRegistry, SignatureError}; use hugr::hugr::IdentList; -use hugr::ops::custom::{ExternalOp, OpaqueOp}; +use hugr::ops::custom::{CustomOp, OpaqueOp}; +use hugr::ops::OpName; use hugr::std_extensions::arithmetic::float_types::{EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE}; use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam}; use hugr::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound}; @@ -73,7 +74,7 @@ pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ } /// Create a new opaque operation -pub(crate) fn wrap_json_op(op: &JsonOp) -> ExternalOp { +pub(crate) fn wrap_json_op(op: &JsonOp) -> CustomOp { // TODO: This throws an error //let op = serde_yaml::to_value(op).unwrap(); //let payload = TypeArg::Opaque(CustomTypeArg::new(TKET1_OP_PAYLOAD.clone(), op).unwrap()); @@ -100,7 +101,7 @@ pub(crate) fn wrap_json_op(op: &JsonOp) -> ExternalOp { /// Extract a json-encoded TKET1 operation from an opaque operation, if /// possible. -pub(crate) fn try_unwrap_json_op(ext: &ExternalOp) -> Option { +pub(crate) fn try_unwrap_json_op(ext: &CustomOp) -> Option { // TODO: Check `extensions.contains(&TKET1_EXTENSION_ID)` // (but the ext op extensions are an empty set?) if ext.name() != format!("{TKET1_EXTENSION_ID}.{JSON_OP_NAME}") { diff --git a/tket2/src/extension/angle.rs b/tket2/src/extension/angle.rs index 2f56cf0b..a6f02a3f 100644 --- a/tket2/src/extension/angle.rs +++ b/tket2/src/extension/angle.rs @@ -1,13 +1,13 @@ use std::{cmp::max, num::NonZeroU64}; use hugr::extension::ExtensionSet; +use hugr::ops::constant::{downcast_equal_consts, CustomConst}; use hugr::{ extension::{prelude::ERROR_TYPE, SignatureError, SignatureFromArgs, TypeDef}, types::{ type_param::{TypeArgError, TypeParam}, ConstTypeError, CustomType, FunctionType, PolyFuncType, Type, TypeArg, TypeBound, }, - values::CustomConst, Extension, }; use itertools::Itertools; @@ -116,11 +116,12 @@ impl CustomConst for ConstAngle { format!("a(2π*{}/2^{})", self.value, self.log_denom).into() } - fn custom_type(&self) -> CustomType { - super::angle_custom_type(self.log_denom) + fn get_type(&self) -> Type { + super::angle_custom_type(self.log_denom).into() } + fn equal_consts(&self, other: &dyn CustomConst) -> bool { - hugr::values::downcast_equal_consts(self, other) + downcast_equal_consts(self, other) } fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(&TKET2_EXTENSION_ID) @@ -205,9 +206,9 @@ pub(super) fn add_to_extension(extension: &mut Extension) { vec![LOG_DENOM_TYPE_PARAM, LOG_DENOM_TYPE_PARAM], FunctionType::new( vec![generic_angle_type(0, &angle_type_def)], - vec![Type::new_sum(vec![ - generic_angle_type(1, &angle_type_def), - ERROR_TYPE, + vec![Type::new_sum([ + generic_angle_type(1, &angle_type_def).into(), + ERROR_TYPE.into(), ])], ), ), @@ -241,6 +242,7 @@ pub(super) fn add_to_extension(extension: &mut Extension) { #[cfg(test)] mod test { use super::*; + use crate::extension::angle_custom_type; use hugr::types::TypeArg; #[test] @@ -264,14 +266,8 @@ mod test { assert_ne!(const_a32_7, const_a32_8); assert_eq!(const_a32_7, ConstAngle::new(5, 7).unwrap()); - assert_eq!( - const_a32_7.custom_type(), - super::super::angle_custom_type(5) - ); - assert_ne!( - const_a32_7.custom_type(), - super::super::angle_custom_type(6) - ); + assert_eq!(const_a32_7.get_type(), angle_custom_type(5).into()); + assert_ne!(const_a32_7.get_type(), angle_custom_type(6).into()); assert!(matches!( ConstAngle::new(3, 256), Err(ConstTypeError::CustomCheckFail(_)) diff --git a/tket2/src/json.rs b/tket2/src/json.rs index a6050af9..54fb093c 100644 --- a/tket2/src/json.rs +++ b/tket2/src/json.rs @@ -12,9 +12,8 @@ use hugr::CircuitUnit; use std::path::Path; use std::{fs, io}; -use hugr::ops::OpType; +use hugr::ops::{Const, OpType}; use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; -use hugr::values::Value; use hugr::Hugr; use stringreader::StringReader; @@ -194,7 +193,7 @@ fn parse_val(n: &str) -> Option { } /// Try to interpret a TKET1 parameter as a constant value. #[inline] -fn try_param_to_constant(param: &str) -> Option { +fn try_param_to_constant(param: &str) -> Option { if let Some(f) = parse_val(param) { Some(ConstF64::new(f).into()) } else if param.split('/').count() == 2 { diff --git a/tket2/src/json/decoder.rs b/tket2/src/json/decoder.rs index dbb2a37f..7cac7360 100644 --- a/tket2/src/json/decoder.rs +++ b/tket2/src/json/decoder.rs @@ -8,8 +8,6 @@ use std::mem; use hugr::builder::{CircuitBuilder, Container, DFGBuilder, Dataflow, DataflowHugr}; use hugr::extension::prelude::QB_T; -use hugr::ops::Const; -use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::FunctionType; use hugr::CircuitUnit; use hugr::{Hugr, Wire}; @@ -143,11 +141,7 @@ impl JsonDecoder { /// TODO: If the parameter is a variable, returns the corresponding wire from the input. fn create_param_wire(&mut self, param: &str) -> Wire { match try_param_to_constant(param) { - Some(c) => { - let const_type = FLOAT64_TYPE; - let const_op = Const::new(c, const_type).unwrap(); - self.hugr.add_load_const(const_op).unwrap() - } + Some(const_op) => self.hugr.add_load_const(const_op), None => { // store string in custom op. let symb_op = symbolic_constant_op(param); diff --git a/tket2/src/json/encoder.rs b/tket2/src/json/encoder.rs index c86e564a..a8417cea 100644 --- a/tket2/src/json/encoder.rs +++ b/tket2/src/json/encoder.rs @@ -6,7 +6,6 @@ use std::collections::HashMap; use hugr::extension::prelude::QB_T; use hugr::ops::{OpName, OpType}; use hugr::std_extensions::arithmetic::float_types::ConstF64; -use hugr::values::Value; use hugr::Wire; use itertools::{Either, Itertools}; use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit}; @@ -194,16 +193,10 @@ impl JsonEncoder { let param = match optype { OpType::Const(const_op) => { // New constant, register it if it can be interpreted as a parameter. - match const_op.value() { - Value::Extension { c: (val,) } => { - if let Some(f) = val.downcast_ref::() { - f.to_string() - } else { - return false; - } - } - _ => return false, - } + let Some(const_float) = const_op.get_custom_value::() else { + return false; + }; + const_float.to_string() } OpType::LoadConstant(_op_type) => { // Re-use the parameter from the input. @@ -212,15 +205,11 @@ impl JsonEncoder { op if op_matches(op, Tk2Op::AngleAdd) => { format!("{} + {}", inputs[0], inputs[1]) } - OpType::LeafOp(_) => { - if let Some(s) = match_symb_const_op(optype) { - s.to_string() - } else { - return false; - } - } _ => { - return false; + let Some(s) = match_symb_const_op(optype) else { + return false; + }; + s.to_string() } }; diff --git a/tket2/src/json/op.rs b/tket2/src/json/op.rs index 32e77b80..407eaf5a 100644 --- a/tket2/src/json/op.rs +++ b/tket2/src/json/op.rs @@ -8,8 +8,8 @@ use hugr::extension::prelude::QB_T; -use hugr::ops::custom::ExternalOp; -use hugr::ops::{LeafOp, OpTrait, OpType}; +use hugr::ops::custom::CustomOp; +use hugr::ops::{Noop, OpTrait, OpType}; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::FunctionType; @@ -150,7 +150,7 @@ impl JsonOp { } /// Wraps the op into a Hugr opaque operation - fn as_opaque_op(&self) -> ExternalOp { + fn as_custom_op(&self) -> CustomOp { crate::extension::wrap_json_op(self) } @@ -177,7 +177,6 @@ impl From<&JsonOp> for OpType { /// Any other operation is wrapped in an `OpaqueOp`. fn from(json_op: &JsonOp) -> Self { match json_op.op.op_type { - // JsonOpType::X => LeafOp::X.into(), JsonOpType::H => Tk2Op::H.into(), JsonOpType::CX => Tk2Op::CX.into(), JsonOpType::T => Tk2Op::T.into(), @@ -193,8 +192,13 @@ impl From<&JsonOp> for OpType { JsonOpType::ZZPhase => Tk2Op::ZZPhase.into(), JsonOpType::CZ => Tk2Op::CZ.into(), JsonOpType::Reset => Tk2Op::Reset.into(), - JsonOpType::noop => LeafOp::Noop { ty: QB_T }.into(), - _ => LeafOp::CustomOp(Box::new(json_op.as_opaque_op())).into(), + JsonOpType::noop => { + // TODO: Replace with `Noop::new` once that is published. + let mut noop = Noop::default(); + noop.ty = QB_T; + noop.into() + } + _ => json_op.as_custom_op().into(), } } } @@ -209,43 +213,41 @@ impl TryFrom<&OpType> for JsonOp { // // Non-supported Hugr operations throw an error. let err = || OpConvertError::UnsupportedOpSerialization(op.clone()); - let Some(leaf) = op.as_leaf_op() else { - return Err(err()); + + let Ok(tk2op) = op.try_into() else { + if let OpType::CustomOp(custom_op) = op { + return try_unwrap_json_op(custom_op).ok_or_else(err); + } else { + return Err(err()); + } }; - let json_optype = if let Ok(tk2op) = leaf.try_into() { - match tk2op { - Tk2Op::H => JsonOpType::H, - Tk2Op::CX => JsonOpType::CX, - Tk2Op::T => JsonOpType::T, - Tk2Op::S => JsonOpType::S, - Tk2Op::X => JsonOpType::X, - Tk2Op::Y => JsonOpType::Y, - Tk2Op::Z => JsonOpType::Z, - Tk2Op::Tdg => JsonOpType::Tdg, - Tk2Op::Sdg => JsonOpType::Sdg, - Tk2Op::ZZMax => JsonOpType::ZZMax, - Tk2Op::Measure => JsonOpType::Measure, - Tk2Op::RzF64 => JsonOpType::Rz, - Tk2Op::RxF64 => JsonOpType::Rx, - // TODO: Use a TK2 opaque op once we update the tket-json-rs dependency. - Tk2Op::AngleAdd => { - unimplemented!("Serialising AngleAdd not supported. Are all constants folded?") - } - Tk2Op::TK1 => JsonOpType::TK1, - Tk2Op::PhasedX => JsonOpType::PhasedX, - Tk2Op::ZZPhase => JsonOpType::ZZPhase, - Tk2Op::CZ => JsonOpType::CZ, - Tk2Op::Reset => JsonOpType::Reset, - Tk2Op::QAlloc | Tk2Op::QFree => { - unimplemented!("TKET1 does not support dynamic qubit allocation/discarding.") - } + let json_optype = match tk2op { + Tk2Op::H => JsonOpType::H, + Tk2Op::CX => JsonOpType::CX, + Tk2Op::T => JsonOpType::T, + Tk2Op::S => JsonOpType::S, + Tk2Op::X => JsonOpType::X, + Tk2Op::Y => JsonOpType::Y, + Tk2Op::Z => JsonOpType::Z, + Tk2Op::Tdg => JsonOpType::Tdg, + Tk2Op::Sdg => JsonOpType::Sdg, + Tk2Op::ZZMax => JsonOpType::ZZMax, + Tk2Op::Measure => JsonOpType::Measure, + Tk2Op::RzF64 => JsonOpType::Rz, + Tk2Op::RxF64 => JsonOpType::Rx, + // TODO: Use a TK2 opaque op once we update the tket-json-rs dependency. + Tk2Op::AngleAdd => { + unimplemented!("Serialising AngleAdd not supported. Are all constants folded?") + } + Tk2Op::TK1 => JsonOpType::TK1, + Tk2Op::PhasedX => JsonOpType::PhasedX, + Tk2Op::ZZPhase => JsonOpType::ZZPhase, + Tk2Op::CZ => JsonOpType::CZ, + Tk2Op::Reset => JsonOpType::Reset, + Tk2Op::QAlloc | Tk2Op::QFree => { + unimplemented!("TKET1 does not support dynamic qubit allocation/discarding.") } - } else if let LeafOp::CustomOp(b) = leaf { - let ext = (*b).as_ref(); - return try_unwrap_json_op(ext).ok_or_else(err); - } else { - return Err(err()); }; let mut num_qubits = 0; diff --git a/tket2/src/json/tests.rs b/tket2/src/json/tests.rs index df9b0a27..6d2d6e0f 100644 --- a/tket2/src/json/tests.rs +++ b/tket2/src/json/tests.rs @@ -109,8 +109,8 @@ fn circ_add_angles_constants() -> Hugr { let qb = h.input_wires().next().unwrap(); - let point2 = h.add_load_const(ConstF64::new(0.2)).unwrap(); - let point3 = h.add_load_const(ConstF64::new(0.3)).unwrap(); + let point2 = h.add_load_const(ConstF64::new(0.2)); + let point3 = h.add_load_const(ConstF64::new(0.3)); let point5 = h .add_dataflow_op(Tk2Op::AngleAdd, [point2, point3]) .unwrap() diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index d6b535d9..660f12d8 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -8,7 +8,7 @@ use hugr::{ simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, ExtensionId, OpDef, SignatureFunc, }, - ops::{custom::ExternalOp, LeafOp, OpType}, + ops::{CustomOp, OpType}, std_extensions::arithmetic::float_types::FLOAT64_TYPE, type_row, types::{ @@ -181,7 +181,7 @@ impl Tk2Op { /// Initialize a new custom symbolic expression constant op from a string. pub fn symbolic_constant_op(s: &str) -> OpType { let value: serde_yaml::Value = s.into(); - let l: LeafOp = EXTENSION + EXTENSION .instantiate_extension_op( &SYM_OP_ID, vec![TypeArg::Opaque { @@ -190,8 +190,7 @@ pub fn symbolic_constant_op(s: &str) -> OpType { ®ISTRY, ) .unwrap() - .into(); - l.into() + .into() } /// match against a symbolic constant @@ -209,14 +208,14 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option { .unwrap_or_else(|| panic!("Found an invalid type arg in a symbolic operation node.")) }; - if let OpType::LeafOp(LeafOp::CustomOp(e)) = op { - match e.as_ref() { - ExternalOp::Extension(e) + if let OpType::CustomOp(custom_op) = op { + match custom_op { + CustomOp::Extension(e) if e.def().name() == &SYM_OP_ID && e.def().extension() == &EXTENSION_ID => { Some(symbol_from_typeargs(e.args())) } - ExternalOp::Opaque(e) if e.name() == &SYM_OP_ID && e.extension() == &EXTENSION_ID => { + CustomOp::Opaque(e) if e.name() == &SYM_OP_ID && e.extension() == &EXTENSION_ID => { Some(symbol_from_typeargs(e.args())) } _ => None, @@ -226,52 +225,31 @@ pub(crate) fn match_symb_const_op(op: &OpType) -> Option { } } -impl From for LeafOp { - fn from(op: Tk2Op) -> Self { - op.to_extension_op().unwrap().into() - } -} - -impl TryFrom for Tk2Op { - type Error = NotTk2Op; - - fn try_from(op: OpType) -> Result { - let leaf: LeafOp = op.try_into().map_err(|_| NotTk2Op)?; - leaf.try_into() - } -} - impl TryFrom<&OpType> for Tk2Op { type Error = NotTk2Op; fn try_from(op: &OpType) -> Result { - let OpType::LeafOp(leaf) = op else { - return Err(NotTk2Op); - }; - leaf.try_into() - } -} - -impl TryFrom<&LeafOp> for Tk2Op { - type Error = NotTk2Op; - - fn try_from(op: &LeafOp) -> Result { - let LeafOp::CustomOp(ext) = op else { + let OpType::CustomOp(custom_op) = op else { return Err(NotTk2Op); }; - match ext.as_ref() { - ExternalOp::Extension(ext) => Tk2Op::from_extension_op(ext), - ExternalOp::Opaque(opaque) => try_from_name(opaque.name()), + match custom_op { + CustomOp::Extension(ext) => Tk2Op::from_extension_op(ext), + CustomOp::Opaque(opaque) => { + if opaque.extension() != &EXTENSION_ID { + return Err(NotTk2Op); + } + try_from_name(opaque.name()) + } } .map_err(|_| NotTk2Op) } } -impl TryFrom for Tk2Op { +impl TryFrom for Tk2Op { type Error = NotTk2Op; - fn try_from(op: LeafOp) -> Result { + fn try_from(op: OpType) -> Result { Self::try_from(&op) } } diff --git a/tket2/src/passes/chunks.rs b/tket2/src/passes/chunks.rs index 278df47e..67661c17 100644 --- a/tket2/src/passes/chunks.rs +++ b/tket2/src/passes/chunks.rs @@ -58,9 +58,7 @@ impl Chunk { checker, ) .expect("Failed to define the chunk subgraph"); - let extracted = subgraph - .extract_subgraph(circ, "Chunk") - .expect("Failed to extract chunk"); + let extracted = subgraph.extract_subgraph(circ, "Chunk"); // Transform the subgraph's input/output sets into wires that can be // matched between different chunks. // @@ -104,9 +102,7 @@ impl Chunk { // Insert the chunk circuit into the original circuit. let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&chunk_sg) .unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}")); - let node_map = circ - .insert_subgraph(root, &self.circ, &subgraph) - .expect("Failed to insert the chunk subgraph"); + let node_map = circ.insert_subgraph(root, &self.circ, &subgraph); let mut input_map = HashMap::with_capacity(self.inputs.len()); let mut output_map = HashMap::with_capacity(self.outputs.len()); @@ -268,7 +264,7 @@ impl CircuitChunks { op_cost: impl Fn(&OpType) -> C, ) -> Self { let root_meta = circ.get_node_metadata(circ.root()).cloned(); - let signature = circ.circuit_signature().clone(); + let signature = circ.circuit_signature().body().clone(); let [circ_input, circ_output] = circ.get_io(circ.root()).unwrap(); let input_connections = circ @@ -421,11 +417,11 @@ impl CircuitChunks { continue; }; for (target, target_port) in tgts { - reassembled.connect(source, source_port, target, target_port)?; + reassembled.connect(source, source_port, target, target_port); } } - reassembled.overwrite_node_metadata(root, self.root_meta)?; + reassembled.overwrite_node_metadata(root, self.root_meta); Ok(reassembled) } diff --git a/tket2/src/passes/commutation.rs b/tket2/src/passes/commutation.rs index 36124566..10ed99e6 100644 --- a/tket2/src/passes/commutation.rs +++ b/tket2/src/passes/commutation.rs @@ -207,10 +207,6 @@ impl Rewrite for PullForward { type ApplyResult = (); - type InvalidationSet<'a> = std::vec::IntoIter - where - Self: 'a; - const UNCHANGED_ON_FAILURE: bool = false; fn verify(&self, _h: &impl HugrView) -> Result<(), Self::Error> { @@ -248,10 +244,10 @@ impl Rewrite for PullForward { // do not need to commute along this qubit. continue; } - h.disconnect(command.node(), in_port)?; - h.disconnect(command.node(), out_port)?; + h.disconnect(command.node(), in_port); + h.disconnect(command.node(), out_port); // connect old source and destination - identity operation. - h.connect(src, src_port.index(), dst, dst_port.index())?; + h.connect(src, src_port.index(), dst, dst_port.index()); let new_dst_port = qb_port(new_neighbour_com, qb, Direction::Incoming)?; let (new_src, new_src_port) = h @@ -260,25 +256,25 @@ impl Rewrite for PullForward { .ok() .unwrap(); // disconnect link which we will insert in to. - h.disconnect(new_neighbour_com.node(), new_dst_port)?; + h.disconnect(new_neighbour_com.node(), new_dst_port); h.connect( new_src, new_src_port.index(), command.node(), in_port.index(), - )?; + ); h.connect( command.node(), out_port.index(), new_neighbour_com.node(), new_dst_port.index(), - )?; + ); } Ok(()) } - fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + fn invalidation_set(&self) -> impl Iterator { // TODO: This could avoid creating a vec, but it'll be easier to do once // return position impl trait is available. // This is done in the Rewrite trait of hugr so once that version diff --git a/tket2/src/portmatching/matcher.rs b/tket2/src/portmatching/matcher.rs index ff9274a2..0d16a54d 100644 --- a/tket2/src/portmatching/matcher.rs +++ b/tket2/src/portmatching/matcher.rs @@ -52,11 +52,11 @@ impl From for MatchOp { // identified by their name. let encoded = match op { OpType::Module(_) => None, - OpType::LeafOp(leaf) - if leaf + OpType::CustomOp(custom_op) + if custom_op .as_extension_op() .map(|ext| ext.args().is_empty()) - .unwrap_or_default() => + .unwrap_or(false) => { None } diff --git a/tket2/src/portmatching/pattern.rs b/tket2/src/portmatching/pattern.rs index fc0ecc21..4020a6bf 100644 --- a/tket2/src/portmatching/pattern.rs +++ b/tket2/src/portmatching/pattern.rs @@ -140,7 +140,7 @@ mod tests { use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; use hugr::extension::prelude::QB_T; - use hugr::ops::LeafOp; + use hugr::ops::OpType; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::FunctionType; use hugr::Hugr; @@ -253,13 +253,9 @@ mod tests { } fn get_nodes_by_tk2op(circ: &impl Circuit, t2_op: Tk2Op) -> Vec { + let t2_op: OpType = t2_op.into(); circ.nodes() - .filter(|n| { - let Ok(op): Result = circ.get_optype(*n).clone().try_into() else { - return false; - }; - op == t2_op.into() - }) + .filter(|n| circ.get_optype(*n) == &t2_op) .collect() } diff --git a/tket2/src/rewrite/trace.rs b/tket2/src/rewrite/trace.rs index 4d0a4770..f2c17494 100644 --- a/tket2/src/rewrite/trace.rs +++ b/tket2/src/rewrite/trace.rs @@ -78,9 +78,7 @@ pub trait RewriteTracer: Circuit + HugrMut + Sized { if !REWRITE_TRACING_ENABLED { return; } - let meta = self - .get_metadata_mut(self.root(), METADATA_REWRITES) - .unwrap(); + let meta = self.get_metadata_mut(self.root(), METADATA_REWRITES); if *meta == NodeMetadata::Null { *meta = NodeMetadata::Array(vec![]); } @@ -96,8 +94,7 @@ pub trait RewriteTracer: Circuit + HugrMut + Sized { } match self .get_metadata_mut(self.root(), METADATA_REWRITES) - .ok() - .and_then(|m| m.as_array_mut()) + .as_array_mut() { Some(meta) => { let rewrite = rewrite.into(); diff --git a/tket2/src/utils.rs b/tket2/src/utils.rs index 22456624..1c46692d 100644 --- a/tket2/src/utils.rs +++ b/tket2/src/utils.rs @@ -24,7 +24,7 @@ pub(crate) fn build_simple_circuit( let qbs = h.input_wires(); - let mut circ = h.as_circuit(qbs.into_iter().collect()); + let mut circ = h.as_circuit(qbs); f(&mut circ)?;