diff --git a/hugr-cli/src/extensions.rs b/hugr-cli/src/extensions.rs index 792d997e7..d6e160f76 100644 --- a/hugr-cli/src/extensions.rs +++ b/hugr-cli/src/extensions.rs @@ -28,9 +28,9 @@ impl ExtArgs { pub fn run_dump(&self, registry: &ExtensionRegistry) { let base_dir = &self.outdir; - for (name, ext) in registry.iter() { + for ext in registry.iter() { let mut path = base_dir.clone(); - for part in name.split('.') { + for part in ext.name().split('.') { path.push(part); } path.set_extension("json"); diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index baf2e1a3b..f3fc35f83 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -190,7 +190,7 @@ fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains(" Extension 'arithmetic.float.types' not found")); + .stderr(contains(" requires extension arithmetic.float.types")); } #[rstest] diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index 5f19cc4c5..a02c38816 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -241,8 +241,8 @@ pub(crate) mod test { use crate::extension::prelude::{bool_t, usize_t}; use crate::hugr::{views::HugrView, HugrMut}; use crate::ops; - use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; use crate::types::{PolyFuncType, Signature}; + use crate::utils::test_quantum_extension; use crate::Hugr; use super::handle::BuildHandle; @@ -269,7 +269,7 @@ pub(crate) mod test { f(f_builder)?; - Ok(module_builder.finish_hugr(&FLOAT_OPS_REGISTRY)?) + Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?) } #[fixture] diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index edf9a9049..5a2b18a04 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -243,6 +243,7 @@ mod test { use super::*; use cool_asserts::assert_matches; + use crate::builder::{Container, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY}; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; @@ -308,33 +309,44 @@ mod test { .instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY) .unwrap(); - let build_res = build_main( - Signature::new( - vec![qb_t(), qb_t(), usize_t()], - vec![qb_t(), qb_t(), bool_t()], + let mut module_builder = ModuleBuilder::new(); + let mut f_build = module_builder + .define_function( + "main", + Signature::new( + vec![qb_t(), qb_t(), usize_t()], + vec![qb_t(), qb_t(), bool_t()], + ) + .with_extension_delta(ExtensionSet::from_iter([ + test_quantum_extension::EXTENSION_ID, + my_ext_name, + ])), ) - .with_extension_delta(ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - my_ext_name, - ])) - .into(), - |mut f_build| { - let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); + .unwrap(); - let mut linear = f_build.as_circuit([q0, q1]); + let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); - let measure_out = linear - .append(cx_gate(), [0, 1])? - .append_and_consume( - my_custom_op, - [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], - )? - .append_with_outputs(measure(), [0])?; - - let out_qbs = linear.finish(); - f_build.finish_with_outputs(out_qbs.into_iter().chain(measure_out)) - }, - ); + let mut linear = f_build.as_circuit([q0, q1]); + + let measure_out = linear + .append(cx_gate(), [0, 1]) + .unwrap() + .append_and_consume( + my_custom_op, + [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], + ) + .unwrap() + .append_with_outputs(measure(), [0]) + .unwrap(); + + let out_qbs = linear.finish(); + f_build + .finish_with_outputs(out_qbs.into_iter().chain(measure_out)) + .unwrap(); + + let mut registry = test_quantum_extension::REG.clone(); + registry.register(my_ext).unwrap(); + let build_res = module_builder.finish_hugr(®istry); assert_matches!(build_res, Ok(_)); } diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index af2df0ecf..5d0c7e0c2 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -325,7 +325,7 @@ pub(crate) mod test { use crate::std_extensions::logic::test::and_op; use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV}; - use crate::utils::test_quantum_extension::h_gate; + use crate::utils::test_quantum_extension::{self, h_gate}; use crate::{builder::test::n_identity, type_row, Wire}; use super::super::test::simple_dfg_hugr; @@ -342,8 +342,10 @@ pub(crate) mod test { let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?; let inner_id = n_identity(inner_builder)?; - outer_builder - .finish_prelude_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs())) + outer_builder.finish_hugr_with_outputs( + inner_id.outputs().chain(q_out.outputs()), + &test_quantum_extension::REG, + ) }; assert_eq!(build_result.err(), None); @@ -361,7 +363,7 @@ pub(crate) mod test { f(&mut builder)?; - builder.finish_hugr(&EMPTY_REG) + builder.finish_hugr(&test_quantum_extension::REG) }; assert_matches!(build_result, Ok(_), "Failed on example: {}", msg); @@ -583,7 +585,7 @@ pub(crate) mod test { let add_c = add_c.finish_with_outputs(wires)?; let [w] = add_c.outputs_arr(); - parent.finish_hugr_with_outputs([w], &EMPTY_REG)?; + parent.finish_hugr_with_outputs([w], &test_quantum_extension::REG)?; Ok(()) } diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index afd46b11a..97c971822 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -21,24 +21,26 @@ use crate::types::RowVariable; use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; +mod const_fold; mod op_def; +pub mod prelude; +pub mod resolution; +pub mod simple_op; +mod type_def; + +pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder}; pub use op_def::{ CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, ValidateJustArgs, ValidateTypeArgs, }; -mod type_def; -pub use type_def::{TypeDef, TypeDefBound}; -mod const_fold; -pub mod prelude; -pub mod simple_op; -pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; +pub use type_def::{TypeDef, TypeDefBound}; #[cfg(feature = "declarative")] pub mod declarative; /// Extension Registries store extensions to be looked up e.g. during validation. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct ExtensionRegistry(BTreeMap>); impl ExtensionRegistry { @@ -96,14 +98,15 @@ impl ExtensionRegistry { } } - /// Registers a new extension to the registry, keeping most up to date if extension exists. + /// Registers a new extension to the registry, keeping the one most up to + /// date if the extension already exists. /// /// If extension IDs match, the extension with the higher version is kept. - /// If versions match, the original extension is kept. - /// Returns a reference to the registered extension if successful. + /// If versions match, the original extension is kept. Returns a reference + /// to the registered extension if successful. /// - /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see - /// [`ExtensionRegistry::register_updated_ref`]. + /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, + /// see [`ExtensionRegistry::register_updated_ref`]. pub fn register_updated(&mut self, extension: impl Into>) { let extension = extension.into(); match self.0.entry(extension.name().clone()) { @@ -118,8 +121,8 @@ impl ExtensionRegistry { } } - /// Registers a new extension to the registry, keeping most up to date if - /// extension exists. + /// Registers a new extension to the registry, keeping the one most up to + /// date if the extension already exists. /// /// If extension IDs match, the extension with the higher version is kept. /// If versions match, the original extension is kept. Returns a reference @@ -151,8 +154,8 @@ impl ExtensionRegistry { } /// Returns an iterator over the extensions in the registry. - pub fn iter(&self) -> impl Iterator)> { - self.0.iter() + pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter { + self.0.values() } /// Returns an iterator over the extensions ids in the registry. @@ -167,12 +170,38 @@ impl ExtensionRegistry { } impl IntoIterator for ExtensionRegistry { - type Item = (ExtensionId, Arc); + type Item = Arc; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = std::collections::btree_map::IntoValues>; fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() + self.0.into_values() + } +} + +impl<'a> IntoIterator for &'a ExtensionRegistry { + type Item = &'a Arc; + + type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc>; + + fn into_iter(self) -> Self::IntoIter { + self.0.values() + } +} + +impl<'a> Extend<&'a Arc> for ExtensionRegistry { + fn extend>>(&mut self, iter: T) { + for ext in iter { + self.register_updated_ref(ext); + } + } +} + +impl Extend> for ExtensionRegistry { + fn extend>>(&mut self, iter: T) { + for ext in iter { + self.register_updated(ext); + } } } @@ -197,8 +226,13 @@ pub enum SignatureError { #[error("Invalid type arguments for operation")] InvalidTypeArgs, /// The Extension Registry did not contain an Extension referenced by the Signature - #[error("Extension '{0}' not found")] - ExtensionNotFound(ExtensionId), + #[error("Extension '{missing}' not found. Available extensions: {}", + available.iter().map(|e| e.to_string()).collect::>().join(", ") + )] + ExtensionNotFound { + missing: ExtensionId, + available: Vec, + }, /// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature #[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")] ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName }, @@ -537,7 +571,7 @@ impl Extension { ExtensionOp::new(op_def.clone(), args, ext_reg) } - // Validates against a registry, which we can assume includes this extension itself. + /// Validates against a registry, which we can assume includes this extension itself. // (TODO deal with the registry itself containing invalid extensions!) fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> { // We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624 diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index 94d557895..1f6361b3e 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -354,12 +354,9 @@ extensions: let new_exts = new_extensions(®, dependencies).collect_vec(); assert_eq!(new_exts.len(), num_declarations); + assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types); assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.types()).count(), - num_types - ); - assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.operations()).count(), + new_exts.iter().flat_map(|e| e.operations()).count(), num_operations ); Ok(()) @@ -381,12 +378,9 @@ extensions: let new_exts = new_extensions(®, dependencies).collect_vec(); assert_eq!(new_exts.len(), num_declarations); + assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types); assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.types()).count(), - num_types - ); - assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.operations()).count(), + new_exts.iter().flat_map(|e| e.operations()).count(), num_operations ); Ok(()) @@ -413,8 +407,8 @@ extensions: fn new_extensions<'a>( reg: &'a ExtensionRegistry, dependencies: &'a ExtensionRegistry, - ) -> impl Iterator)> { + ) -> impl Iterator> { reg.iter() - .filter(move |(id, _)| !dependencies.contains(id) && *id != &PRELUDE_ID) + .filter(move |ext| !dependencies.contains(ext.name()) && ext.name() != &PRELUDE_ID) } } diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 8985e5002..bb6708a86 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -960,6 +960,7 @@ mod test { use crate::builder::inout_sig; use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; use crate::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; + use crate::utils::test_quantum_extension; use crate::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}, utils::test_quantum_extension::cx_gate, @@ -1150,7 +1151,8 @@ mod test { .add_dataflow_op(panic_op, [err, q0, q1]) .unwrap() .outputs_arr(); - b.finish_prelude_hugr_with_outputs([q0, q1]).unwrap(); + b.finish_hugr_with_outputs([q0, q1], &test_quantum_extension::REG) + .unwrap(); } #[test] diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs new file mode 100644 index 000000000..93e1b97f9 --- /dev/null +++ b/hugr-core/src/extension/resolution.rs @@ -0,0 +1,103 @@ +//! Utilities for resolving operations and types present in a HUGR, and updating +//! the list of used extensions. See [`crate::Hugr::resolve_extension_defs`]. +//! +//! When listing "used extensions" we only care about _definitional_ extension +//! requirements, i.e., the operations and types that are required to define the +//! HUGR nodes and wire types. This is computed from the union of all extension +//! required across the HUGR. +//! +//! This is distinct from _runtime_ extension requirements, which are defined +//! more granularly in each function signature by the `required_extensions` +//! field. See the `extension_inference` feature and related modules for that. +//! +//! Note: These procedures are only temporary until `hugr-model` is stabilized. +//! Once that happens, hugrs will no longer be directly deserialized using serde +//! but instead will be created by the methods in `crate::import`. As these +//! (will) automatically resolve extensions as the operations are created, +//! we will no longer require this post-facto resolution step. + +mod ops; +mod types; + +pub(crate) use ops::update_op_extensions; +pub(crate) use types::update_op_types_extensions; + +use derive_more::{Display, Error, From}; + +use super::{Extension, ExtensionId, ExtensionRegistry}; +use crate::ops::custom::OpaqueOpError; +use crate::ops::{NamedOp, OpName, OpType}; +use crate::types::TypeName; +use crate::Node; + +/// Errors that can occur during extension resolution. +#[derive(Debug, Display, Clone, Error, From, PartialEq)] +#[non_exhaustive] +pub enum ExtensionResolutionError { + /// Could not resolve an opaque operation to an extension operation. + #[display("Error resolving opaque operation: {_0}")] + #[from] + OpaqueOpError(OpaqueOpError), + /// An operation requires an extension that is not in the given registry. + #[display( + "{op} ({node}) requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + available_extensions.join(", ") + )] + MissingOpExtension { + /// The node that requires the extension. + node: Node, + /// The operation that requires the extension. + op: OpName, + /// The missing extension + missing_extension: ExtensionId, + /// A list of available extensions. + available_extensions: Vec, + }, + #[display( + "Type {ty} in {node} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + available_extensions.join(", ") + )] + /// A type references an extension that is not in the given registry. + MissingTypeExtension { + /// The node that requires the extension. + node: Node, + /// The type that requires the extension. + ty: TypeName, + /// The missing extension + missing_extension: ExtensionId, + /// A list of available extensions. + available_extensions: Vec, + }, +} + +impl ExtensionResolutionError { + /// Create a new error for missing operation extensions. + pub fn missing_op_extension( + node: Node, + op: &OpType, + missing_extension: &ExtensionId, + extensions: &ExtensionRegistry, + ) -> Self { + Self::MissingOpExtension { + node, + op: NamedOp::name(op), + missing_extension: missing_extension.clone(), + available_extensions: extensions.ids().cloned().collect(), + } + } + + /// Create a new error for missing type extensions. + pub fn missing_type_extension( + node: Node, + ty: &TypeName, + missing_extension: &ExtensionId, + extensions: &ExtensionRegistry, + ) -> Self { + Self::MissingTypeExtension { + node, + ty: ty.clone(), + missing_extension: missing_extension.clone(), + available_extensions: extensions.ids().cloned().collect(), + } + } +} diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs new file mode 100644 index 000000000..7c8fbfc37 --- /dev/null +++ b/hugr-core/src/extension/resolution/ops.rs @@ -0,0 +1,92 @@ +//! Resolve `OpaqueOp`s into `ExtensionOp`s and return an operation's required extension. + +use std::sync::Arc; + +use super::{Extension, ExtensionRegistry, ExtensionResolutionError}; +use crate::ops::custom::OpaqueOpError; +use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType}; +use crate::Node; + +/// Compute the required extension for an operation. +/// +/// If the op is a [`OpType::OpaqueOp`], replace it with a resolved +/// [`OpType::ExtensionOp`] by looking searching for the operation in the +/// extension registries. +/// +/// If `op` was an opaque or extension operation, the result contains the +/// extension reference that should be added to the hugr's extension registry. +/// +/// # Errors +/// +/// If the serialized opaque resolves to a definition that conflicts with what +/// was serialized. Or if the operation is not found in the registry. +pub(crate) fn update_op_extensions<'e>( + node: Node, + op: &mut OpType, + extensions: &'e ExtensionRegistry, +) -> Result>, ExtensionResolutionError> { + let extension = operation_extension(node, op, extensions)?; + + let OpType::OpaqueOp(opaque) = op else { + return Ok(extension); + }; + + // Fail if the Extension is not in the registry, or if the Extension was + // found but did not have the expected operation. + let extension = extension.expect("OpaqueOp should have an extension"); + let Some(def) = extension.get_op(opaque.op_name()) else { + return Err(OpaqueOpError::OpNotFoundInExtension { + node, + op: opaque.name().clone(), + extension: extension.name().clone(), + available_ops: extension + .operations() + .map(|(name, _)| name.clone()) + .collect(), + } + .into()); + }; + + let ext_op = + ExtensionOp::new_with_cached(def.clone(), opaque.args().to_vec(), opaque, extensions) + .map_err(|e| OpaqueOpError::SignatureError { + node, + name: opaque.name().clone(), + cause: e, + })?; + + if opaque.signature() != ext_op.signature() { + return Err(OpaqueOpError::SignatureMismatch { + node, + extension: opaque.extension().clone(), + op: def.name().clone(), + computed: ext_op.signature().clone(), + stored: opaque.signature().clone(), + } + .into()); + }; + + // Replace the opaque operation with the resolved extension operation. + *op = ext_op.into(); + + Ok(Some(extension)) +} + +/// Returns the extension in the registry required by the operation. +/// +/// If the operation does not require an extension, returns `None`. +fn operation_extension<'e>( + node: Node, + op: &OpType, + extensions: &'e ExtensionRegistry, +) -> Result>, ExtensionResolutionError> { + let Some(ext) = op.extension_id() else { + return Ok(None); + }; + match extensions.get(ext) { + Some(e) => Ok(Some(e)), + None => Err(ExtensionResolutionError::missing_op_extension( + node, op, ext, extensions, + )), + } +} diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs new file mode 100644 index 000000000..b249a4dca --- /dev/null +++ b/hugr-core/src/extension/resolution/types.rs @@ -0,0 +1,174 @@ +//! Resolve weak links inside `CustomType`s in an optype's signature. + +use std::sync::Arc; + +use super::{ExtensionRegistry, ExtensionResolutionError}; +use crate::ops::OpType; +use crate::types::type_row::TypeRowBase; +use crate::types::{MaybeRV, Signature, SumType, TypeBase, TypeEnum}; +use crate::Node; + +/// Replace the dangling extension pointer in the [`CustomType`]s inside a +/// signature with a valid pointer to the extension in the `extensions` +/// registry. +/// +/// When a pointer is replaced, the extension is added to the +/// `used_extensions` registry and the new type definition is returned. +/// +/// This is a helper function used right after deserializing a Hugr. +pub fn update_op_types_extensions( + node: Node, + op: &mut OpType, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + match op { + OpType::ExtensionOp(ext) => { + update_signature_exts(node, ext.signature_mut(), extensions, used_extensions)? + } + OpType::FuncDefn(f) => { + update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? + } + OpType::FuncDecl(f) => { + update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? + } + OpType::Const(_c) => { + // TODO: Is it OK to assume that `Value::get_type` returns a well-resolved value? + } + OpType::Input(inp) => { + update_type_row_exts(node, &mut inp.types, extensions, used_extensions)? + } + OpType::Output(out) => { + update_type_row_exts(node, &mut out.types, extensions, used_extensions)? + } + OpType::Call(c) => { + update_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; + update_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; + } + OpType::CallIndirect(c) => { + update_signature_exts(node, &mut c.signature, extensions, used_extensions)? + } + OpType::LoadConstant(lc) => { + update_type_exts(node, &mut lc.datatype, extensions, used_extensions)? + } + OpType::LoadFunction(lf) => { + update_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; + update_signature_exts(node, &mut lf.signature, extensions, used_extensions)?; + } + OpType::DFG(dfg) => { + update_signature_exts(node, &mut dfg.signature, extensions, used_extensions)? + } + OpType::OpaqueOp(op) => { + update_signature_exts(node, op.signature_mut(), extensions, used_extensions)? + } + OpType::Tag(t) => { + for variant in t.variants.iter_mut() { + update_type_row_exts(node, variant, extensions, used_extensions)? + } + } + OpType::DataflowBlock(db) => { + update_type_row_exts(node, &mut db.inputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut db.other_outputs, extensions, used_extensions)?; + for row in db.sum_rows.iter_mut() { + update_type_row_exts(node, row, extensions, used_extensions)?; + } + } + OpType::ExitBlock(e) => { + update_type_row_exts(node, &mut e.cfg_outputs, extensions, used_extensions)?; + } + OpType::TailLoop(tl) => { + update_type_row_exts(node, &mut tl.just_inputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut tl.just_outputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut tl.rest, extensions, used_extensions)?; + } + OpType::CFG(cfg) => { + update_signature_exts(node, &mut cfg.signature, extensions, used_extensions)?; + } + OpType::Conditional(cond) => { + for row in cond.sum_rows.iter_mut() { + update_type_row_exts(node, row, extensions, used_extensions)?; + } + update_type_row_exts(node, &mut cond.other_inputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut cond.outputs, extensions, used_extensions)?; + } + OpType::Case(case) => { + update_signature_exts(node, &mut case.signature, extensions, used_extensions)?; + } + // Ignore optypes that do not store a signature. + _ => {} + } + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a signature. +/// +/// Adds the extensions used in the signature to the `used_extensions` registry. +fn update_signature_exts( + node: Node, + signature: &mut Signature, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + // Note that we do not include the signature's `extension_reqs` here, as those refer + // to _runtime_ requirements that may not be currently present. + // See https://github.com/CQCL/hugr/issues/1734 + // TODO: Update comment once that issue gets implemented. + update_type_row_exts(node, &mut signature.input, extensions, used_extensions)?; + update_type_row_exts(node, &mut signature.output, extensions, used_extensions)?; + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. +/// +/// Adds the extensions used in the row to the `used_extensions` registry. +fn update_type_row_exts( + node: Node, + row: &mut TypeRowBase, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + for ty in row.iter_mut() { + update_type_exts(node, ty, extensions, used_extensions)?; + } + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a type. +/// +/// Adds the extensions used in the type to the `used_extensions` registry. +fn update_type_exts( + node: Node, + typ: &mut TypeBase, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + match typ.as_type_enum_mut() { + TypeEnum::Extension(custom) => { + let ext_id = custom.extension(); + let ext = extensions.get(ext_id).ok_or_else(|| { + ExtensionResolutionError::missing_type_extension( + node, + custom.name(), + ext_id, + extensions, + ) + })?; + + // Add the extension to the used extensions registry, + // and update the CustomType with the valid pointer. + used_extensions.register_updated_ref(ext); + custom.update_extension(Arc::downgrade(ext)); + } + TypeEnum::Function(f) => { + update_type_row_exts(node, &mut f.input, extensions, used_extensions)?; + update_type_row_exts(node, &mut f.output, extensions, used_extensions)?; + } + TypeEnum::Sum(SumType::General { rows }) => { + for row in rows.iter_mut() { + update_type_row_exts(node, row, extensions, used_extensions)?; + } + } + _ => {} + } + Ok(()) +} diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 9aa5d080f..dc8321f2f 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -25,8 +25,10 @@ use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; +use crate::extension::resolution::{ + update_op_extensions, update_op_types_extensions, ExtensionResolutionError, +}; use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; -use crate::ops::custom::resolve_extension_ops; use crate::ops::{OpTag, OpTrait}; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; @@ -90,7 +92,7 @@ impl Hugr { &mut self, extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { - resolve_extension_ops(self, extension_registry)?; + self.resolve_extension_defs(extension_registry)?; self.validate_no_extensions(extension_registry)?; #[cfg(feature = "extension_inference")] { @@ -170,6 +172,67 @@ impl Hugr { infer(self, self.root(), remove)?; Ok(()) } + + /// Given a Hugr that has been deserialized, collect all extensions used to + /// define the HUGR while resolving all [`OpType::OpaqueOp`] operations into + /// [`OpType::ExtensionOp`]s and updating the extension pointer in all + /// internal [`crate::types::CustomType`]s to point to the extensions in the + /// register. + /// + /// When listing "used extensions" we only care about _definitional_ + /// extension requirements, i.e., the operations and types that are required + /// to define the HUGR nodes and wire types. This is computed from the union + /// of all extension required across the HUGR. + /// + /// This is distinct from _runtime_ extension requirements computed in + /// [`Hugr::infer_extensions`], which are computed more granularly in each + /// function signature by the `required_extensions` field and define the set + /// of capabilities required by the runtime to execute each function. + /// + /// Returns a new extension registry with the extensions used in the Hugr. + /// + /// # Parameters + /// + /// - `extensions`: The extension set considered when resolving opaque + /// operations and types. The original Hugr's internal extension + /// registry is ignored and replaced with the newly computed one. + /// + /// # Errors + /// + /// - If an opaque operation cannot be resolved to an extension operation. + /// - If an extension operation references an extension that is missing from + /// the registry. + /// - If a custom type references an extension that is missing from the + /// registry. + pub fn resolve_extension_defs( + &mut self, + extensions: &ExtensionRegistry, + ) -> Result { + let mut used_extensions = ExtensionRegistry::default(); + + // Here we need to iterate the optypes in the hugr mutably, to avoid + // having to clone and accumulate all replacements before finally + // applying them. + // + // This is not something we want to expose it the API, so we manually + // iterate instead of writing it as a method. + for n in 0..self.node_count() { + let pg_node = portgraph::NodeIndex::new(n); + let node: Node = pg_node.into(); + if !self.contains_node(node) { + continue; + } + + let op = &mut self.op_types[pg_node]; + + if let Some(extension) = update_op_extensions(node, op, extensions)? { + used_extensions.register_updated_ref(extension); + } + update_op_types_extensions(node, op, extensions, &mut used_extensions)?; + } + + Ok(used_extensions) + } } /// Internal API for HUGRs, not intended for use by users. diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/rewrite/insert_identity.rs index cacd59591..0c1dc872a 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/rewrite/insert_identity.rs @@ -101,10 +101,8 @@ mod tests { use super::super::simple_replace::test::dfg_hugr; use super::*; - use crate::{ - extension::{prelude::qb_t, PRELUDE_REGISTRY}, - Hugr, - }; + use crate::utils::test_quantum_extension; + use crate::{extension::prelude::qb_t, Hugr}; #[rstest] fn correct_insertion(dfg_hugr: Hugr) { @@ -129,6 +127,6 @@ mod tests { assert_eq!(noop, Noop(qb_t())); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } } diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index ffdbb0743..3ccb4db80 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -462,7 +462,7 @@ mod test { use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG}; use crate::std_extensions::collections::{self, list_type, ListOp}; use crate::types::{Signature, Type, TypeRow}; - use crate::utils::depth; + use crate::utils::{depth, test_quantum_extension}; use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; @@ -657,6 +657,8 @@ mod test { let baz = ext .instantiate_extension_op("baz", [], &PRELUDE_REGISTRY) .unwrap(); + let mut registry = test_quantum_extension::REG.clone(); + registry.register(ext).unwrap(); let mut h = DFGBuilder::new( Signature::new(vec![usize_t(), bool_t()], vec![usize_t()]) @@ -688,7 +690,7 @@ mod test { let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); let cond = cond.finish_sub_container().unwrap(); let h = h - .finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY) + .finish_hugr_with_outputs(cond.outputs(), ®istry) .unwrap(); let mut r_hugr = Hugr::new(h.get_optype(cond.node()).clone()); diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 422875275..e5dc42841 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -243,7 +243,7 @@ pub(in crate::hugr::rewrite) mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; + use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; use crate::ops::dataflow::DataflowOpTrait; @@ -253,7 +253,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::std_extensions::logic::test::and_op; use crate::std_extensions::logic::LogicOp; use crate::types::{Signature, Type}; - use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; + use crate::utils::test_quantum_extension::{self, cx_gate, h_gate, EXTENSION_ID}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -297,7 +297,7 @@ pub(in crate::hugr::rewrite) mod test { func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))? }; - Ok(module_builder.finish_prelude_hugr()?) + Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?) } #[fixture] @@ -317,7 +317,7 @@ pub(in crate::hugr::rewrite) mod test { let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; let wire45 = dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?; - dfg_builder.finish_prelude_hugr_with_outputs(wire45.outputs()) + dfg_builder.finish_hugr_with_outputs(wire45.outputs(), &test_quantum_extension::REG) } #[fixture] @@ -337,7 +337,7 @@ pub(in crate::hugr::rewrite) mod test { let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; let wire2out = wire2.outputs().exactly_one().unwrap(); let wireoutvec = vec![wire0, wire2out]; - dfg_builder.finish_prelude_hugr_with_outputs(wireoutvec) + dfg_builder.finish_hugr_with_outputs(wireoutvec, &test_quantum_extension::REG) } #[fixture] @@ -372,7 +372,7 @@ pub(in crate::hugr::rewrite) mod test { ( dfg_builder - .finish_prelude_hugr_with_outputs([b0, b1]) + .finish_hugr_with_outputs([b0, b1], &test_quantum_extension::REG) .unwrap(), vec![not_inp.node(), not_0.node(), not_1.node()], ) @@ -404,7 +404,7 @@ pub(in crate::hugr::rewrite) mod test { ( dfg_builder - .finish_prelude_hugr_with_outputs([b0, b1]) + .finish_hugr_with_outputs([b0, b1], &test_quantum_extension::REG) .unwrap(), vec![not_inp.node(), not_0.node()], ) @@ -489,7 +489,7 @@ pub(in crate::hugr::rewrite) mod test { // ├───┤├───┤┌─┴─┐ // ┤ H ├┤ H ├┤ X ├ // └───┘└───┘└───┘ - assert_eq!(h.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(h.update_validate(&test_quantum_extension::REG), Ok(())); } #[rstest] @@ -561,7 +561,7 @@ pub(in crate::hugr::rewrite) mod test { // ├───┤├───┤┌───┐ // ┤ H ├┤ H ├┤ H ├ // └───┘└───┘└───┘ - assert_eq!(h.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(h.update_validate(&test_quantum_extension::REG), Ok(())); } #[test] @@ -573,7 +573,9 @@ pub(in crate::hugr::rewrite) mod test { circ.append(cx_gate(), [1, 0]).unwrap(); let wires = circ.finish(); let [input, output] = builder.io(); - let mut h = builder.finish_prelude_hugr_with_outputs(wires).unwrap(); + let mut h = builder + .finish_hugr_with_outputs(wires, &test_quantum_extension::REG) + .unwrap(); let replacement = h.clone(); let orig = h.clone(); @@ -632,13 +634,17 @@ pub(in crate::hugr::rewrite) mod test { .unwrap() .outputs(); let [input, _] = builder.io(); - let mut h = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap(); + let mut h = builder + .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) + .unwrap(); let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap(); let inw = builder.input_wires(); let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs(); let [repl_input, repl_output] = builder.io(); - let repl = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap(); + let repl = builder + .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) + .unwrap(); let orig = h.clone(); @@ -747,7 +753,8 @@ pub(in crate::hugr::rewrite) mod test { let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap(); let [w_not] = not.outputs_arr(); ( - b.finish_prelude_hugr_with_outputs([w, w_not]).unwrap(), + b.finish_hugr_with_outputs([w, w_not], &test_quantum_extension::REG) + .unwrap(), not.node(), ) }; diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index d5fdd1858..dba525d39 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -8,10 +8,12 @@ use crate::builder::{ use crate::extension::prelude::Noop; use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; use crate::extension::simple_op::MakeRegisteredOp; +use crate::extension::ExtensionId; use crate::extension::{test::SimpleOpDef, ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::validate::ValidationError; -use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError}; +use crate::hugr::ExtensionResolutionError; +use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{self, dataflow::IOTrait, Input, Module, Output, Value, DFG}; use crate::std_extensions::arithmetic::float_types::float64_type; use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY; @@ -22,6 +24,7 @@ use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, TypeRV, }; +use crate::utils::test_quantum_extension; use crate::{type_row, OutgoingPort}; use itertools::Itertools; @@ -331,7 +334,7 @@ fn dfg_roundtrip() -> Result<(), Box> { .unwrap() .out_wire(0); } - let hugr = dfg.finish_hugr_with_outputs(params, &EMPTY_REG)?; + let hugr = dfg.finish_hugr_with_outputs(params, &test_quantum_extension::REG)?; check_hugr_roundtrip(&hugr, true); Ok(()) @@ -350,7 +353,7 @@ fn extension_ops() -> Result<(), Box> { .unwrap() .out_wire(0); - let hugr = dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY)?; + let hugr = dfg.finish_hugr_with_outputs([wire], &test_quantum_extension::REG)?; check_hugr_roundtrip(&hugr, true); Ok(()) @@ -368,6 +371,7 @@ fn opaque_ops() -> Result<(), Box> { .add_dataflow_op(extension_op.clone(), [wire]) .unwrap() .out_wire(0); + let not_node = wire.node(); // Add an unresolved opaque operation let opaque_op: OpaqueOp = extension_op.into(); @@ -376,11 +380,14 @@ fn opaque_ops() -> Result<(), Box> { assert_eq!( dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY), - Err(ValidationError::OpaqueOpError(OpaqueOpError::UnresolvedOp( - wire.node(), - "Not".into(), - ext_name - )) + Err(ValidationError::ExtensionResolutionError( + ExtensionResolutionError::MissingOpExtension { + node: not_node, + op: "logic.Not".into(), + missing_extension: ext_name, + available_extensions: vec![ExtensionId::new("prelude").unwrap()] + } + ) .into()) ); @@ -554,7 +561,7 @@ fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { // test all standard extension serialisations are valid against scheme fn std_extensions_valid() { let std_reg = crate::std_extensions::std_reg(); - for (_, ext) in std_reg.into_iter() { + for ext in std_reg { let val = serde_json::to_value(ext).unwrap(); NamedSchema::check_schemas(&val, get_schemas(true)); // check deserialises correctly, can't check equality because of custom binaries. diff --git a/hugr-core/src/hugr/serialize/upgrade/test.rs b/hugr-core/src/hugr/serialize/upgrade/test.rs index d08dea520..e9837ddc0 100644 --- a/hugr-core/src/hugr/serialize/upgrade/test.rs +++ b/hugr-core/src/hugr/serialize/upgrade/test.rs @@ -4,6 +4,7 @@ use crate::{ hugr::serialize::test::check_hugr_deserialize, std_extensions::logic::LogicOp, types::Signature, + utils::test_quantum_extension, }; use lazy_static::lazy_static; use std::{ @@ -50,7 +51,7 @@ pub fn hugr_with_named_op() -> Hugr { let [a, b] = builder.input_wires_arr(); let x = builder.add_dataflow_op(LogicOp::And, [a, b]).unwrap(); builder - .finish_prelude_hugr_with_outputs(x.outputs()) + .finish_hugr_with_outputs(x.outputs(), &test_quantum_extension::REG) .unwrap() } diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 02d7a0fc6..6c9c92753 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,12 +9,13 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; +use crate::extension::resolution::ExtensionResolutionError; use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED}; use crate::ops::constant::ConstTypeError; use crate::ops::custom::{ExtensionOp, OpaqueOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{FuncDefn, OpParent, OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, NamedOp, OpName, OpParent, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, Signature}; use crate::{Direction, Hugr, Node, Port}; @@ -259,7 +260,11 @@ impl<'a, 'b> ValidationContext<'a, 'b> { } self.validate_port_kind(&port_kind, var_decls) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; + .map_err(|cause| ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + })?; let mut link_cnt = 0; for (_, link) in links { @@ -579,7 +584,11 @@ impl<'a, 'b> ValidationContext<'a, 'b> { ext_op .def() .validate_args(ext_op.args(), self.extension_registry, var_decls) - .map_err(|cause| ValidationError::SignatureError { node, cause }) + .map_err(|cause| ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + }) }; match op_type { OpType::ExtensionOp(ext_op) => validate_ext(ext_op)?, @@ -591,12 +600,22 @@ impl<'a, 'b> ValidationContext<'a, 'b> { ))?; } OpType::Call(c) => { - c.validate(self.extension_registry) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; + c.validate(self.extension_registry).map_err(|cause| { + ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + } + })?; } OpType::LoadFunction(c) => { - c.validate(self.extension_registry) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; + c.validate(self.extension_registry).map_err(|cause| { + ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + } + })?; } _ => (), } @@ -738,9 +757,10 @@ pub enum ValidationError { #[error("Node {node} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only")] ExtensionsNotInferred { node: Node }, /// Error in a node signature - #[error("Error in signature of node {node}: {cause}")] + #[error("Error in signature of operation {op} at {node}: {cause}")] SignatureError { node: Node, + op: OpName, #[source] cause: SignatureError, }, @@ -757,6 +777,11 @@ pub enum ValidationError { /// [Type]: crate::types::Type #[error(transparent)] ConstTypeError(#[from] ConstTypeError), + /// Some operations or types in the HUGR reference invalid extensions. + // + // TODO: Remove once `hugr::update_validate` is removed. + #[error(transparent)] + ExtensionResolutionError(#[from] ExtensionResolutionError), } /// Errors related to the inter-graph edge validations. diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 9be6103a5..e98744a82 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -26,6 +26,7 @@ use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, TypeRow, }; +use crate::utils::test_quantum_extension; use crate::{ const_extension_ids, test_file, type_row, Direction, IncomingPort, Node, OutgoingPort, }; @@ -127,13 +128,13 @@ fn dfg_root() { let mut b = Hugr::new(dfg_op); let root = b.root(); add_df_children(&mut b, root, 1); - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); } #[test] fn simple_hugr() { let mut b = make_simple_hugr(2).0; - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); } #[test] @@ -158,7 +159,7 @@ fn children_restrictions() { }, ); assert_matches!( - b.update_validate(&EMPTY_REG), + b.update_validate(&test_quantum_extension::REG), Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) ); @@ -166,7 +167,7 @@ fn children_restrictions() { add_df_children(&mut b, new_def, 2); b.set_parent(new_def, copy); assert_matches!( - b.update_validate(&EMPTY_REG), + b.update_validate(&test_quantum_extension::REG), Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) ); b.set_parent(new_def, root); @@ -175,7 +176,7 @@ fn children_restrictions() { // add an input node to the module subgraph let new_input = b.add_node_with_parent(root, ops::Input::new(type_row![])); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)} ); } @@ -194,7 +195,7 @@ fn df_children_restrictions() { // Replace the output operation of the df subgraph with a copy b.replace_op(output, Noop(usize_t())).unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) ); @@ -202,7 +203,7 @@ fn df_children_restrictions() { b.replace_op(output, ops::Output::new(vec![bool_t()])) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} ); @@ -213,7 +214,7 @@ fn df_children_restrictions() { b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} ); @@ -248,20 +249,20 @@ fn test_ext_edge() { h.connect(sub_dfg, 0, output, 0); assert_matches!( - h.update_validate(&EMPTY_REG), + h.update_validate(&test_quantum_extension::REG), Err(ValidationError::UnconnectedPort { .. }) ); h.connect(input, 1, sub_op, 1); assert_matches!( - h.update_validate(&EMPTY_REG), + h.update_validate(&test_quantum_extension::REG), Err(ValidationError::InterGraphEdgeError( InterGraphEdgeError::MissingOrderEdge { .. } )) ); //Order edge. This will need metadata indicating its purpose. h.add_other_edge(input, sub_dfg); - h.update_validate(&EMPTY_REG).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } #[test] @@ -277,7 +278,7 @@ fn no_ext_edge_into_func() -> Result<(), Box> { let func = func.finish_with_outputs(and_op.outputs())?; let loadfn = dfg.load_func(func.handle(), &[], &EMPTY_REG)?; let dfg = dfg.finish_with_outputs([loadfn])?; - let res = h.finish_hugr_with_outputs(dfg.outputs(), &EMPTY_REG); + let res = h.finish_hugr_with_outputs(dfg.outputs(), &test_quantum_extension::REG); assert_eq!( res, Err(BuildError::InvalidHUGR( @@ -302,7 +303,7 @@ fn test_local_const() { h.connect(input, 0, and, 0); h.connect(and, 0, output, 0); assert_eq!( - h.update_validate(&EMPTY_REG), + h.update_validate(&test_quantum_extension::REG), Err(ValidationError::UnconnectedPort { node: and, port: IncomingPort::from(1).into(), @@ -323,7 +324,7 @@ fn test_local_const() { h.connect(lcst, 0, and, 1); assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: - h.update_validate(&EMPTY_REG).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } #[test] @@ -339,7 +340,10 @@ fn dfg_with_cycles() { h.connect(input, 1, not2, 0); h.connect(not2, 0, output, 0); // The graph contains a cycle: - assert_matches!(h.validate(&EMPTY_REG), Err(ValidationError::NotADag { .. })); + assert_matches!( + h.validate(&test_quantum_extension::REG), + Err(ValidationError::NotADag { .. }) + ); } fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { @@ -361,15 +365,12 @@ fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { } #[test] fn unregistered_extension() { - let (mut h, def) = identity_hugr_with_type(usize_t()); - assert_eq!( + let (mut h, _def) = identity_hugr_with_type(usize_t()); + assert_matches!( h.validate(&EMPTY_REG), - Err(ValidationError::SignatureError { - node: def, - cause: SignatureError::ExtensionNotFound(PRELUDE.name.clone()) - }) + Err(ValidationError::SignatureError { .. }) ); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } const_extension_ids! { @@ -392,7 +393,7 @@ fn invalid_types() { let validate_to_sig_error = |t: CustomType| { let (h, def) = identity_hugr_with_type(Type::new_extension(t)); match h.validate(®) { - Err(ValidationError::SignatureError { node, cause }) if node == def => cause, + Err(ValidationError::SignatureError { node, cause, .. }) if node == def => cause, e => panic!( "Expected SignatureError at def node, got {}", match e { @@ -831,7 +832,7 @@ fn cfg_children_restrictions() { .unwrap(); // Write Extension annotations into the Hugr while it's still well-formed // enough for us to compute them - b.validate(&EMPTY_REG).unwrap(); + b.validate(&test_quantum_extension::REG).unwrap(); b.replace_op( copy, ops::CFG { @@ -840,7 +841,7 @@ fn cfg_children_restrictions() { ) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::ContainerWithoutChildren { .. }) ); let cfg = copy; @@ -876,7 +877,7 @@ fn cfg_children_restrictions() { }, ); b.add_other_edge(block, exit); - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); // Test malformed errors @@ -888,7 +889,7 @@ fn cfg_children_restrictions() { }, ); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} ); @@ -923,7 +924,7 @@ fn cfg_children_restrictions() { ) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) => assert_eq!(parent, cfg) ); diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index c1aa7da1f..f7b893ddf 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -176,7 +176,7 @@ pub(super) mod test { use rstest::rstest; use crate::extension::prelude::{qb_t, usize_t}; - use crate::extension::PRELUDE_REGISTRY; + use crate::utils::test_quantum_extension; use crate::IncomingPort; use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, @@ -214,7 +214,7 @@ pub(super) mod test { func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?; (f_id, inner_id) }; - let hugr = module_builder.finish_prelude_hugr()?; + let hugr = module_builder.finish_hugr(&test_quantum_extension::REG)?; Ok((hugr, f_id.handle().node(), inner_id.handle().node())) } @@ -291,7 +291,7 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; let extracted = region.extract_hugr(); - extracted.validate(&PRELUDE_REGISTRY)?; + extracted.validate(&test_quantum_extension::REG)?; let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 6fb4246d0..b9710f00a 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -337,12 +337,11 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::extension::PRELUDE_REGISTRY; use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID}; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::ops::{OpTrait, OpType}; use crate::types::Signature; - use crate::utils::test_quantum_extension::EXTENSION_ID; + use crate::utils::test_quantum_extension::{self, EXTENSION_ID}; use crate::IncomingPort; use super::super::descendants::test::make_module_hgr; @@ -457,7 +456,7 @@ mod test { let ins = dfg.input_wires(); let sub_dfg = dfg.finish_with_outputs(ins)?; let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?; - let h = module_builder.finish_hugr(&PRELUDE_REGISTRY)?; + let h = module_builder.finish_hugr(&test_quantum_extension::REG)?; let sub_dfg = sub_dfg.node(); // We can create a view from a child or grandchild of a hugr: @@ -486,7 +485,9 @@ mod test { /// Mutate a SiblingMut wrapper #[rstest] fn flat_mut(mut simple_dfg_hugr: Hugr) { - simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap(); + simple_dfg_hugr + .update_validate(&test_quantum_extension::REG) + .unwrap(); let root = simple_dfg_hugr.root(); let signature = simple_dfg_hugr.inner_function_type().unwrap().clone(); @@ -511,7 +512,9 @@ mod test { // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); - assert!(simple_dfg_hugr.validate(&PRELUDE_REGISTRY).is_err()); + assert!(simple_dfg_hugr + .validate(&test_quantum_extension::REG) + .is_err()); } #[rstest] @@ -540,7 +543,7 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; let extracted = region.extract_hugr(); - extracted.validate(&PRELUDE_REGISTRY)?; + extracted.validate(&test_quantum_extension::REG)?; let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index de82ae591..ea23e17de 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -795,10 +795,7 @@ mod tests { BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }, - extension::{ - prelude::{bool_t, qb_t}, - EMPTY_REG, - }, + extension::prelude::{bool_t, qb_t}, hugr::views::{HierarchyView, SiblingGraph}, ops::handle::{DfgID, FuncID, NodeHandle}, std_extensions::logic::test::and_op, @@ -881,7 +878,7 @@ mod tests { dfg.finish_with_outputs(outs3.outputs())? }; let hugr = mod_builder - .finish_prelude_hugr() + .finish_hugr(&test_quantum_extension::REG) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -903,7 +900,7 @@ mod tests { dfg.finish_with_outputs([b1, b2])? }; let hugr = mod_builder - .finish_prelude_hugr() + .finish_hugr(&test_quantum_extension::REG) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -924,7 +921,7 @@ mod tests { dfg.finish_with_outputs(outs.outputs())? }; let hugr = mod_builder - .finish_hugr(&EMPTY_REG) + .finish_hugr(&test_quantum_extension::REG) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -1170,7 +1167,9 @@ mod tests { .unwrap() .outputs(); let outw = [outw1].into_iter().chain(outw2); - let h = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap(); + let h = builder + .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) + .unwrap(); let view = SiblingGraph::::try_new(&h, h.root()).unwrap(); let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap(); assert_eq!(subg.nodes().len(), 2); diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index a703ece54..f69779082 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -1,6 +1,8 @@ use portgraph::PortOffset; use rstest::{fixture, rstest}; +use crate::std_extensions::logic::LOGIC_REG; +use crate::utils::test_quantum_extension; use crate::{ builder::{ endo_sig, inout_sig, BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, @@ -32,7 +34,8 @@ pub(crate) fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle Option<&ExtensionId> { + match self { + OpType::OpaqueOp(opaque) => Some(opaque.extension()), + OpType::ExtensionOp(e) => Some(e.def().extension_id()), + _ => None, + } + } } /// Macro used by operations that want their diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 965a8e31c..5d32dee5a 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -1,5 +1,6 @@ //! Extensible operations. +use itertools::Itertools; use std::sync::Arc; use thiserror::Error; #[cfg(test)] @@ -11,14 +12,12 @@ use { }; use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; -use crate::hugr::internal::HugrMutInternals; -use crate::hugr::HugrView; use crate::types::{type_param::TypeArg, Signature}; -use crate::{ops, Hugr, IncomingPort, Node}; +use crate::{ops, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; use super::tag::OpTag; -use super::{NamedOp, OpName, OpNameRef, OpTrait, OpType}; +use super::{NamedOp, OpName, OpNameRef}; /// An operation defined by an [OpDef] from a loaded [Extension]. /// @@ -56,13 +55,13 @@ impl ExtensionOp { } /// If OpDef is missing binary computation, trust the cached signature. - fn new_with_cached( + pub(crate) fn new_with_cached( def: Arc, - args: impl Into>, + args: impl IntoIterator, opaque: &OpaqueOp, exts: &ExtensionRegistry, ) -> Result { - let args: Vec = args.into(); + let args: Vec = args.into_iter().collect(); // TODO skip computation depending on config // see https://github.com/CQCL/hugr/issues/1363 let signature = match def.compute_signature(&args, exts) { @@ -99,7 +98,8 @@ impl ExtensionOp { /// [`ExtensionOp`]. /// /// Regenerating the [`ExtensionOp`] back from the [`OpaqueOp`] requires a - /// registry with the appropriate extension. See [`resolve_opaque_op`]. + /// registry with the appropriate extension. See + /// [`crate::Hugr::resolve_extension_defs`]. /// /// For a non-cloning version of this operation, use [`OpaqueOp::from`]. pub fn make_opaque(&self) -> OpaqueOp { @@ -111,6 +111,11 @@ impl ExtensionOp { signature: self.signature.clone(), } } + + /// Returns a mutable reference to the cached signature of the operation. + pub fn signature_mut(&mut self) -> &mut Signature { + &mut self.signature + } } impl From for OpaqueOp { @@ -222,6 +227,11 @@ impl OpaqueOp { signature, } } + + /// Returns a mutable reference to the signature of the operation. + pub fn signature_mut(&mut self) -> &mut Signature { + &mut self.signature + } } impl NamedOp for OpaqueOp { @@ -273,89 +283,25 @@ impl DataflowOpTrait for OpaqueOp { } } -/// Resolve serialized names of operations into concrete implementation (OpDefs) where possible -pub fn resolve_extension_ops( - h: &mut Hugr, - extension_registry: &ExtensionRegistry, -) -> Result<(), OpaqueOpError> { - let mut replacements = Vec::new(); - for n in h.nodes() { - if let OpType::OpaqueOp(opaque) = h.get_optype(n) { - let resolved = resolve_opaque_op(n, opaque, extension_registry)?; - replacements.push((n, resolved)); - } - } - // Only now can we perform the replacements as the 'for' loop was borrowing 'h' preventing use from using it mutably - for (n, op) in replacements { - debug_assert_eq!(h.get_optype(n).tag(), OpTag::Leaf); - debug_assert_eq!(op.tag(), OpTag::Leaf); - h.replace_op(n, op).unwrap(); - } - Ok(()) -} - -/// Try to resolve a [`OpaqueOp`] to a [`ExtensionOp`] by looking the op up in -/// the registry. -/// -/// # Return -/// Some if the serialized opaque resolves to an extension-defined op and all is -/// ok; None if the serialized opaque doesn't identify an extension -/// -/// # Errors -/// If the serialized opaque resolves to a definition that conflicts with what -/// was serialized -pub fn resolve_opaque_op( - node: Node, - opaque: &OpaqueOp, - extension_registry: &ExtensionRegistry, -) -> Result { - if let Some(r) = extension_registry.get(&opaque.extension) { - // Fail if the Extension was found but did not have the expected operation - let Some(def) = r.get_op(&opaque.name) else { - return Err(OpaqueOpError::OpNotFoundInExtension( - node, - opaque.name.clone(), - r.name().clone(), - )); - }; - let ext_op = ExtensionOp::new_with_cached( - def.clone(), - opaque.args.clone(), - opaque, - extension_registry, - ) - .map_err(|e| OpaqueOpError::SignatureError { - node, - name: opaque.name.clone(), - cause: e, - })?; - if opaque.signature() != ext_op.signature() { - return Err(OpaqueOpError::SignatureMismatch { - node, - extension: opaque.extension.clone(), - op: def.name().clone(), - computed: ext_op.signature.clone(), - stored: opaque.signature.clone(), - }); - }; - Ok(ext_op) - } else { - Err(OpaqueOpError::UnresolvedOp( - node, - opaque.name.clone(), - opaque.extension.clone(), - )) - } -} - /// Errors that arise after loading a Hugr containing opaque ops (serialized just as their names) /// when trying to resolve the serialized names against a registry of known Extensions. #[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] pub enum OpaqueOpError { /// The Extension was found but did not contain the expected OpDef - #[error("Operation '{1}' in {0} not found in Extension {2}")] - OpNotFoundInExtension(Node, OpName, ExtensionId), + #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}", + available_ops.iter().join(", ") + )] + OpNotFoundInExtension { + /// The node where the error occurred. + node: Node, + /// The missing operation. + op: OpName, + /// The extension where the operation was expected. + extension: ExtensionId, + /// The available operations in the extension. + available_ops: Vec, + }, /// Extension and OpDef found, but computed signature did not match stored #[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")] #[allow(missing_docs)] @@ -383,6 +329,9 @@ pub enum OpaqueOpError { #[cfg(test)] mod test { + use ops::OpType; + + use crate::extension::resolution::update_op_extensions; use crate::std_extensions::arithmetic::conversions::{self, CONVERT_OPS_REGISTRY}; use crate::{ extension::{ @@ -396,6 +345,11 @@ mod test { use super::*; + /// Unwrap the replacement type's `OpDef` from the return type of `resolve_op_definition`. + fn resolve_res_definition(res: &OpType) -> &OpDef { + res.as_extension_op().unwrap().def() + } + #[test] fn new_opaque_op() { let sig = Signature::new_endo(vec![qb_t()]); @@ -426,10 +380,14 @@ mod test { vec![], Signature::new(i0.clone(), bool_t()), ); - let resolved = - super::resolve_opaque_op(Node::from(portgraph::NodeIndex::new(1)), &opaque, registry) - .unwrap(); - assert_eq!(resolved.def().name(), "itobool"); + let mut resolved = opaque.into(); + update_op_extensions( + Node::from(portgraph::NodeIndex::new(1)), + &mut resolved, + registry, + ) + .unwrap(); + assert_eq!(resolve_res_definition(&resolved).name(), "itobool"); } #[test] @@ -466,20 +424,22 @@ mod test { endo_sig.clone(), ); let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig); - let resolved_val = super::resolve_opaque_op( + let mut resolved_val = opaque_val.into(); + update_op_extensions( Node::from(portgraph::NodeIndex::new(1)), - &opaque_val, + &mut resolved_val, ®istry, ) .unwrap(); - assert_eq!(resolved_val.def().name(), val_name); + assert_eq!(resolve_res_definition(&resolved_val).name(), val_name); - let resolved_comp = super::resolve_opaque_op( + let mut resolved_comp = opaque_comp.into(); + update_op_extensions( Node::from(portgraph::NodeIndex::new(2)), - &opaque_comp, + &mut resolved_comp, ®istry, ) .unwrap(); - assert_eq!(resolved_comp.def().name(), comp_name); + assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name); } } diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index a26adfecb..492edf428 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -372,7 +372,7 @@ impl ListOpInst { .clone() .into_iter() // ignore self if already in registry - .filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext)) + .filter(|ext| ext.name() != EXTENSION.name()) .chain(std::iter::once(EXTENSION.to_owned())), ) .unwrap(); diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 331a97b52..a8cd3d9ba 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -395,6 +395,12 @@ impl TypeBase { &self.0 } + /// Report a mutable reference to the component TypeEnum. + #[inline(always)] + pub fn as_type_enum_mut(&mut self) -> &mut TypeEnum { + &mut self.0 + } + /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index c8eed491e..22af9b77a 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -98,7 +98,10 @@ impl CustomType { let ex = extension_registry.get(&self.extension); // Even if OpDef's (+binaries) are not available, the part of the Extension definition // describing the TypeDefs can easily be passed around (serialized), so should be available. - let ex = ex.ok_or(SignatureError::ExtensionNotFound(self.extension.clone()))?; + let ex = ex.ok_or(SignatureError::ExtensionNotFound { + missing: self.extension.clone(), + available: extension_registry.ids().cloned().collect(), + })?; ex.get_type(&self.id) .ok_or(SignatureError::ExtensionTypeNotFound { exn: self.extension.clone(), @@ -143,6 +146,11 @@ impl CustomType { pub fn extension_ref(&self) -> Weak { self.extension_ref.clone() } + + /// Update the internal extension reference with a new weak pointer. + pub fn update_extension(&mut self, extension_ref: Weak) { + self.extension_ref = extension_ref; + } } impl Display for CustomType { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 5882d377f..754e32205 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -145,6 +145,11 @@ impl PolyFuncTypeBase { self.params.iter().map(ToString::to_string).join(" ") ) } + + /// Returns a mutable reference to the body of the function type. + pub fn body_mut(&mut self) -> &mut FuncTypeBase { + &mut self.body + } } #[cfg(test)] diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index eef582384..dba98384c 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -106,6 +106,8 @@ pub(crate) mod test_quantum_extension { use std::sync::Arc; use crate::ops::{OpName, OpNameRef}; + use crate::std_extensions::arithmetic::float_ops; + use crate::std_extensions::logic; use crate::types::FuncValueType; use crate::{ extension::{ @@ -190,7 +192,15 @@ pub(crate) mod test_quantum_extension { lazy_static! { /// Quantum extension definition. pub static ref EXTENSION: Arc = extension(); - static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.clone(), PRELUDE.clone(), float_types::EXTENSION.clone()]).unwrap(); + + /// A registry with all necessary extensions to run tests internally, including the test quantum extension. + pub static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([ + EXTENSION.clone(), + PRELUDE.clone(), + float_types::EXTENSION.clone(), + float_ops::EXTENSION.clone(), + logic::EXTENSION.clone() + ]).unwrap(); } diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index b5e303d43..351f4a19e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -218,4 +218,4 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { } #[cfg(test)] -mod test; +pub(crate) mod test; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 55d53db1c..8b54f7e93 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,14 +1,13 @@ use crate::const_fold::constant_fold_pass; +use crate::test::TEST_REG; use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{ bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, UnpackTuple, }; -use hugr_core::extension::{ExtensionRegistry, PRELUDE}; use hugr_core::ops::Value; -use hugr_core::std_extensions::arithmetic; use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use hugr_core::std_extensions::logic::{self, LogicOp}; +use hugr_core::std_extensions::logic::LogicOp; use hugr_core::type_row; use hugr_core::types::{Signature, Type, TypeRow, TypeRowRV}; @@ -105,20 +104,12 @@ fn test_big() { .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), - ]) - .unwrap(); let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), ®) + .finish_hugr_with_outputs(to_int.outputs(), &TEST_REG) .unwrap(); assert_eq!(h.node_count(), 8); - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); let expected = const_ok(i2c(2).clone(), error_type()); assert_fully_folded(&h, &expected); @@ -128,14 +119,8 @@ fn test_big() { #[ignore = "Waiting for `unwrap` operation"] // TODO: https://github.com/CQCL/hugr/issues/1486 fn test_list_ops() -> Result<(), Box> { - use hugr_core::std_extensions::collections::{self, ListOp, ListValue}; + use hugr_core::std_extensions::collections::{ListOp, ListValue}; - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), - ]) - .unwrap(); let base_list: Value = ListValue::new(bool_t(), [Value::false_val()]).into(); let mut build = DFGBuilder::new(Signature::new( type_row![], @@ -149,7 +134,7 @@ fn test_list_ops() -> Result<(), Box> { .add_dataflow_op( ListOp::pop .with_type(bool_t()) - .to_extension_op(®) + .to_extension_op(&TEST_REG) .unwrap(), [list], )? @@ -162,15 +147,15 @@ fn test_list_ops() -> Result<(), Box> { .add_dataflow_op( ListOp::push .with_type(bool_t()) - .to_extension_op(®) + .to_extension_op(&TEST_REG) .unwrap(), [list, elem], )? .outputs_arr(); - let mut h = build.finish_hugr_with_outputs([list], ®)?; + let mut h = build.finish_hugr_with_outputs([list], &TEST_REG)?; - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); assert_fully_folded(&h, &base_list); Ok(()) @@ -186,10 +171,10 @@ fn test_fold_and() { let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::true_val()); let x2 = build.add_dataflow_op(LogicOp::And, [x0, x1]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -204,10 +189,10 @@ fn test_fold_or() { let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::false_val()); let x2 = build.add_dataflow_op(LogicOp::Or, [x0, x1]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -221,10 +206,10 @@ fn test_fold_not() { let mut build = DFGBuilder::new(noargfn(bool_t())).unwrap(); let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_dataflow_op(LogicOp::Not, [x0]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -250,9 +235,9 @@ fn orphan_output() { .unwrap(); let or_node = r.node(); let parent = build.container_node(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + let mut h = build + .finish_hugr_with_outputs(r.outputs(), &TEST_REG) + .unwrap(); // we delete the original Not and create a new One. This means it will be // traversed by `constant_fold_pass` after the Or. @@ -261,7 +246,7 @@ fn orphan_output() { h.disconnect(or_node, IncomingPort::from(1)); h.connect(new_not, 0, or_node, 1); h.remove_node(orig_not.node()); - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); assert_fully_folded(&h, &Value::true_val()) } @@ -291,33 +276,25 @@ fn test_folding_pass_issue_996() { let x7 = build .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x7.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } #[test] fn test_const_fold_to_nonfinite() { - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - // HUGR computing 1.0 / 1.0 let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h0, ®); + let mut h0 = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h0, &TEST_REG); assert_fully_folded_with(&h0, |v| { v.get_custom_value::().unwrap().value() == 1.0 }); @@ -328,8 +305,10 @@ fn test_const_fold_to_nonfinite() { let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h1, ®); + let mut h1 = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h1, &TEST_REG); assert_eq!(h1.node_count(), 8); } @@ -345,13 +324,10 @@ fn test_fold_iwiden_u() { let x1 = build .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 13).unwrap()); assert_fully_folded(&h, &expected); } @@ -368,13 +344,10 @@ fn test_fold_iwiden_s() { let x1 = build .add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); assert_fully_folded(&h, &expected); } @@ -419,13 +392,10 @@ fn test_fold_inarrow, E: std::fmt::Debug>( [x0], ) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); lazy_static! { static ref INARROW_ERROR_VALUE: ConstError = ConstError { signal: 0, @@ -452,13 +422,10 @@ fn test_fold_itobool() { let x1 = build .add_dataflow_op(ConvertOpDef::itobool.without_log_width(), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -475,13 +442,10 @@ fn test_fold_ifrombool() { let x1 = build .add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(0, 0).unwrap()); assert_fully_folded(&h, &expected); } @@ -498,13 +462,10 @@ fn test_fold_ieq() { let x2 = build .add_dataflow_op(IntOpDef::ieq.with_log_width(3), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -521,13 +482,10 @@ fn test_fold_ine() { let x2 = build .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -544,13 +502,10 @@ fn test_fold_ilt_u() { let x2 = build .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -567,13 +522,10 @@ fn test_fold_ilt_s() { let x2 = build .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -590,13 +542,10 @@ fn test_fold_igt_u() { let x2 = build .add_dataflow_op(IntOpDef::igt_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -613,13 +562,10 @@ fn test_fold_igt_s() { let x2 = build .add_dataflow_op(IntOpDef::igt_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -636,13 +582,10 @@ fn test_fold_ile_u() { let x2 = build .add_dataflow_op(IntOpDef::ile_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -659,13 +602,10 @@ fn test_fold_ile_s() { let x2 = build .add_dataflow_op(IntOpDef::ile_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -682,13 +622,10 @@ fn test_fold_ige_u() { let x2 = build .add_dataflow_op(IntOpDef::ige_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -705,13 +642,10 @@ fn test_fold_ige_s() { let x2 = build .add_dataflow_op(IntOpDef::ige_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -728,13 +662,10 @@ fn test_fold_imax_u() { let x2 = build .add_dataflow_op(IntOpDef::imax_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 11).unwrap()); assert_fully_folded(&h, &expected); } @@ -751,13 +682,10 @@ fn test_fold_imax_s() { let x2 = build .add_dataflow_op(IntOpDef::imax_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -774,13 +702,10 @@ fn test_fold_imin_u() { let x2 = build .add_dataflow_op(IntOpDef::imin_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 7).unwrap()); assert_fully_folded(&h, &expected); } @@ -797,13 +722,10 @@ fn test_fold_imin_s() { let x2 = build .add_dataflow_op(IntOpDef::imin_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -2).unwrap()); assert_fully_folded(&h, &expected); } @@ -820,13 +742,10 @@ fn test_fold_iadd() { let x2 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -1).unwrap()); assert_fully_folded(&h, &expected); } @@ -843,13 +762,10 @@ fn test_fold_isub() { let x2 = build .add_dataflow_op(IntOpDef::isub.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); assert_fully_folded(&h, &expected); } @@ -865,13 +781,10 @@ fn test_fold_ineg() { let x2 = build .add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -888,13 +801,10 @@ fn test_fold_imul() { let x2 = build .add_dataflow_op(IntOpDef::imul.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -14).unwrap()); assert_fully_folded(&h, &expected); } @@ -914,13 +824,10 @@ fn test_fold_idivmod_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::idivmod_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -946,13 +853,10 @@ fn test_fold_idivmod_u() { let x4 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [x2, x3]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x4.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(3, 8).unwrap()); assert_fully_folded(&h, &expected); } @@ -972,13 +876,10 @@ fn test_fold_idivmod_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::idivmod_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1006,13 +907,10 @@ fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) { let x4 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [x2, x3]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x4.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(6, c).unwrap()); assert_fully_folded(&h, &expected); } @@ -1030,13 +928,10 @@ fn test_fold_idiv_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::idiv_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1057,13 +952,10 @@ fn test_fold_idiv_u() { let x2 = build .add_dataflow_op(IntOpDef::idiv_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 6).unwrap()); assert_fully_folded(&h, &expected); } @@ -1081,13 +973,10 @@ fn test_fold_imod_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::imod_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1108,13 +997,10 @@ fn test_fold_imod_u() { let x2 = build .add_dataflow_op(IntOpDef::imod_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -1132,13 +1018,10 @@ fn test_fold_idiv_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::idiv_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1159,13 +1042,10 @@ fn test_fold_idiv_s() { let x2 = build .add_dataflow_op(IntOpDef::idiv_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -7).unwrap()); assert_fully_folded(&h, &expected); } @@ -1183,13 +1063,10 @@ fn test_fold_imod_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::imod_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1210,13 +1087,10 @@ fn test_fold_imod_s() { let x2 = build .add_dataflow_op(IntOpDef::imod_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1232,13 +1106,10 @@ fn test_fold_iabs() { let x2 = build .add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -1255,13 +1126,10 @@ fn test_fold_iand() { let x2 = build .add_dataflow_op(IntOpDef::iand.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 4).unwrap()); assert_fully_folded(&h, &expected); } @@ -1278,13 +1146,10 @@ fn test_fold_ior() { let x2 = build .add_dataflow_op(IntOpDef::ior.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 30).unwrap()); assert_fully_folded(&h, &expected); } @@ -1301,13 +1166,10 @@ fn test_fold_ixor() { let x2 = build .add_dataflow_op(IntOpDef::ixor.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 26).unwrap()); assert_fully_folded(&h, &expected); } @@ -1323,13 +1185,10 @@ fn test_fold_inot() { let x2 = build .add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, (1u64 << 32) - 15).unwrap()); assert_fully_folded(&h, &expected); } @@ -1346,13 +1205,10 @@ fn test_fold_ishl() { let x2 = build .add_dataflow_op(IntOpDef::ishl.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 112).unwrap()); assert_fully_folded(&h, &expected); } @@ -1369,13 +1225,10 @@ fn test_fold_ishr() { let x2 = build .add_dataflow_op(IntOpDef::ishr.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1392,13 +1245,10 @@ fn test_fold_irotl() { let x2 = build .add_dataflow_op(IntOpDef::irotl.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1415,13 +1265,10 @@ fn test_fold_irotr() { let x2 = build .add_dataflow_op(IntOpDef::irotr.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1437,13 +1284,10 @@ fn test_fold_itostring_u() { let x1 = build .add_dataflow_op(ConvertOpDef::itostring_u.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstString::new("17".into())); assert_fully_folded(&h, &expected); } @@ -1459,13 +1303,10 @@ fn test_fold_itostring_s() { let x1 = build .add_dataflow_op(ConvertOpDef::itostring_s.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstString::new("-17".into())); assert_fully_folded(&h, &expected); } @@ -1502,14 +1343,10 @@ fn test_fold_int_ops() { let x7 = build .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x7.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 13dd47776..9042850d4 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -11,3 +11,29 @@ pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; + +#[cfg(test)] +pub(crate) mod test { + + use lazy_static::lazy_static; + + use hugr_core::extension::{ExtensionRegistry, PRELUDE}; + use hugr_core::std_extensions::arithmetic; + use hugr_core::std_extensions::collections; + use hugr_core::std_extensions::logic; + + lazy_static! { + /// A registry containing various extensions for testing. + pub(crate) static ref TEST_REG: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_ops::EXTENSION.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), + ]) + .unwrap(); + } +} diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 0b23709a9..598622089 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -42,15 +42,14 @@ pub fn ensure_no_nonlocal_edges(hugr: &impl HugrView) -> Result<(), NonLocalEdge mod test { use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, - extension::{ - prelude::{bool_t, Noop}, - EMPTY_REG, - }, + extension::prelude::{bool_t, Noop}, ops::handle::NodeHandle, type_row, types::Signature, }; + use crate::test::TEST_REG; + use super::*; #[test] @@ -64,7 +63,7 @@ mod test { .unwrap() .outputs_arr(); builder - .finish_hugr_with_outputs([out_w], &EMPTY_REG) + .finish_hugr_with_outputs([out_w], &TEST_REG) .unwrap() }; ensure_no_nonlocal_edges(&hugr).unwrap(); @@ -94,7 +93,7 @@ mod test { }; ( builder - .finish_hugr_with_outputs([out_w], &EMPTY_REG) + .finish_hugr_with_outputs([out_w], &TEST_REG) .unwrap(), edge, ) diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index b5899c7ec..3a96696e8 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -81,7 +81,7 @@ //! lazy_static! { //! /// Quantum extension definition. //! pub static ref EXTENSION: Arc = extension(); -//! static ref REG: ExtensionRegistry = +//! pub static ref REG: ExtensionRegistry = //! ExtensionRegistry::try_new([EXTENSION.clone(), PRELUDE.clone()]).unwrap(); //! } //! fn get_gate(gate_name: impl Into) -> ExtensionOp { @@ -103,7 +103,7 @@ //! } //! } //! -//! use mini_quantum_extension::{cx_gate, h_gate, measure}; +//! use mini_quantum_extension::{cx_gate, h_gate, measure, REG}; //! //! // ┌───┐ //! // q_0: ┤ H ├──■───── @@ -121,7 +121,7 @@ //! let h1 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; //! let cx = dfg_builder.add_dataflow_op(cx_gate(), h0.outputs().chain(h1.outputs()))?; //! let measure = dfg_builder.add_dataflow_op(measure(), cx.outputs().last())?; -//! dfg_builder.finish_prelude_hugr_with_outputs(cx.outputs().take(1).chain(measure.outputs())) +//! dfg_builder.finish_hugr_with_outputs(cx.outputs().take(1).chain(measure.outputs()), ®) //! } //! //! let h: Hugr = make_dfg_hugr().unwrap();