diff --git a/Cargo.toml b/Cargo.toml index b61351b29..c842a3d1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,7 +65,7 @@ urlencoding = "2.1.2" webbrowser = "1.0.0" clap = { version = "4.5.4" } clio = "0.3.5" -clap-verbosity-flag = "2.2.0" +clap-verbosity-flag = "3.0.1" assert_cmd = "2.0.14" assert_fs = "1.1.1" predicates = "3.1.0" diff --git a/hugr-cli/src/extensions.rs b/hugr-cli/src/extensions.rs index 792d997e7..d6e160f76 100644 --- a/hugr-cli/src/extensions.rs +++ b/hugr-cli/src/extensions.rs @@ -28,9 +28,9 @@ impl ExtArgs { pub fn run_dump(&self, registry: &ExtensionRegistry) { let base_dir = &self.outdir; - for (name, ext) in registry.iter() { + for ext in registry.iter() { let mut path = base_dir.clone(); - for part in name.split('.') { + for part in ext.name().split('.') { path.push(part); } path.set_extension("json"); diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index 20f602fed..445930b24 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -4,7 +4,7 @@ use clap::Parser as _; use hugr_cli::{validate, CliArgs}; -use clap_verbosity_flag::Level; +use clap_verbosity_flag::log::Level; fn main() { match CliArgs::parse() { diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index 996799b77..5205006d7 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -1,8 +1,7 @@ //! The `validate` subcommand. use clap::Parser; -use clap_verbosity_flag::Level; -use hugr::package::PackageValidationError; +use clap_verbosity_flag::log::Level; use hugr::{extension::ExtensionRegistry, Extension, Hugr}; use crate::{CliError, HugrArgs}; @@ -64,8 +63,7 @@ impl HugrArgs { for ext in &self.extensions { let f = std::fs::File::open(ext)?; let ext: Extension = serde_json::from_reader(f)?; - reg.register_updated(ext) - .map_err(PackageValidationError::Extension)?; + reg.register_updated(ext); } package.update_validate(&mut reg)?; diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index cd9a3cb79..f3fc35f83 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -12,9 +12,8 @@ use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder}; use hugr::types::Type; use hugr::{ builder::{Container, Dataflow}, - extension::prelude::{BOOL_T, QB_T}, - std_extensions::arithmetic::float_types::FLOAT64_TYPE, - type_row, + extension::prelude::{bool_t, qb_t}, + std_extensions::arithmetic::float_types::float64_type, types::Signature, Hugr, }; @@ -41,7 +40,7 @@ const FLOAT_EXT_FILE: &str = concat!( /// A test package, containing a module-rooted HUGR. #[fixture] -fn test_package(#[default(BOOL_T)] id_type: Type) -> Package { +fn test_package(#[default(bool_t())] id_type: Type) -> Package { let mut module = ModuleBuilder::new(); let df = module .define_function("test", Signature::new_endo(id_type)) @@ -57,7 +56,7 @@ fn test_package(#[default(BOOL_T)] id_type: Type) -> Package { /// A DFG-rooted HUGR. #[fixture] -fn test_hugr(#[default(BOOL_T)] id_type: Type) -> Hugr { +fn test_hugr(#[default(bool_t())] id_type: Type) -> Hugr { let mut df = DFGBuilder::new(Signature::new_endo(id_type)).unwrap(); let [i] = df.input_wires_arr(); df.set_outputs([i]).unwrap(); @@ -120,7 +119,7 @@ fn test_mermaid(test_hugr_file: NamedTempFile, mut cmd: Command) { #[fixture] fn bad_hugr_string() -> String { - let df = DFGBuilder::new(Signature::new_endo(type_row![QB_T])).unwrap(); + let df = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap(); let bad_hugr = df.hugr().clone(); serde_json::to_string(&bad_hugr).unwrap() @@ -178,7 +177,7 @@ fn test_no_std(test_hugr_string: String, mut val_cmd: Command) { } #[fixture] -fn float_hugr_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { +fn float_hugr_string(#[with(float64_type())] test_hugr: Hugr) -> String { serde_json::to_string(&test_hugr).unwrap() } @@ -191,7 +190,7 @@ fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains(" Extension 'arithmetic.float.types' not found")); + .stderr(contains(" requires extension arithmetic.float.types")); } #[rstest] @@ -205,7 +204,7 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) { val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[fixture] -fn package_string(#[with(FLOAT64_TYPE)] test_package: Package) -> String { +fn package_string(#[with(float64_type())] test_package: Package) -> String { serde_json::to_string(&test_package).unwrap() } diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index f40703243..a02c38816 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -28,7 +28,7 @@ //! ```rust //! # use hugr::Hugr; //! # use hugr::builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, ModuleBuilder, DataflowSubContainer, HugrBuilder}; -//! use hugr::extension::prelude::BOOL_T; +//! use hugr::extension::prelude::bool_t; //! use hugr::std_extensions::logic::{EXTENSION_ID, LOGIC_REG, LogicOp}; //! use hugr::types::Signature; //! @@ -42,7 +42,7 @@ //! let _dfg_handle = { //! let mut dfg = module_builder.define_function( //! "main", -//! Signature::new_endo(BOOL_T).with_extension_delta(EXTENSION_ID), +//! Signature::new_endo(bool_t()).with_extension_delta(EXTENSION_ID), //! )?; //! //! // Get the wires from the function inputs. @@ -59,7 +59,7 @@ //! let _circuit_handle = { //! let mut dfg = module_builder.define_function( //! "circuit", -//! Signature::new_endo(vec![BOOL_T, BOOL_T]) +//! Signature::new_endo(vec![bool_t(), bool_t()]) //! .with_extension_delta(EXTENSION_ID), //! )?; //! let mut circuit = dfg.as_circuit(dfg.input_wires()); @@ -238,11 +238,12 @@ pub enum BuilderWiringError { pub(crate) mod test { use rstest::fixture; + use crate::extension::prelude::{bool_t, usize_t}; use crate::hugr::{views::HugrView, HugrMut}; use crate::ops; - use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; - use crate::types::{PolyFuncType, Signature, Type}; - use crate::{type_row, Hugr}; + use crate::types::{PolyFuncType, Signature}; + use crate::utils::test_quantum_extension; + use crate::Hugr; use super::handle::BuildHandle; use super::{ @@ -251,12 +252,8 @@ pub(crate) mod test { }; use super::{DataflowSubContainer, HugrBuilder}; - pub(super) const NAT: Type = crate::extension::prelude::USIZE_T; - pub(super) const BIT: Type = crate::extension::prelude::BOOL_T; - pub(super) const QB: Type = crate::extension::prelude::QB_T; - /// Wire up inputs of a Dataflow container to the outputs. - pub(super) fn n_identity( + pub(crate) fn n_identity( dataflow_builder: T, ) -> Result { let w = dataflow_builder.input_wires(); @@ -272,12 +269,12 @@ pub(crate) mod test { f(f_builder)?; - Ok(module_builder.finish_hugr(&FLOAT_OPS_REGISTRY)?) + Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?) } #[fixture] pub(crate) fn simple_dfg_hugr() -> Hugr { - let dfg_builder = DFGBuilder::new(Signature::new(type_row![BIT], type_row![BIT])).unwrap(); + let dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap(); let [i1] = dfg_builder.input_wires_arr(); dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap() } @@ -285,7 +282,7 @@ pub(crate) mod test { #[fixture] pub(crate) fn simple_funcdef_hugr() -> Hugr { let fn_builder = - FunctionBuilder::new("test", Signature::new(type_row![BIT], type_row![BIT])).unwrap(); + FunctionBuilder::new("test", Signature::new(vec![bool_t()], vec![bool_t()])).unwrap(); let [i1] = fn_builder.input_wires_arr(); fn_builder.finish_prelude_hugr_with_outputs([i1]).unwrap() } @@ -293,7 +290,7 @@ pub(crate) mod test { #[fixture] pub(crate) fn simple_module_hugr() -> Hugr { let mut builder = ModuleBuilder::new(); - let sig = Signature::new(type_row![BIT], type_row![BIT]); + let sig = Signature::new(vec![bool_t()], vec![bool_t()]); builder.declare("test", sig.into()).unwrap(); builder.finish_prelude_hugr().unwrap() } @@ -301,7 +298,7 @@ pub(crate) mod test { #[fixture] pub(crate) fn simple_cfg_hugr() -> Hugr { let mut cfg_builder = - CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT])).unwrap(); + CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); super::cfg::test::build_basic_cfg(&mut cfg_builder).unwrap(); cfg_builder.finish_prelude_hugr().unwrap() } diff --git a/hugr-core/src/builder/cfg.rs b/hugr-core/src/builder/cfg.rs index bcb6aff3b..fb9199df5 100644 --- a/hugr-core/src/builder/cfg.rs +++ b/hugr-core/src/builder/cfg.rs @@ -51,21 +51,20 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// ops, type_row, /// types::{Signature, SumType, Type}, /// Hugr, +/// extension::prelude::usize_t, /// }; /// -/// const NAT: Type = prelude::USIZE_T; -/// /// fn make_cfg() -> Result { -/// let mut cfg_builder = CFGBuilder::new(Signature::new_endo(NAT))?; +/// let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?; /// /// // Outputs from basic blocks must be packed in a sum which corresponds to /// // which successor to pick. We'll either choose the first branch and pass -/// // it a NAT, or the second branch and pass it nothing. -/// let sum_variants = vec![type_row![NAT], type_row![]]; +/// // it a usize, or the second branch and pass it nothing. +/// let sum_variants = vec![vec![usize_t()].into(), type_row![]]; /// /// // The second argument says what types will be passed through to every /// // successor, in addition to the appropriate `sum_variants` type. -/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?; +/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?; /// /// let [inw] = entry_b.input_wires_arr(); /// let entry = { @@ -81,10 +80,10 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// }; /// /// // This block will be the first successor of the entry node. It takes two -/// // `NAT` arguments: one from the `sum_variants` type, and another from the +/// // `usize` arguments: one from the `sum_variants` type, and another from the /// // entry node's `other_outputs`. /// let mut successor_builder = cfg_builder.simple_block_builder( -/// inout_sig(type_row![NAT, NAT], NAT), +/// inout_sig(vec![usize_t(), usize_t()], usize_t()), /// 1, // only one successor to this block /// )?; /// let successor_a = { @@ -98,7 +97,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// }; /// /// // The only argument to this block is the entry node's `other_outputs`. -/// let mut successor_builder = cfg_builder.simple_block_builder(endo_sig(NAT), 1)?; +/// let mut successor_builder = cfg_builder.simple_block_builder(endo_sig(usize_t()), 1)?; /// let successor_b = { /// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum()); /// let [in_wire] = successor_builder.input_wires_arr(); @@ -392,7 +391,11 @@ impl + AsRef> BlockBuilder { Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs)) } fn create(base: B, block_n: Node) -> Result { - let block_op = base.get_optype(block_n).as_dataflow_block().unwrap(); + let block_op = base + .as_ref() + .get_optype(block_n) + .as_dataflow_block() + .unwrap(); let signature = block_op.inner_signature(); let db = DFGBuilder::create_with_io(base, block_n, signature)?; Ok(BlockBuilder::from_dfg_builder(db)) @@ -464,9 +467,10 @@ impl BlockBuilder { pub(crate) mod test { use crate::builder::{DataflowSubContainer, ModuleBuilder}; + use crate::extension::prelude::usize_t; use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::ValidationError; - use crate::{builder::test::NAT, type_row}; + use crate::type_row; use cool_asserts::assert_matches; use super::*; @@ -475,13 +479,13 @@ pub(crate) mod test { let build_result = { let mut module_builder = ModuleBuilder::new(); let mut func_builder = module_builder - .define_function("main", Signature::new(vec![NAT], type_row![NAT]))?; + .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?; let _f_id = { let [int] = func_builder.input_wires_arr(); let cfg_id = { let mut cfg_builder = - func_builder.cfg_builder(vec![(NAT, int)], type_row![NAT])?; + func_builder.cfg_builder(vec![(usize_t(), int)], vec![usize_t()].into())?; build_basic_cfg(&mut cfg_builder)?; cfg_builder.finish_sub_container()? @@ -498,7 +502,7 @@ pub(crate) mod test { } #[test] fn basic_cfg_hugr() -> Result<(), BuildError> { - let mut cfg_builder = CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT]))?; + let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?; build_basic_cfg(&mut cfg_builder)?; assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_)); @@ -508,7 +512,8 @@ pub(crate) mod test { pub(crate) fn build_basic_cfg + AsRef>( cfg_builder: &mut CFGBuilder, ) -> Result<(), BuildError> { - let sum2_variants = vec![type_row![NAT], type_row![NAT]]; + let usize_row: TypeRow = vec![usize_t()].into(); + let sum2_variants = vec![usize_row.clone(), usize_row]; let mut entry_b = cfg_builder.entry_builder_exts( sum2_variants.clone(), type_row![], @@ -520,8 +525,8 @@ pub(crate) mod test { let sum = entry_b.make_sum(1, sum2_variants, [inw])?; entry_b.finish_with_outputs(sum, [])? }; - let mut middle_b = - cfg_builder.simple_block_builder(Signature::new(type_row![NAT], type_row![NAT]), 1)?; + let mut middle_b = cfg_builder + .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?; let middle = { let c = middle_b.add_load_const(ops::Value::unary_unit_sum()); let [inw] = middle_b.input_wires_arr(); @@ -535,7 +540,7 @@ pub(crate) mod test { } #[test] fn test_dom_edge() -> Result<(), BuildError> { - let mut cfg_builder = CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT]))?; + let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?; let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum()); let sum_variants = vec![type_row![]]; @@ -551,7 +556,7 @@ pub(crate) mod test { entry_b.finish_with_outputs(sum, [])? }; let mut middle_b = - cfg_builder.simple_block_builder(Signature::new(type_row![], type_row![NAT]), 1)?; + cfg_builder.simple_block_builder(Signature::new(type_row![], vec![usize_t()]), 1)?; let middle = { let c = middle_b.load_const(&sum_tuple_const); middle_b.finish_with_outputs(c, [inw])? @@ -566,18 +571,19 @@ pub(crate) mod test { #[test] fn test_non_dom_edge() -> Result<(), BuildError> { - let mut cfg_builder = CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT]))?; + let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?; let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum()); let sum_variants = vec![type_row![]]; - let mut middle_b = - cfg_builder.simple_block_builder(Signature::new(type_row![NAT], type_row![NAT]), 1)?; + let mut middle_b = cfg_builder + .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?; let [inw] = middle_b.input_wires_arr(); let middle = { let c = middle_b.load_const(&sum_tuple_const); middle_b.finish_with_outputs(c, [inw])? }; - let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?; + let mut entry_b = + cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?; let entry = { let sum = entry_b.load_const(&sum_tuple_const); // entry block uses wire from middle block even though middle block diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 112cb83fb..5a2b18a04 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -243,26 +243,24 @@ mod test { use super::*; use cool_asserts::assert_matches; - use crate::extension::{ExtensionId, ExtensionSet}; + use crate::builder::{Container, HugrBuilder, ModuleBuilder}; + use crate::extension::prelude::{qb_t, usize_t}; + use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY}; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::utils::test_quantum_extension::{ self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64, }; use crate::Extension; use crate::{ - builder::{ - test::{build_main, NAT, QB}, - DataflowSubContainer, - }, - extension::prelude::BOOL_T, - type_row, + builder::{test::build_main, DataflowSubContainer}, + extension::prelude::bool_t, types::Signature, }; #[test] fn simple_linear() { let build_res = build_main( - Signature::new(type_row![QB, QB], type_row![QB, QB]) + Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]) .with_extension_delta(test_quantum_extension::EXTENSION_ID) .with_extension_delta(float_types::EXTENSION_ID) .into(), @@ -298,33 +296,57 @@ mod test { #[test] fn with_nonlinear_and_outputs() { let my_ext_name: ExtensionId = "MyExt".try_into().unwrap(); - let mut my_ext = Extension::new_test(my_ext_name.clone()); - let my_custom_op = my_ext.simple_ext_op("MyOp", Signature::new(vec![QB, NAT], vec![QB])); - - let build_res = build_main( - Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]) + let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| { + ext.add_op( + "MyOp".into(), + "".to_string(), + Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]), + extension_ref, + ) + .unwrap(); + }); + let my_custom_op = my_ext + .instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY) + .unwrap(); + + let mut module_builder = ModuleBuilder::new(); + let mut f_build = module_builder + .define_function( + "main", + Signature::new( + vec![qb_t(), qb_t(), usize_t()], + vec![qb_t(), qb_t(), bool_t()], + ) .with_extension_delta(ExtensionSet::from_iter([ test_quantum_extension::EXTENSION_ID, my_ext_name, - ])) - .into(), - |mut f_build| { - let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); - - let mut linear = f_build.as_circuit([q0, q1]); - - let measure_out = linear - .append(cx_gate(), [0, 1])? - .append_and_consume( - my_custom_op, - [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], - )? - .append_with_outputs(measure(), [0])?; - - let out_qbs = linear.finish(); - f_build.finish_with_outputs(out_qbs.into_iter().chain(measure_out)) - }, - ); + ])), + ) + .unwrap(); + + let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); + + let mut linear = f_build.as_circuit([q0, q1]); + + let measure_out = linear + .append(cx_gate(), [0, 1]) + .unwrap() + .append_and_consume( + my_custom_op, + [CircuitUnit::Linear(0), CircuitUnit::Wire(angle)], + ) + .unwrap() + .append_with_outputs(measure(), [0]) + .unwrap(); + + let out_qbs = linear.finish(); + f_build + .finish_with_outputs(out_qbs.into_iter().chain(measure_out)) + .unwrap(); + + let mut registry = test_quantum_extension::REG.clone(); + registry.register(my_ext).unwrap(); + let build_res = module_builder.finish_hugr(®istry); assert_matches!(build_res, Ok(_)); } @@ -332,7 +354,7 @@ mod test { #[test] fn ancillae() { let build_res = build_main( - Signature::new_endo(QB) + Signature::new_endo(qb_t()) .with_extension_delta(test_quantum_extension::EXTENSION_ID) .into(), |mut f_build| { @@ -370,7 +392,7 @@ mod test { #[test] fn circuit_builder_errors() { let _build_res = build_main( - Signature::new_endo(type_row![QB, QB]).into(), + Signature::new_endo(vec![qb_t(), qb_t()]).into(), |mut f_build| { let mut circ = f_build.as_circuit(f_build.input_wires()); let [q0, q1] = circ.tracked_units_arr(); diff --git a/hugr-core/src/builder/conditional.rs b/hugr-core/src/builder/conditional.rs index 685d2c889..4e8039992 100644 --- a/hugr-core/src/builder/conditional.rs +++ b/hugr-core/src/builder/conditional.rs @@ -214,11 +214,9 @@ mod test { use crate::builder::{DataflowSubContainer, ModuleBuilder}; + use crate::extension::prelude::usize_t; use crate::{ - builder::{ - test::{n_identity, NAT}, - Dataflow, - }, + builder::{test::n_identity, Dataflow}, ops::Value, type_row, }; @@ -229,8 +227,8 @@ mod test { fn basic_conditional() -> Result<(), BuildError> { let mut conditional_b = ConditionalBuilder::new_exts( [type_row![], type_row![]], - type_row![NAT], - type_row![NAT], + vec![usize_t()], + vec![usize_t()], ExtensionSet::new(), )?; @@ -244,14 +242,14 @@ mod test { let build_result: Result = { let mut module_builder = ModuleBuilder::new(); let mut fbuild = module_builder - .define_function("main", Signature::new(type_row![NAT], type_row![NAT]))?; + .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?; let tru_const = fbuild.add_constant(Value::true_val()); let _fdef = { let const_wire = fbuild.load_const(&tru_const); let [int] = fbuild.input_wires_arr(); let conditional_id = { - let other_inputs = vec![(NAT, int)]; - let outputs = vec![NAT].into(); + let other_inputs = vec![(usize_t(), int)]; + let outputs = vec![usize_t()].into(); let mut conditional_b = fbuild.conditional_builder( ([type_row![], type_row![]], const_wire), other_inputs, diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 04ec38b4b..5d0c7e0c2 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -315,8 +315,8 @@ pub(crate) mod test { use crate::builder::{ endo_sig, inout_sig, BuilderWiringError, DataflowSubContainer, ModuleBuilder, }; + use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::prelude::{Lift, Noop}; - use crate::extension::prelude::{BOOL_T, USIZE_T}; use crate::extension::{ExtensionId, SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::validate::InterGraphEdgeError; use crate::ops::{handle::NodeHandle, OpTag}; @@ -325,28 +325,27 @@ pub(crate) mod test { use crate::std_extensions::logic::test::and_op; use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV}; - use crate::utils::test_quantum_extension::h_gate; - use crate::{ - builder::test::{n_identity, BIT, NAT, QB}, - type_row, Wire, - }; + use crate::utils::test_quantum_extension::{self, h_gate}; + use crate::{builder::test::n_identity, type_row, Wire}; use super::super::test::simple_dfg_hugr; use super::*; #[test] fn nested_identity() -> Result<(), BuildError> { let build_result = { - let mut outer_builder = DFGBuilder::new(endo_sig(type_row![NAT, QB]))?; + let mut outer_builder = DFGBuilder::new(endo_sig(vec![usize_t(), qb_t()]))?; let [int, qb] = outer_builder.input_wires_arr(); let q_out = outer_builder.add_dataflow_op(h_gate(), vec![qb])?; - let inner_builder = outer_builder.dfg_builder_endo([(NAT, int)])?; + let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?; let inner_id = n_identity(inner_builder)?; - outer_builder - .finish_prelude_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs())) + outer_builder.finish_hugr_with_outputs( + inner_id.outputs().chain(q_out.outputs()), + &test_quantum_extension::REG, + ) }; assert_eq!(build_result.err(), None); @@ -360,11 +359,11 @@ pub(crate) mod test { F: FnOnce(&mut DFGBuilder) -> Result<(), BuildError>, { let build_result = { - let mut builder = DFGBuilder::new(inout_sig(BOOL_T, type_row![BOOL_T, BOOL_T]))?; + let mut builder = DFGBuilder::new(inout_sig(bool_t(), vec![bool_t(), bool_t()]))?; f(&mut builder)?; - builder.finish_hugr(&EMPTY_REG) + builder.finish_hugr(&test_quantum_extension::REG) }; assert_matches!(build_result, Ok(_), "Failed on example: {}", msg); @@ -408,7 +407,7 @@ pub(crate) mod test { let mut module_builder = ModuleBuilder::new(); let f_build = module_builder - .define_function("main", Signature::new(type_row![QB], type_row![QB, QB]))?; + .define_function("main", Signature::new(vec![qb_t()], vec![qb_t(), qb_t()]))?; let [q1] = f_build.input_wires_arr(); f_build.finish_with_outputs([q1, q1])?; @@ -422,7 +421,7 @@ pub(crate) mod test { error: BuilderWiringError::NoCopyLinear { typ, .. }, .. }) - if typ == QB + if typ == qb_t() ); } @@ -431,19 +430,19 @@ pub(crate) mod test { let builder = || -> Result { let mut f_build = FunctionBuilder::new( "main", - Signature::new(type_row![BIT], type_row![BIT]).with_prelude(), + Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), )?; let [i1] = f_build.input_wires_arr(); - let noop = f_build.add_dataflow_op(Noop(BIT), [i1])?; + let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?; let i1 = noop.out_wire(0); let mut nested = f_build.dfg_builder( - Signature::new(type_row![], type_row![BIT]).with_prelude(), + Signature::new(type_row![], vec![bool_t()]).with_prelude(), [], )?; - let id = nested.add_dataflow_op(Noop(BIT), [i1])?; + let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?; let nested = nested.finish_with_outputs([id.out_wire(0)])?; @@ -458,21 +457,21 @@ pub(crate) mod test { let builder = || -> Result<(Hugr, Node), BuildError> { let mut f_build = FunctionBuilder::new( "main", - Signature::new(type_row![BIT], type_row![BIT]).with_prelude(), + Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), )?; let f_node = f_build.container_node(); let [i0] = f_build.input_wires_arr(); - let noop0 = f_build.add_dataflow_op(Noop(BIT), [i0])?; + let noop0 = f_build.add_dataflow_op(Noop(bool_t()), [i0])?; // Some some order edges f_build.set_order(&f_build.io()[0], &noop0.node()); f_build.set_order(&noop0.node(), &f_build.io()[1]); // Add a new input and output, and connect them with a noop in between - f_build.add_output(QB); - let i1 = f_build.add_input(QB); - let noop1 = f_build.add_dataflow_op(Noop(QB), [i1])?; + f_build.add_output(qb_t()); + let i1 = f_build.add_input(qb_t()); + let noop1 = f_build.add_dataflow_op(Noop(qb_t()), [i1])?; let hugr = f_build.finish_prelude_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?; @@ -482,21 +481,26 @@ pub(crate) mod test { let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}")); let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap(); - assert_eq!(func_sig.io(), (&type_row![BIT, QB], &type_row![BIT, QB])); + assert_eq!( + func_sig.io(), + ( + &vec![bool_t(), qb_t()].into(), + &vec![bool_t(), qb_t()].into() + ) + ); } #[test] fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> { - let mut f_build = - FunctionBuilder::new("main", Signature::new(type_row![QB], type_row![QB]))?; + let mut f_build = FunctionBuilder::new("main", Signature::new(vec![qb_t()], vec![qb_t()]))?; let [i1] = f_build.input_wires_arr(); - let noop = f_build.add_dataflow_op(Noop(QB), [i1])?; + let noop = f_build.add_dataflow_op(Noop(qb_t()), [i1])?; let i1 = noop.out_wire(0); - let mut nested = f_build.dfg_builder(Signature::new(type_row![], type_row![QB]), [])?; + let mut nested = f_build.dfg_builder(Signature::new(type_row![], vec![qb_t()]), [])?; - let id_res = nested.add_dataflow_op(Noop(QB), [i1]); + let id_res = nested.add_dataflow_op(Noop(qb_t()), [i1]); // The error would anyway be caught in validation when we finish the Hugr, // but the builder catches it earlier @@ -520,7 +524,7 @@ pub(crate) mod test { #[test] fn insert_hugr() -> Result<(), BuildError> { // Create a simple DFG - let mut dfg_builder = DFGBuilder::new(Signature::new(type_row![BIT], type_row![BIT]))?; + let mut dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()]))?; let [i1] = dfg_builder.input_wires_arr(); dfg_builder.set_metadata("x", 42); let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1], &EMPTY_REG)?; @@ -530,7 +534,7 @@ pub(crate) mod test { let (dfg_node, f_node) = { let mut f_build = module_builder - .define_function("main", Signature::new(type_row![BIT], type_row![BIT]))?; + .define_function("main", Signature::new(vec![bool_t()], vec![bool_t()]))?; let [i1] = f_build.input_wires_arr(); let dfg = f_build.add_hugr_with_wires(dfg_hugr, [i1])?; @@ -555,18 +559,18 @@ pub(crate) mod test { let xb: ExtensionId = "B".try_into().unwrap(); let xc: ExtensionId = "C".try_into().unwrap(); - let mut parent = DFGBuilder::new(endo_sig(BIT))?; + let mut parent = DFGBuilder::new(endo_sig(bool_t()))?; let [w] = parent.input_wires_arr(); // A box which adds extensions A and B, via child Lift nodes - let mut add_ab = parent.dfg_builder(endo_sig(BIT), [w])?; + let mut add_ab = parent.dfg_builder(endo_sig(bool_t()), [w])?; let [w] = add_ab.input_wires_arr(); - let lift_a = add_ab.add_dataflow_op(Lift::new(type_row![BIT], xa.clone()), [w])?; + let lift_a = add_ab.add_dataflow_op(Lift::new(vec![bool_t()].into(), xa.clone()), [w])?; let [w] = lift_a.outputs_arr(); - let lift_b = add_ab.add_dataflow_op(Lift::new(type_row![BIT], xb), [w])?; + let lift_b = add_ab.add_dataflow_op(Lift::new(vec![bool_t()].into(), xb), [w])?; let [w] = lift_b.outputs_arr(); let add_ab = add_ab.finish_with_outputs([w])?; @@ -574,14 +578,14 @@ pub(crate) mod test { // Add another node (a sibling to add_ab) which adds extension C // via a child lift node - let mut add_c = parent.dfg_builder(endo_sig(BIT), [w])?; + let mut add_c = parent.dfg_builder(endo_sig(bool_t()), [w])?; let [w] = add_c.input_wires_arr(); - let lift_c = add_c.add_dataflow_op(Lift::new(type_row![BIT], xc), [w])?; + let lift_c = add_c.add_dataflow_op(Lift::new(vec![bool_t()].into(), xc), [w])?; let wires: Vec = lift_c.outputs().collect(); let add_c = add_c.finish_with_outputs(wires)?; let [w] = add_c.outputs_arr(); - parent.finish_hugr_with_outputs([w], &EMPTY_REG)?; + parent.finish_hugr_with_outputs([w], &test_quantum_extension::REG)?; Ok(()) } @@ -647,7 +651,7 @@ pub(crate) mod test { PolyFuncType::new( [TypeParam::new_list(TypeBound::Copyable)], Signature::new( - Type::new_function(FuncValueType::new(USIZE_T, tv.clone())), + Type::new_function(FuncValueType::new(usize_t(), tv.clone())), vec![], ), ), @@ -656,7 +660,7 @@ pub(crate) mod test { // But cannot eval it... let ev = e.instantiate_extension_op( "eval", - [vec![USIZE_T.into()].into(), vec![tv.into()].into()], + [vec![usize_t().into()].into(), vec![tv.into()].into()], &PRELUDE_REGISTRY, ); assert_eq!( @@ -673,11 +677,11 @@ pub(crate) mod test { let (mut hugr, load_constant, call) = { let mut builder = ModuleBuilder::new(); let func = builder - .declare("func", Signature::new_endo(BOOL_T).into()) + .declare("func", Signature::new_endo(bool_t()).into()) .unwrap(); let (load_constant, call) = { let mut builder = builder - .define_function("main", Signature::new(Type::EMPTY_TYPEROW, BOOL_T)) + .define_function("main", Signature::new(Type::EMPTY_TYPEROW, bool_t())) .unwrap(); let load_constant = builder.add_load_value(Value::true_val()); let [r] = builder diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index d2f144e04..1df328d83 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -159,13 +159,10 @@ impl + AsRef> ModuleBuilder { mod test { use cool_asserts::assert_matches; + use crate::extension::prelude::usize_t; use crate::{ - builder::{ - test::{n_identity, NAT}, - Dataflow, DataflowSubContainer, - }, + builder::{test::n_identity, Dataflow, DataflowSubContainer}, extension::{EMPTY_REG, PRELUDE_REGISTRY}, - type_row, types::Signature, }; @@ -177,7 +174,7 @@ mod test { let f_id = module_builder.declare( "main", - Signature::new(type_row![NAT], type_row![NAT]).into(), + Signature::new(vec![usize_t()], vec![usize_t()]).into(), )?; let mut f_build = module_builder.define_declaration(&f_id)?; @@ -217,10 +214,14 @@ mod test { let build_result = { let mut module_builder = ModuleBuilder::new(); - let mut f_build = module_builder - .define_function("main", Signature::new(type_row![NAT], type_row![NAT, NAT]))?; - let local_build = f_build - .define_function("local", Signature::new(type_row![NAT], type_row![NAT, NAT]))?; + let mut f_build = module_builder.define_function( + "main", + Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), + )?; + let local_build = f_build.define_function( + "local", + Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]), + )?; let [wire] = local_build.input_wires_arr(); let f_id = local_build.finish_with_outputs([wire, wire])?; diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index 29134ce03..b3051c72e 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -106,12 +106,10 @@ impl TailLoopBuilder { mod test { use cool_asserts::assert_matches; + use crate::extension::prelude::bool_t; use crate::{ - builder::{ - test::{BIT, NAT}, - DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, - }, - extension::prelude::{ConstUsize, Lift, PRELUDE_ID, USIZE_T}, + builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer}, + extension::prelude::{usize_t, ConstUsize, Lift, PRELUDE_ID}, hugr::ValidationError, ops::Value, type_row, @@ -123,7 +121,7 @@ mod test { fn basic_loop() -> Result<(), BuildError> { let build_result: Result = { let mut loop_b = - TailLoopBuilder::new_exts(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID)?; + TailLoopBuilder::new_exts(vec![], vec![bool_t()], vec![usize_t()], PRELUDE_ID)?; let [i1] = loop_b.input_wires_arr(); let const_wire = loop_b.add_load_value(ConstUsize::new(1)); @@ -142,15 +140,21 @@ mod test { let mut module_builder = ModuleBuilder::new(); let mut fbuild = module_builder.define_function( "main", - Signature::new(type_row![BIT], type_row![NAT]).with_prelude(), + Signature::new(vec![bool_t()], vec![usize_t()]).with_prelude(), )?; let _fdef = { let [b1] = fbuild - .add_dataflow_op(Lift::new(type_row![BIT], PRELUDE_ID), fbuild.input_wires())? + .add_dataflow_op( + Lift::new(vec![bool_t()].into(), PRELUDE_ID), + fbuild.input_wires(), + )? .outputs_arr(); let loop_id = { - let mut loop_b = - fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?; + let mut loop_b = fbuild.tail_loop_builder( + vec![(bool_t(), b1)], + vec![], + vec![usize_t()].into(), + )?; let signature = loop_b.loop_signature()?.clone(); let const_val = Value::true_val(); let const_wire = loop_b.add_load_const(Value::true_val()); @@ -164,7 +168,7 @@ mod test { let output_row = loop_b.internal_output_row()?; let mut conditional_b = loop_b.conditional_builder( ([type_row![], type_row![]], const_wire), - vec![(BIT, b1)], + vec![(bool_t(), b1)], output_row, )?; diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 093368b60..06be24fa9 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -443,10 +443,10 @@ impl<'a> Context<'a> { let poly_func_type = match opdef.signature_func() { SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type, - _ => return self.make_named_global_ref(opdef.extension(), opdef.name()), + _ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()), }; - let key = (opdef.extension().clone(), opdef.name().clone()); + let key = (opdef.extension_id().clone(), opdef.name().clone()); let entry = self.decl_operations.entry(key); let node = match entry { @@ -467,7 +467,7 @@ impl<'a> Context<'a> { }; let decl = self.with_local_scope(node, |this| { - let name = this.make_qualified_name(opdef.extension(), opdef.name()); + let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, @@ -509,48 +509,24 @@ impl<'a> Context<'a> { /// like for the other nodes since the ports are control flow ports. pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId { let inputs = { - let mut inputs = BumpVec::with_capacity_in(block.inputs.len(), self.bump); - for input in block.inputs.iter() { - inputs.push(self.export_type(input)); - } - let inputs = self.make_term(model::Term::List { - items: inputs.into_bump_slice(), - tail: None, - }); + let inputs = self.export_type_row(&block.inputs); let inputs = self.make_term(model::Term::Control { values: inputs }); self.make_term(model::Term::List { - items: self.bump.alloc_slice_copy(&[inputs]), - tail: None, + parts: self.bump.alloc_slice_copy(&[model::ListPart::Item(inputs)]), }) }; - let tail = { - let mut tail = BumpVec::with_capacity_in(block.other_outputs.len(), self.bump); - for other_output in block.other_outputs.iter() { - tail.push(self.export_type(other_output)); - } - self.make_term(model::Term::List { - items: tail.into_bump_slice(), - tail: None, - }) - }; + let tail = self.export_type_row(&block.other_outputs); let outputs = { let mut outputs = BumpVec::with_capacity_in(block.sum_rows.len(), self.bump); for sum_row in block.sum_rows.iter() { - let mut variant = BumpVec::with_capacity_in(sum_row.len(), self.bump); - for typ in sum_row.iter() { - variant.push(self.export_type(typ)); - } - let variant = self.make_term(model::Term::List { - items: variant.into_bump_slice(), - tail: Some(tail), - }); - outputs.push(self.make_term(model::Term::Control { values: variant })); + let variant = self.export_type_row_with_tail(sum_row, Some(tail)); + let control = self.make_term(model::Term::Control { values: variant }); + outputs.push(model::ListPart::Item(control)); } self.make_term(model::Term::List { - items: outputs.into_bump_slice(), - tail: None, + parts: outputs.into_bump_slice(), }) }; @@ -772,10 +748,12 @@ impl<'a> Context<'a> { TypeArg::String { arg } => self.make_term(model::Term::Str(self.bump.alloc_str(arg))), TypeArg::Sequence { elems } => { // For now we assume that the sequence is meant to be a list. - let items = self - .bump - .alloc_slice_fill_iter(elems.iter().map(|elem| self.export_type_arg(elem))); - self.make_term(model::Term::List { items, tail: None }) + let parts = self.bump.alloc_slice_fill_iter( + elems + .iter() + .map(|elem| model::ListPart::Item(self.export_type_arg(elem))), + ); + self.make_term(model::Term::List { parts }) } TypeArg::Extensions { es } => self.export_ext_set(es), TypeArg::Variable { v } => self.export_type_arg_var(v), @@ -798,32 +776,53 @@ impl<'a> Context<'a> { pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { match t { SumType::Unit { size } => { - let items = self.bump.alloc_slice_fill_iter((0..*size).map(|_| { - self.make_term(model::Term::List { - items: &[], - tail: None, - }) + let parts = self.bump.alloc_slice_fill_iter((0..*size).map(|_| { + model::ListPart::Item(self.make_term(model::Term::List { parts: &[] })) })); - let list = model::Term::List { items, tail: None }; - let variants = self.make_term(list); + let variants = self.make_term(model::Term::List { parts }); self.make_term(model::Term::Adt { variants }) } SumType::General { rows } => { - let items = self - .bump - .alloc_slice_fill_iter(rows.iter().map(|row| self.export_type_row(row))); - let list = model::Term::List { items, tail: None }; + let parts = self.bump.alloc_slice_fill_iter( + rows.iter() + .map(|row| model::ListPart::Item(self.export_type_row(row))), + ); + let list = model::Term::List { parts }; let variants = { self.make_term(list) }; self.make_term(model::Term::Adt { variants }) } } } - pub fn export_type_row(&mut self, t: &TypeRowBase) -> model::TermId { - let mut items = BumpVec::with_capacity_in(t.len(), self.bump); - items.extend(t.iter().map(|row| self.export_type(row))); - let items = items.into_bump_slice(); - self.make_term(model::Term::List { items, tail: None }) + #[inline] + pub fn export_type_row(&mut self, row: &TypeRowBase) -> model::TermId { + self.export_type_row_with_tail(row, None) + } + + pub fn export_type_row_with_tail( + &mut self, + row: &TypeRowBase, + tail: Option, + ) -> model::TermId { + let mut parts = BumpVec::with_capacity_in(row.len() + tail.is_some() as usize, self.bump); + + for t in row.iter() { + match t.as_type_enum() { + TypeEnum::RowVar(var) => { + parts.push(model::ListPart::Splice(self.export_row_var(var.as_rv()))); + } + _ => { + parts.push(model::ListPart::Item(self.export_type(t))); + } + } + } + + if let Some(tail) = tail { + parts.push(model::ListPart::Splice(tail)); + } + + let parts = parts.into_bump_slice(); + self.make_term(model::Term::List { parts }) } /// Exports a `TypeParam` to a term. @@ -855,12 +854,12 @@ impl<'a> Context<'a> { self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { - let items = self.bump.alloc_slice_fill_iter( + let parts = self.bump.alloc_slice_fill_iter( params .iter() - .map(|param| self.export_type_param(param, None)), + .map(|param| model::ListPart::Item(self.export_type_param(param, None))), ); - let types = self.make_term(model::Term::List { items, tail: None }); + let types = self.make_term(model::Term::List { parts }); self.make_term(model::Term::ApplyFull { global: model::GlobalRef::Named(TERM_PARAM_TUPLE), args: self.bump.alloc_slice_copy(&[types]), @@ -873,54 +872,26 @@ impl<'a> Context<'a> { } } - pub fn export_ext_set(&mut self, t: &ExtensionSet) -> model::TermId { - // Extension sets with variables are encoded using a hack: a variable in the - // extension set is represented by converting its index into a string. - // Until we have a better representation for extension sets, we therefore - // need to try and parse each extension as a number to determine if it is - // a variable or an extension. - - // NOTE: This overprovisions the capacity since some of the entries of the row - // may be variables. Since we panic when there is more than one variable, this - // may at most waste one slot. That is way better than having to allocate - // a temporary vector. - // - // Also `ExtensionSet` has no way of reporting its size, so we have to count - // the elements by iterating over them... - let capacity = t.iter().count(); - let mut extensions = BumpVec::with_capacity_in(capacity, self.bump); - let mut rest = None; - - for ext in t.iter() { - if let Ok(index) = ext.parse::() { - // Extension sets in the model support at most one variable. This is a - // deliberate limitation so that extension sets behave like polymorphic rows. - // The type theory of such rows and how to apply them to model (co)effects - // is well understood. - // - // Extension sets in `hugr-core` at this point have no such restriction. - // However, it appears that so far we never actually use extension sets with - // multiple variables, except for extension sets that are generated through - // property testing. - if rest.is_some() { - // TODO: We won't need this anymore once we have a core representation - // that ensures that extension sets have at most one variable. - panic!("Extension set with multiple variables") - } + pub fn export_ext_set(&mut self, ext_set: &ExtensionSet) -> model::TermId { + let capacity = ext_set.iter().size_hint().0; + let mut parts = BumpVec::with_capacity_in(capacity, self.bump); - let node = self.local_scope.expect("local variable out of scope"); - rest = Some( - self.module - .insert_term(model::Term::Var(model::LocalRef::Index(node, index as _))), - ); - } else { - extensions.push(self.bump.alloc_str(ext) as &str); + for ext in ext_set.iter() { + // `ExtensionSet`s represent variables by extension names that parse to integers. + match ext.parse::() { + Ok(var) => { + let node = self.local_scope.expect("local variable out of scope"); + let local_ref = model::LocalRef::Index(node, var); + let term = self.make_term(model::Term::Var(local_ref)); + parts.push(model::ExtSetPart::Splice(term)); + } + Err(_) => parts.push(model::ExtSetPart::Extension(self.bump.alloc_str(ext))), } } - let extensions = extensions.into_bump_slice(); - - self.make_term(model::Term::ExtSet { extensions, rest }) + self.make_term(model::Term::ExtSet { + parts: parts.into_bump_slice(), + }) } pub fn export_node_metadata( @@ -955,9 +926,8 @@ mod test { use crate::{ builder::{Dataflow, DataflowSubContainer}, - extension::prelude::QB_T, + extension::prelude::qb_t, std_extensions::arithmetic::float_types, - type_row, types::Signature, utils::test_quantum_extension::{self, cx_gate, h_gate}, Hugr, @@ -966,7 +936,7 @@ mod test { #[fixture] fn test_simple_circuit() -> Hugr { crate::builder::test::build_main( - Signature::new_endo(type_row![QB_T, QB_T]) + Signature::new_endo(vec![qb_t(), qb_t()]) .with_extension_delta(test_quantum_extension::EXTENSION_ID) .with_extension_delta(float_types::EXTENSION_ID) .into(), diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 4d22ba7f0..97c971822 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -7,7 +7,8 @@ pub use semver::Version; use std::collections::btree_map; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; -use std::sync::Arc; +use std::mem; +use std::sync::{Arc, Weak}; use thiserror::Error; @@ -20,24 +21,26 @@ use crate::types::RowVariable; use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; +mod const_fold; mod op_def; +pub mod prelude; +pub mod resolution; +pub mod simple_op; +mod type_def; + +pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder}; pub use op_def::{ CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, ValidateJustArgs, ValidateTypeArgs, }; -mod type_def; -pub use type_def::{TypeDef, TypeDefBound}; -mod const_fold; -pub mod prelude; -pub mod simple_op; -pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; +pub use type_def::{TypeDef, TypeDefBound}; #[cfg(feature = "declarative")] pub mod declarative; /// Extension Registries store extensions to be looked up e.g. during validation. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct ExtensionRegistry(BTreeMap>); impl ExtensionRegistry { @@ -95,18 +98,16 @@ impl ExtensionRegistry { } } - /// Registers a new extension to the registry, keeping most up to date if extension exists. + /// Registers a new extension to the registry, keeping the one most up to + /// date if the extension already exists. /// /// If extension IDs match, the extension with the higher version is kept. - /// If versions match, the original extension is kept. - /// Returns a reference to the registered extension if successful. + /// If versions match, the original extension is kept. Returns a reference + /// to the registered extension if successful. /// - /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see - /// [`ExtensionRegistry::register_updated_ref`]. - pub fn register_updated( - &mut self, - extension: impl Into>, - ) -> Result<(), ExtensionRegistryError> { + /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, + /// see [`ExtensionRegistry::register_updated_ref`]. + pub fn register_updated(&mut self, extension: impl Into>) { let extension = extension.into(); match self.0.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { @@ -118,11 +119,10 @@ impl ExtensionRegistry { ve.insert(extension); } } - Ok(()) } - /// Registers a new extension to the registry, keeping most up to date if - /// extension exists. + /// Registers a new extension to the registry, keeping the one most up to + /// date if the extension already exists. /// /// If extension IDs match, the extension with the higher version is kept. /// If versions match, the original extension is kept. Returns a reference @@ -130,10 +130,7 @@ impl ExtensionRegistry { /// /// Clones the Arc only when required. For no-cloning version see /// [`ExtensionRegistry::register_updated`]. - pub fn register_updated_ref( - &mut self, - extension: &Arc, - ) -> Result<(), ExtensionRegistryError> { + pub fn register_updated_ref(&mut self, extension: &Arc) { match self.0.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { if prev.get().version() < extension.version() { @@ -144,7 +141,6 @@ impl ExtensionRegistry { ve.insert(extension.clone()); } } - Ok(()) } /// Returns the number of extensions in the registry. @@ -158,8 +154,13 @@ impl ExtensionRegistry { } /// Returns an iterator over the extensions in the registry. - pub fn iter(&self) -> impl Iterator)> { - self.0.iter() + pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter { + self.0.values() + } + + /// Returns an iterator over the extensions ids in the registry. + pub fn ids(&self) -> impl Iterator { + self.0.keys() } /// Delete an extension from the registry and return it if it was present. @@ -169,12 +170,38 @@ impl ExtensionRegistry { } impl IntoIterator for ExtensionRegistry { - type Item = (ExtensionId, Arc); + type Item = Arc; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = std::collections::btree_map::IntoValues>; fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() + self.0.into_values() + } +} + +impl<'a> IntoIterator for &'a ExtensionRegistry { + type Item = &'a Arc; + + type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc>; + + fn into_iter(self) -> Self::IntoIter { + self.0.values() + } +} + +impl<'a> Extend<&'a Arc> for ExtensionRegistry { + fn extend>>(&mut self, iter: T) { + for ext in iter { + self.register_updated_ref(ext); + } + } +} + +impl Extend> for ExtensionRegistry { + fn extend>>(&mut self, iter: T) { + for ext in iter { + self.register_updated(ext); + } } } @@ -199,8 +226,13 @@ pub enum SignatureError { #[error("Invalid type arguments for operation")] InvalidTypeArgs, /// The Extension Registry did not contain an Extension referenced by the Signature - #[error("Extension '{0}' not found")] - ExtensionNotFound(ExtensionId), + #[error("Extension '{missing}' not found. Available extensions: {}", + available.iter().map(|e| e.to_string()).collect::>().join(", ") + )] + ExtensionNotFound { + missing: ExtensionId, + available: Vec, + }, /// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature #[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")] ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName }, @@ -335,6 +367,45 @@ impl ExtensionValue { pub type ExtensionId = IdentList; /// A extension is a set of capabilities required to execute a graph. +/// +/// These are normally defined once and shared across multiple graphs and +/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`]. +/// +/// # Example +/// +/// The following example demonstrates how to define a new extension with a +/// custom operation and a custom type. +/// +/// When using `arc`s, the extension can only be modified at creation time. The +/// defined operations and types keep a [`Weak`] reference to their extension. We provide a +/// helper method [`Extension::new_arc`] to aid their definition. +/// +/// ``` +/// # use hugr_core::types::Signature; +/// # use hugr_core::extension::{Extension, ExtensionId, Version}; +/// # use hugr_core::extension::{TypeDefBound}; +/// Extension::new_arc( +/// ExtensionId::new_unchecked("my.extension"), +/// Version::new(0, 1, 0), +/// |ext, extension_ref| { +/// // Add a custom type definition +/// ext.add_type( +/// "MyType".into(), +/// vec![], // No type parameters +/// "Some type".into(), +/// TypeDefBound::any(), +/// extension_ref, +/// ); +/// // Add a custom operation +/// ext.add_op( +/// "MyOp".into(), +/// "Some operation".into(), +/// Signature::new_endo(vec![]), +/// extension_ref, +/// ); +/// }, +/// ); +/// ``` #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct Extension { /// Extension version, follows semver. @@ -361,6 +432,12 @@ pub struct Extension { impl Extension { /// Creates a new extension with the given name. + /// + /// In most cases extensions are contained inside an [`Arc`] so that they + /// can be shared across hugr instances and operation definitions. + /// + /// See [`Extension::new_arc`] for a more ergonomic way to create boxed + /// extensions. pub fn new(name: ExtensionId, version: Version) -> Self { Self { name, @@ -372,14 +449,63 @@ impl Extension { } } - /// Extend the requirements of this extension with another set of extensions. - pub fn with_reqs(self, extension_reqs: impl Into) -> Self { - Self { - extension_reqs: self.extension_reqs.union(extension_reqs.into()), - ..self + /// Creates a new extension wrapped in an [`Arc`]. + /// + /// The closure lets us use a weak reference to the arc while the extension + /// is being built. This is necessary for calling [`Extension::add_op`] and + /// [`Extension::add_type`]. + pub fn new_arc( + name: ExtensionId, + version: Version, + init: impl FnOnce(&mut Extension, &Weak), + ) -> Arc { + Arc::new_cyclic(|extension_ref| { + let mut ext = Self::new(name, version); + init(&mut ext, extension_ref); + ext + }) + } + + /// Creates a new extension wrapped in an [`Arc`], using a fallible + /// initialization function. + /// + /// The closure lets us use a weak reference to the arc while the extension + /// is being built. This is necessary for calling [`Extension::add_op`] and + /// [`Extension::add_type`]. + pub fn try_new_arc( + name: ExtensionId, + version: Version, + init: impl FnOnce(&mut Extension, &Weak) -> Result<(), E>, + ) -> Result, E> { + // Annoying hack around not having `Arc::try_new_cyclic` that can return + // a Result. + // https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381 + // + // When there is an error, we store it in `error` and return it at the + // end instead of the partially-initialized extension. + let mut error = None; + let ext = Arc::new_cyclic(|extension_ref| { + let mut ext = Self::new(name, version); + match init(&mut ext, extension_ref) { + Ok(_) => ext, + Err(e) => { + error = Some(e); + ext + } + } + }); + match error { + Some(e) => Err(e), + None => Ok(ext), } } + /// Extend the requirements of this extension with another set of extensions. + pub fn add_requirements(&mut self, extension_reqs: impl Into) { + let reqs = mem::take(&mut self.extension_reqs); + self.extension_reqs = reqs.union(extension_reqs.into()); + } + /// Allows read-only access to the operations in this Extension pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc> { self.operations.get(name) @@ -445,7 +571,7 @@ impl Extension { ExtensionOp::new(op_def.clone(), args, ext_reg) } - // Validates against a registry, which we can assume includes this extension itself. + /// Validates against a registry, which we can assume includes this extension itself. // (TODO deal with the registry itself containing invalid extensions!) fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> { // We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624 @@ -634,20 +760,22 @@ pub mod test { impl Extension { /// Create a new extension for testing, with a 0 version. - pub(crate) fn new_test(name: ExtensionId) -> Self { - Self::new(name, Version::new(0, 0, 0)) + pub(crate) fn new_test_arc( + name: ExtensionId, + init: impl FnOnce(&mut Extension, &Weak), + ) -> Arc { + Self::new_arc(name, Version::new(0, 0, 0), init) } - /// Add a simple OpDef to the extension and return an extension op for it. - /// No description, no type parameters. - pub(crate) fn simple_ext_op( - &mut self, - name: &str, - signature: impl Into, - ) -> ExtensionOp { - self.add_op(name.into(), "".to_string(), signature).unwrap(); - self.instantiate_extension_op(name, [], &PRELUDE_REGISTRY) - .unwrap() + /// Create a new extension for testing, with a 0 version. + pub(crate) fn try_new_test_arc( + name: ExtensionId, + init: impl FnOnce( + &mut Extension, + &Weak, + ) -> Result<(), Box>, + ) -> Result, Box> { + Self::try_new_arc(name, Version::new(0, 0, 0), init) } } @@ -680,14 +808,14 @@ pub mod test { ); // register with update works - reg_ref.register_updated_ref(&ext1_1).unwrap(); - reg.register_updated(ext1_1.clone()).unwrap(); + reg_ref.register_updated_ref(&ext1_1); + reg.register_updated(ext1_1.clone()); assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); assert_eq!(®, ®_ref); // register with lower version does not change version - reg_ref.register_updated_ref(&ext1_2).unwrap(); - reg.register_updated(ext1_2.clone()).unwrap(); + reg_ref.register_updated_ref(&ext1_2); + reg.register_updated(ext1_2.clone()); assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); assert_eq!(®, ®_ref); diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index c81414c9f..1f6361b3e 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -29,6 +29,7 @@ mod types; use std::fs::File; use std::path::Path; +use std::sync::Arc; use crate::extension::prelude::PRELUDE_ID; use crate::ops::OpName; @@ -150,19 +151,24 @@ impl ExtensionDeclaration { &self, imports: &ExtensionSet, ctx: DeclarationContext<'_>, - ) -> Result { - let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0)) - .with_reqs(imports.clone()); - - for t in &self.types { - t.register(&mut ext, ctx)?; - } - - for o in &self.operations { - o.register(&mut ext, ctx)?; - } - - Ok(ext) + ) -> Result, ExtensionDeclarationError> { + Extension::try_new_arc( + self.name.clone(), + // TODO: Get the version as a parameter. + crate::extension::Version::new(0, 0, 0), + |ext, extension_ref| { + for t in &self.types { + t.register(ext, ctx, extension_ref)?; + } + + for o in &self.operations { + o.register(ext, ctx, extension_ref)?; + } + ext.add_requirements(imports.clone()); + + Ok(()) + }, + ) } } @@ -348,12 +354,9 @@ extensions: let new_exts = new_extensions(®, dependencies).collect_vec(); assert_eq!(new_exts.len(), num_declarations); + assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types); assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.types()).count(), - num_types - ); - assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.operations()).count(), + new_exts.iter().flat_map(|e| e.operations()).count(), num_operations ); Ok(()) @@ -375,12 +378,9 @@ extensions: let new_exts = new_extensions(®, dependencies).collect_vec(); assert_eq!(new_exts.len(), num_declarations); + assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types); assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.types()).count(), - num_types - ); - assert_eq!( - new_exts.iter().flat_map(|(_, e)| e.operations()).count(), + new_exts.iter().flat_map(|e| e.operations()).count(), num_operations ); Ok(()) @@ -407,8 +407,8 @@ extensions: fn new_extensions<'a>( reg: &'a ExtensionRegistry, dependencies: &'a ExtensionRegistry, - ) -> impl Iterator)> { + ) -> impl Iterator> { reg.iter() - .filter(move |(id, _)| !dependencies.contains(id) && *id != &PRELUDE_ID) + .filter(move |ext| !dependencies.contains(ext.name()) && ext.name() != &PRELUDE_ID) } } diff --git a/hugr-core/src/extension/declarative/ops.rs b/hugr-core/src/extension/declarative/ops.rs index 8bd769e10..39e688a6b 100644 --- a/hugr-core/src/extension/declarative/ops.rs +++ b/hugr-core/src/extension/declarative/ops.rs @@ -8,6 +8,7 @@ //! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration use std::collections::HashMap; +use std::sync::Weak; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; @@ -55,10 +56,14 @@ pub(super) struct OperationDeclaration { impl OperationDeclaration { /// Register this operation in the given extension. + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. pub fn register<'ext>( &self, ext: &'ext mut Extension, ctx: DeclarationContext<'_>, + extension_ref: &Weak, ) -> Result<&'ext mut OpDef, ExtensionDeclarationError> { // We currently only support explicit signatures. // @@ -88,7 +93,12 @@ impl OperationDeclaration { let signature_func: SignatureFunc = signature.make_signature(ext, ctx, ¶ms)?; - let op_def = ext.add_op(self.name.clone(), self.description.clone(), signature_func)?; + let op_def = ext.add_op( + self.name.clone(), + self.description.clone(), + signature_func, + extension_ref, + )?; for (k, v) in &self.misc { op_def.add_misc(k, v.clone()); diff --git a/hugr-core/src/extension/declarative/types.rs b/hugr-core/src/extension/declarative/types.rs index 10b6e41a0..e426c69f2 100644 --- a/hugr-core/src/extension/declarative/types.rs +++ b/hugr-core/src/extension/declarative/types.rs @@ -7,6 +7,8 @@ //! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format //! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration +use std::sync::Weak; + use crate::extension::{TypeDef, TypeDefBound}; use crate::types::type_param::TypeParam; use crate::types::{TypeBound, TypeName}; @@ -49,10 +51,14 @@ impl TypeDeclaration { /// /// Types in the definition will be resolved using the extensions in `scope` /// and the current extension. + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. pub fn register<'ext>( &self, ext: &'ext mut Extension, ctx: DeclarationContext<'_>, + extension_ref: &Weak, ) -> Result<&'ext TypeDef, ExtensionDeclarationError> { let params = self .params @@ -64,6 +70,7 @@ impl TypeDeclaration { params, self.description.clone(), self.bound.into(), + extension_ref, )?; Ok(type_def) } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 5e74b9e9c..d9a3900fa 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -2,7 +2,7 @@ use std::cmp::min; use std::collections::btree_map::Entry; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use super::{ ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, @@ -302,6 +302,9 @@ impl Debug for LowerFunc { pub struct OpDef { /// The unique Extension owning this OpDef (of which this OpDef is a member) extension: ExtensionId, + /// A weak reference to the extension defining this operation. + #[serde(skip)] + extension_ref: Weak, /// Unique identifier of the operation. Used to look up OpDefs in the registry /// when deserializing nodes (which store only the name). name: OpName, @@ -394,11 +397,16 @@ impl OpDef { &self.name } - /// Returns a reference to the extension of this [`OpDef`]. - pub fn extension(&self) -> &ExtensionId { + /// Returns a reference to the extension id of this [`OpDef`]. + pub fn extension_id(&self) -> &ExtensionId { &self.extension } + /// Returns a weak reference to the extension defining this operation. + pub fn extension(&self) -> Weak { + self.extension_ref.clone() + } + /// Returns a reference to the description of this [`OpDef`]. pub fn description(&self) -> &str { self.description.as_ref() @@ -467,15 +475,41 @@ impl Extension { /// Add an operation definition to the extension. Must be a type scheme /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary /// validation for type arguments ([`CustomValidator`]), or a custom binary - /// function for computing the signature given type arguments (`impl [CustomSignatureFunc]`). + /// function for computing the signature given type arguments (implementing + /// `[CustomSignatureFunc]`). + /// + /// This method requires a [`Weak`] reference to the [`Arc`] containing the + /// extension being defined. The intended way to call this method is inside + /// the closure passed to [`Extension::new_arc`] when defining the extension. + /// + /// # Example + /// + /// ``` + /// # use hugr_core::types::Signature; + /// # use hugr_core::extension::{Extension, ExtensionId, Version}; + /// Extension::new_arc( + /// ExtensionId::new_unchecked("my.extension"), + /// Version::new(0, 1, 0), + /// |ext, extension_ref| { + /// ext.add_op( + /// "MyOp".into(), + /// "Some operation".into(), + /// Signature::new_endo(vec![]), + /// extension_ref, + /// ); + /// }, + /// ); + /// ``` pub fn add_op( &mut self, name: OpName, description: String, signature_func: impl Into, + extension_ref: &Weak, ) -> Result<&mut OpDef, ExtensionBuildError> { let op = OpDef { extension: self.name.clone(), + extension_ref: extension_ref.clone(), name, description, signature_func: signature_func.into(), @@ -501,7 +535,7 @@ pub(super) mod test { use super::SignatureFromArgs; use crate::builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc}; - use crate::extension::prelude::USIZE_T; + use crate::extension::prelude::usize_t; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::ops::OpName; @@ -544,6 +578,7 @@ pub(super) mod test { fn eq(&self, other: &Self) -> bool { let OpDef { extension, + extension_ref: _, name, description, misc, @@ -553,6 +588,7 @@ pub(super) mod test { } = &self.0; let OpDef { extension: other_extension, + extension_ref: _, name: other_name, description: other_description, misc: other_misc, @@ -601,32 +637,35 @@ pub(super) mod test { #[test] fn op_def_with_type_scheme() -> Result<(), Box> { let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); - let mut e = Extension::new_test(EXT_ID); - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; - let list_of_var = - Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); const OP_NAME: OpName = OpName::new_inline("Reverse"); - let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); - - let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?; - def.add_lower_func(LowerFunc::FixedHugr { - extensions: ExtensionSet::new(), - hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here - }); - def.add_misc("key", Default::default()); - assert_eq!(def.description(), "desc"); - assert_eq!(def.lower_funcs.len(), 1); - assert_eq!(def.misc.len(), 1); - - let reg = - ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), e.into()]).unwrap(); + + let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + let list_of_var = + Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); + let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); + + let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?; + def.add_lower_func(LowerFunc::FixedHugr { + extensions: ExtensionSet::new(), + hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here + }); + def.add_misc("key", Default::default()); + assert_eq!(def.description(), "desc"); + assert_eq!(def.lower_funcs.len(), 1); + assert_eq!(def.misc.len(), 1); + + Ok(()) + })?; + + let reg = ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), ext]).unwrap(); let e = reg.get(&EXT_ID).unwrap(); let list_usize = - Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: USIZE_T }])?); + Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?); let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?; let rev = dfg.add_dataflow_op( - e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], ®) + e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }], ®) .unwrap(), dfg.input_wires(), )?; @@ -666,60 +705,64 @@ pub(super) mod test { MAX_NAT } } - let mut e = Extension::new_test(EXT_ID); - let def: &mut crate::extension::OpDef = - e.add_op("MyOp".into(), "".to_string(), SigFun())?; - - // Base case, no type variables: - let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok( - Signature::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) - .with_extension_delta(EXT_ID) - ) - ); - assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); - - // Second arg may be a variable (substitutable) - let tyvar = Type::new_var_use(0, TypeBound::Copyable); - let tyvars: Vec = vec![tyvar.clone(); 3]; - let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok( - Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) - .with_extension_delta(EXT_ID) - ) - ); - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) - .unwrap(); - - // quick sanity check that we are validating the args - note changed bound: - assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Any.into()]), - Err(SignatureError::TypeVarDoesNotMatchDeclaration { - actual: TypeBound::Any.into(), - cached: TypeBound::Copyable.into() - }) - ); - - // First arg must be concrete, not a variable - let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); - let args = [TypeArg::new_var_use(0, kind.clone()), USIZE_T.into()]; - // We can't prevent this from getting into our compute_signature implementation: - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Err(SignatureError::InvalidTypeArgs) - ); - // But validation rules it out, even when the variable is declared: - assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), - Err(SignatureError::FreeTypeVar { - idx: 0, - num_decls: 0 - }) - ); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let def: &mut crate::extension::OpDef = + ext.add_op("MyOp".into(), "".to_string(), SigFun(), extension_ref)?; + + // Base case, no type variables: + let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok(Signature::new( + vec![usize_t(); 3], + vec![Type::new_tuple(vec![usize_t(); 3])] + ) + .with_extension_delta(EXT_ID)) + ); + assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); + + // Second arg may be a variable (substitutable) + let tyvar = Type::new_var_use(0, TypeBound::Copyable); + let tyvars: Vec = vec![tyvar.clone(); 3]; + let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok( + Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) + .with_extension_delta(EXT_ID) + ) + ); + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) + .unwrap(); + + // quick sanity check that we are validating the args - note changed bound: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Any.into()]), + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + actual: TypeBound::Any.into(), + cached: TypeBound::Copyable.into() + }) + ); + + // First arg must be concrete, not a variable + let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); + let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()]; + // We can't prevent this from getting into our compute_signature implementation: + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Err(SignatureError::InvalidTypeArgs) + ); + // But validation rules it out, even when the variable is declared: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), + Err(SignatureError::FreeTypeVar { + idx: 0, + num_decls: 0 + }) + ); + + Ok(()) + })?; Ok(()) } @@ -728,68 +771,77 @@ pub(super) mod test { fn type_scheme_instantiate_var() -> Result<(), Box> { // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external) // type variable - let mut e = Extension::new_test(EXT_ID); - let def = e.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new( - vec![TypeBound::Any.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), - ), - )?; - let tv = Type::new_var_use(1, TypeBound::Copyable); - let args = [TypeArg::Type { ty: tv.clone() }]; - let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; - def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); - assert_eq!( - def.compute_signature(&args, &EMPTY_REG), - Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) - ); - // But not with an external row variable - let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); - assert_eq!( - def.compute_signature(&[arg.clone()], &EMPTY_REG), - Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: TypeBound::Any.into(), - arg - } - )) - ); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let def = ext.add_op( + "SimpleOp".into(), + "".into(), + PolyFuncTypeRV::new( + vec![TypeBound::Any.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + extension_ref, + )?; + let tv = Type::new_var_use(1, TypeBound::Copyable); + let args = [TypeArg::Type { ty: tv.clone() }]; + let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; + def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); + assert_eq!( + def.compute_signature(&args, &EMPTY_REG), + Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) + ); + // But not with an external row variable + let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); + assert_eq!( + def.compute_signature(&[arg.clone()], &EMPTY_REG), + Err(SignatureError::TypeArgMismatch( + TypeArgError::TypeMismatch { + param: TypeBound::Any.into(), + arg + } + )) + ); + Ok(()) + })?; Ok(()) } #[test] fn instantiate_extension_delta() -> Result<(), Box> { - use crate::extension::prelude::{BOOL_T, PRELUDE_REGISTRY}; + use crate::extension::prelude::{bool_t, PRELUDE_REGISTRY}; + + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let params: Vec = vec![TypeParam::Extensions]; + let db_set = ExtensionSet::type_var(0); + let fun_ty = Signature::new_endo(bool_t()).with_extension_delta(db_set); + + let def = ext.add_op( + "SimpleOp".into(), + "".into(), + PolyFuncTypeRV::new(params.clone(), fun_ty), + extension_ref, + )?; + + // Concrete extension set + let es = ExtensionSet::singleton(&EXT_ID); + let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone()); + let args = [TypeArg::Extensions { es }]; + + def.validate_args(&args, &PRELUDE_REGISTRY, ¶ms) + .unwrap(); + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok(exp_fun_ty) + ); + + Ok(()) + })?; - let mut e = Extension::new_test(EXT_ID); - - let params: Vec = vec![TypeParam::Extensions]; - let db_set = ExtensionSet::type_var(0); - let fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(db_set); - - let def = e.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new(params.clone(), fun_ty), - )?; - - // Concrete extension set - let es = ExtensionSet::singleton(&EXT_ID); - let exp_fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(es.clone()); - let args = [TypeArg::Extensions { es }]; - - def.validate_args(&args, &PRELUDE_REGISTRY, ¶ms) - .unwrap(); - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(exp_fun_ty) - ); Ok(()) } mod proptest { + use std::sync::Weak; + use super::SimpleOpDef; use ::proptest::prelude::*; @@ -846,6 +898,8 @@ pub(super) mod test { |(extension, name, description, misc, signature_func, lower_funcs)| { Self::new(OpDef { extension, + // Use a dead weak reference. Trying to access the extension will always return None. + extension_ref: Weak::default(), name, description, misc, diff --git a/hugr-core/src/extension/op_def/serialize_signature_func.rs b/hugr-core/src/extension/op_def/serialize_signature_func.rs index 88c8c30de..6c189cc84 100644 --- a/hugr-core/src/extension/op_def/serialize_signature_func.rs +++ b/hugr-core/src/extension/op_def/serialize_signature_func.rs @@ -57,7 +57,7 @@ mod test { use super::*; use crate::{ extension::{ - prelude::USIZE_T, CustomSignatureFunc, CustomValidator, ExtensionRegistry, OpDef, + prelude::usize_t, CustomSignatureFunc, CustomValidator, ExtensionRegistry, OpDef, SignatureError, ValidateTypeArgs, }, types::{FuncValueType, Signature, TypeArg}, @@ -121,7 +121,7 @@ mod test { #[test] fn test_serial_sig_func() { // test round-trip - let sig: FuncValueType = Signature::new_endo(USIZE_T.clone()).into(); + let sig: FuncValueType = Signature::new_endo(usize_t().clone()).into(); let simple: SignatureFunc = sig.clone().into(); let ser: SerSignatureFunc = simple.into(); let expected_ser = SerSignatureFunc { diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index e7056bb5d..786b0379e 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -1,6 +1,6 @@ //! Prelude extension - available in all contexts, defining common types, //! operations and constants. -use std::sync::Arc; +use std::sync::{Arc, Weak}; use itertools::Itertools; use lazy_static::lazy_static; @@ -40,102 +40,130 @@ pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); lazy_static! { - static ref PRELUDE_DEF: Arc = { - let mut prelude = Extension::new(PRELUDE_ID, VERSION); - prelude - .add_type( - TypeName::new_inline("usize"), - vec![], - "usize".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - prelude.add_type( - STRING_TYPE_NAME, - vec![], - "string".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - prelude.add_op( - PRINT_OP_ID, - "Print the string to standard output".to_string(), - Signature::new(type_row![STRING_TYPE], type_row![]), - ) - .unwrap(); - prelude.add_type( - TypeName::new_inline(ARRAY_TYPE_NAME), - vec![ TypeParam::max_nat(), TypeBound::Any.into()], - "array".into(), - TypeDefBound::from_params(vec![1] ), - ) - .unwrap(); - - prelude - .add_type( - TypeName::new_inline("qubit"), - vec![], - "qubit".into(), - TypeDefBound::any(), - ) - .unwrap(); - prelude - .add_type( - ERROR_TYPE_NAME, - vec![], - "Simple opaque error type.".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - - - prelude - .add_op( - PANIC_OP_ID, - "Panic with input error".to_string(), - PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], - FuncValueType::new( - vec![TypeRV::new_extension(ERROR_CUSTOM_TYPE), TypeRV::new_row_var_use(0, TypeBound::Any)], - vec![TypeRV::new_row_var_use(1, TypeBound::Any)], - ), - ), - ) - .unwrap(); - - TupleOpDef::load_all_ops(&mut prelude).unwrap(); - NoopDef.add_to_extension(&mut prelude).unwrap(); - LiftDef.add_to_extension(&mut prelude).unwrap(); - array::ArrayOpDef::load_all_ops(&mut prelude).unwrap(); - array::ArrayRepeatDef.add_to_extension(&mut prelude).unwrap(); - array::ArrayScanDef.add_to_extension(&mut prelude).unwrap(); - - Arc::new(prelude) + /// Prelude extension, containing common types and operations. + pub static ref PRELUDE: Arc = { + Extension::new_arc(PRELUDE_ID, VERSION, |prelude, extension_ref| { + + // Construct the list and error types using the passed extension + // reference. + // + // If we tried to use `string_type()` or `error_type()` directly it + // would try to access the `PRELUDE` lazy static recursively, + // causing a deadlock. + let string_type: Type = string_custom_type(extension_ref).into(); + let error_type: CustomType = error_custom_type(extension_ref); + + prelude + .add_type( + TypeName::new_inline("usize"), + vec![], + "usize".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + prelude.add_type( + STRING_TYPE_NAME, + vec![], + "string".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + prelude.add_op( + PRINT_OP_ID, + "Print the string to standard output".to_string(), + Signature::new(vec![string_type], type_row![]), + extension_ref, + ) + .unwrap(); + prelude.add_type( + TypeName::new_inline(ARRAY_TYPE_NAME), + vec![ TypeParam::max_nat(), TypeBound::Any.into()], + "array".into(), + TypeDefBound::from_params(vec![1] ), + extension_ref, + ) + .unwrap(); + prelude + .add_type( + TypeName::new_inline("qubit"), + vec![], + "qubit".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + prelude + .add_type( + ERROR_TYPE_NAME, + vec![], + "Simple opaque error type.".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + prelude + .add_op( + PANIC_OP_ID, + "Panic with input error".to_string(), + PolyFuncTypeRV::new( + [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], + FuncValueType::new( + vec![TypeRV::new_extension(error_type), TypeRV::new_row_var_use(0, TypeBound::Any)], + vec![TypeRV::new_row_var_use(1, TypeBound::Any)], + ), + ), + extension_ref, + ) + .unwrap(); + + TupleOpDef::load_all_ops(prelude, extension_ref).unwrap(); + NoopDef.add_to_extension(prelude, extension_ref).unwrap(); + LiftDef.add_to_extension(prelude, extension_ref).unwrap(); + array::ArrayOpDef::load_all_ops(prelude, extension_ref).unwrap(); + array::ArrayRepeatDef.add_to_extension(prelude, extension_ref).unwrap(); + array::ArrayScanDef.add_to_extension(prelude, extension_ref).unwrap(); + }) }; + /// An extension registry containing only the prelude pub static ref PRELUDE_REGISTRY: ExtensionRegistry = - ExtensionRegistry::try_new([PRELUDE_DEF.clone()]).unwrap(); - - /// Prelude extension - pub static ref PRELUDE: Arc = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap().clone(); - + ExtensionRegistry::try_new([PRELUDE.clone()]).unwrap(); } -pub(crate) const USIZE_CUSTOM_T: CustomType = CustomType::new_simple( - TypeName::new_inline("usize"), - PRELUDE_ID, - TypeBound::Copyable, -); +pub(crate) fn usize_custom_t(extension_ref: &Weak) -> CustomType { + CustomType::new( + TypeName::new_inline("usize"), + vec![], + PRELUDE_ID, + TypeBound::Copyable, + extension_ref, + ) +} -pub(crate) const QB_CUSTOM_T: CustomType = - CustomType::new_simple(TypeName::new_inline("qubit"), PRELUDE_ID, TypeBound::Any); +pub(crate) fn qb_custom_t(extension_ref: &Weak) -> CustomType { + CustomType::new( + TypeName::new_inline("qubit"), + vec![], + PRELUDE_ID, + TypeBound::Any, + extension_ref, + ) +} /// Qubit type. -pub const QB_T: Type = Type::new_extension(QB_CUSTOM_T); +pub fn qb_t() -> Type { + qb_custom_t(&Arc::downgrade(&PRELUDE)).into() +} /// Unsigned size type. -pub const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T); +pub fn usize_t() -> Type { + usize_custom_t(&Arc::downgrade(&PRELUDE)).into() +} /// Boolean type - Sum of two units. -pub const BOOL_T: Type = Type::new_unit_sum(2); +pub fn bool_t() -> Type { + Type::new_unit_sum(2) +} /// Name of the prelude panic operation. /// @@ -152,11 +180,23 @@ pub const PANIC_OP_ID: OpName = OpName::new_inline("panic"); pub const STRING_TYPE_NAME: TypeName = TypeName::new_inline("string"); /// Custom type for strings. -pub const STRING_CUSTOM_TYPE: CustomType = - CustomType::new_simple(STRING_TYPE_NAME, PRELUDE_ID, TypeBound::Copyable); +/// +/// Receives a reference to the prelude extensions as a parameter. +/// This avoids deadlocks when we are in the process of creating the prelude. +fn string_custom_type(extension_ref: &Weak) -> CustomType { + CustomType::new( + STRING_TYPE_NAME, + vec![], + PRELUDE_ID, + TypeBound::Copyable, + extension_ref, + ) +} /// String type. -pub const STRING_TYPE: Type = Type::new_extension(STRING_CUSTOM_TYPE); +pub fn string_type() -> Type { + string_custom_type(&Arc::downgrade(&PRELUDE)).into() +} #[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)] /// Structure for holding constant string values. @@ -189,7 +229,7 @@ impl CustomConst for ConstString { } fn get_type(&self) -> Type { - STRING_TYPE + string_type() } } @@ -197,17 +237,30 @@ impl CustomConst for ConstString { pub const PRINT_OP_ID: OpName = OpName::new_inline("print"); /// The custom type for Errors. -pub const ERROR_CUSTOM_TYPE: CustomType = - CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Copyable); +/// +/// Receives a reference to the prelude extensions as a parameter. +/// This avoids deadlocks when we are in the process of creating the prelude. +fn error_custom_type(extension_ref: &Weak) -> CustomType { + CustomType::new( + ERROR_TYPE_NAME, + vec![], + PRELUDE_ID, + TypeBound::Copyable, + extension_ref, + ) +} + /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); +pub fn error_type() -> Type { + error_custom_type(&Arc::downgrade(&PRELUDE)).into() +} /// The string name of the error type. pub const ERROR_TYPE_NAME: TypeName = TypeName::new_inline("error"); /// Return a Sum type with the second variant as the given type and the first an Error. pub fn sum_with_error(ty: impl Into) -> SumType { - either_type(ERROR_TYPE, ty) + either_type(error_type(), ty) } /// An optional type, i.e. a Sum type with the second variant as the given type and the first as an empty tuple. @@ -369,7 +422,7 @@ impl CustomConst for ConstUsize { } fn get_type(&self) -> Type { - USIZE_T + usize_t() } } @@ -414,7 +467,7 @@ impl CustomConst for ConstError { ExtensionSet::singleton(&PRELUDE_ID) } fn get_type(&self) -> Type { - ERROR_TYPE + error_type() } } @@ -504,7 +557,7 @@ impl ConstFold for TupleOpDef { } } impl MakeOpDef for TupleOpDef { - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { let rv = TypeRV::new_row_var_use(0, TypeBound::Any); let tuple_type = TypeRV::new_tuple(vec![rv.clone()]); @@ -529,13 +582,17 @@ impl MakeOpDef for TupleOpDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { PRELUDE_ID.to_owned() } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) + } + fn post_opdef(&self, def: &mut OpDef) { def.set_constant_folder(*self); } @@ -686,7 +743,7 @@ impl std::str::FromStr for NoopDef { } } impl MakeOpDef for NoopDef { - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { let tv = Type::new_var_use(0, TypeBound::Any); PolyFuncType::new([TypeBound::Any.into()], Signature::new_endo(tv)).into() } @@ -696,13 +753,17 @@ impl MakeOpDef for NoopDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { PRELUDE_ID.to_owned() } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) + } + fn post_opdef(&self, def: &mut OpDef) { def.set_constant_folder(*self); } @@ -792,7 +853,7 @@ impl std::str::FromStr for LiftDef { } impl MakeOpDef for LiftDef { - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { PolyFuncTypeRV::new( vec![TypeParam::Extensions, TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(TypeRV::new_row_var_use(1, TypeBound::Any)) @@ -806,12 +867,16 @@ impl MakeOpDef for LiftDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { PRELUDE_ID.to_owned() } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) + } } /// A node which adds a extension req to the types of the wires it is passed @@ -895,7 +960,8 @@ impl MakeRegisteredOp for Lift { mod test { use crate::builder::inout_sig; use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; - use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; + use crate::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; + use crate::utils::test_quantum_extension; use crate::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}, utils::test_quantum_extension::cx_gate, @@ -973,14 +1039,14 @@ mod test { /// Test building a HUGR involving a new_array operation. fn test_new_array() { let mut b = DFGBuilder::new(inout_sig( - vec![QB_T, QB_T], - array_type(TypeArg::BoundedNat { n: 2 }, QB_T), + vec![qb_t(), qb_t()], + array_type(TypeArg::BoundedNat { n: 2 }, qb_t()), )) .unwrap(); let [q1, q2] = b.input_wires_arr(); - let op = new_array_op(QB_T, 2); + let op = new_array_op(qb_t(), 2); let out = b.add_dataflow_op(op, [q1, q2]).unwrap(); @@ -989,9 +1055,9 @@ mod test { #[test] fn test_option() { - let typ: Type = option_type(BOOL_T).into(); + let typ: Type = option_type(bool_t()).into(); let const_val1 = const_some(Value::true_val()); - let const_val2 = const_none(BOOL_T); + let const_val2 = const_none(bool_t()); let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap(); @@ -1003,9 +1069,9 @@ mod test { #[test] fn test_result() { - let typ: Type = either_type(BOOL_T, FLOAT64_TYPE).into(); - let const_bool = const_left(Value::true_val(), FLOAT64_TYPE); - let const_float = const_right(BOOL_T, ConstF64::new(0.5).into()); + let typ: Type = either_type(bool_t(), float64_type()).into(); + let const_bool = const_left(Value::true_val(), float64_type()); + let const_float = const_right(bool_t(), ConstF64::new(0.5).into()); let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap(); @@ -1026,7 +1092,7 @@ mod test { .unwrap(); let ext_type = Type::new_extension(ext_def); - assert_eq!(ext_type, ERROR_TYPE); + assert_eq!(ext_type, error_type()); let error_val = ConstError::new(2, "my message"); @@ -1063,9 +1129,9 @@ mod test { /// test the panic operation with input and output wires fn test_panic_with_io() { let error_val = ConstError::new(42, "PANIC"); - const TYPE_ARG_Q: TypeArg = TypeArg::Type { ty: QB_T }; + let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; let type_arg_2q: TypeArg = TypeArg::Sequence { - elems: vec![TYPE_ARG_Q, TYPE_ARG_Q], + elems: vec![type_arg_q.clone(), type_arg_q], }; let panic_op = PRELUDE .instantiate_extension_op( @@ -1075,7 +1141,7 @@ mod test { ) .unwrap(); - let mut b = DFGBuilder::new(endo_sig(type_row![QB_T, QB_T])).unwrap(); + let mut b = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); let [q0, q1] = b.input_wires_arr(); let [q0, q1] = b .add_dataflow_op(cx_gate(), [q0, q1]) @@ -1086,7 +1152,8 @@ mod test { .add_dataflow_op(panic_op, [err, q0, q1]) .unwrap() .outputs_arr(); - b.finish_prelude_hugr_with_outputs([q0, q1]).unwrap(); + b.finish_hugr_with_outputs([q0, q1], &test_quantum_extension::REG) + .unwrap(); } #[test] @@ -1097,8 +1164,8 @@ mod test { .unwrap() .instantiate([]) .unwrap(); - let string_type: Type = Type::new_extension(string_custom_type); - assert_eq!(string_type, STRING_TYPE); + let string_ty: Type = Type::new_extension(string_custom_type); + assert_eq!(string_ty, string_type()); let string_const: ConstString = ConstString::new("Lorem ipsum".into()); assert_eq!(string_const.name(), "ConstString(\"Lorem ipsum\")"); assert!(string_const.validate().is_ok()); @@ -1135,7 +1202,7 @@ mod test { ); assert!(subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("bar", Type::UNIT, false))); - assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", STRING_TYPE, false))); + assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", string_type(), false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, true))); assert!(ConstExternalSymbol::new("", Type::UNIT, true) diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index 025a31431..3b8512fef 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -1,13 +1,13 @@ use std::str::FromStr; +use std::sync::{Arc, Weak}; use itertools::Itertools; use strum_macros::EnumIter; use strum_macros::EnumString; use strum_macros::IntoStaticStr; -use crate::extension::prelude::either_type; use crate::extension::prelude::option_type; -use crate::extension::prelude::USIZE_T; +use crate::extension::prelude::{either_type, usize_custom_t}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; @@ -113,7 +113,11 @@ impl ArrayOpDef { } /// To avoid recursion when defining the extension, take the type definition as an argument. - fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + fn signature_from_def( + &self, + array_def: &TypeDef, + extension_ref: &Weak, + ) -> SignatureFunc { use ArrayOpDef::*; if let new_array | pop_left | pop_right = self { // implements SignatureFromArgs @@ -125,6 +129,12 @@ impl ArrayOpDef { let array_ty = instantiate(array_def, size_var.clone(), elem_ty_var.clone()); let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; + // Construct the usize type using the passed extension reference. + // + // If we tried to use `usize_t()` directly it would try to access + // the `PRELUDE` lazy static recursively, causing a deadlock. + let usize_t: Type = usize_custom_t(extension_ref).into(); + match self { get => { let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; @@ -133,7 +143,7 @@ impl ArrayOpDef { let option_type: Type = option_type(copy_elem_ty).into(); PolyFuncTypeRV::new( params, - FuncValueType::new(vec![copy_array_ty, USIZE_T], option_type), + FuncValueType::new(vec![copy_array_ty, usize_t], option_type), ) } set => { @@ -142,7 +152,7 @@ impl ArrayOpDef { PolyFuncTypeRV::new( standard_params, FuncValueType::new( - vec![array_ty.clone(), USIZE_T, elem_ty_var], + vec![array_ty.clone(), usize_t, elem_ty_var], result_type, ), ) @@ -151,7 +161,7 @@ impl ArrayOpDef { let result_type: Type = either_type(array_ty.clone(), array_ty.clone()).into(); PolyFuncTypeRV::new( standard_params, - FuncValueType::new(vec![array_ty, USIZE_T, USIZE_T], result_type), + FuncValueType::new(vec![array_ty, usize_t.clone(), usize_t], result_type), ) } discard_empty => PolyFuncTypeRV::new( @@ -173,11 +183,15 @@ impl MakeOpDef for ArrayOpDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } - fn signature(&self) -> SignatureFunc { - self.signature_from_def(array_type_def()) + fn init_signature(&self, extension_ref: &Weak) -> SignatureFunc { + self.signature_from_def(array_type_def(), extension_ref) + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) } fn extension(&self) -> ExtensionId { @@ -205,9 +219,11 @@ impl MakeOpDef for ArrayOpDef { fn add_to_extension( &self, extension: &mut Extension, + extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { - let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let sig = + self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap(), extension_ref); + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -368,13 +384,17 @@ impl MakeOpDef for ArrayRepeatDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { self.signature_from_def(array_type_def()) } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) + } + fn extension(&self) -> ExtensionId { PRELUDE_ID } @@ -393,9 +413,10 @@ impl MakeOpDef for ArrayRepeatDef { fn add_to_extension( &self, extension: &mut Extension, + extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -504,7 +525,8 @@ impl FromStr for ArrayScanDef { } impl ArrayScanDef { - /// To avoid recursion when defining the extension, take the type definition as an argument. + /// To avoid recursion when defining the extension, take the type definition + /// and a reference to the extension as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { // array, (T1, *A -> T2, *A), -> array, *A let params = vec![ @@ -546,13 +568,17 @@ impl MakeOpDef for ArrayScanDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { self.signature_from_def(array_type_def()) } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) + } + fn extension(&self) -> ExtensionId { PRELUDE_ID } @@ -573,9 +599,10 @@ impl MakeOpDef for ArrayScanDef { fn add_to_extension( &self, extension: &mut Extension, + extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -692,9 +719,10 @@ impl HasConcrete for ArrayScanDef { mod tests { use strum::IntoEnumIterator; + use crate::extension::prelude::usize_t; use crate::{ builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::{BOOL_T, QB_T}, + extension::prelude::{bool_t, qb_t}, ops::{OpTrait, OpType}, types::Signature, }; @@ -704,7 +732,11 @@ mod tests { #[test] fn test_array_ops() { for def in ArrayOpDef::iter() { - let ty = if def == ArrayOpDef::get { BOOL_T } else { QB_T }; + let ty = if def == ArrayOpDef::get { + bool_t() + } else { + qb_t() + }; let size = if def == ArrayOpDef::discard_empty { 0 } else { @@ -721,14 +753,14 @@ mod tests { /// Test building a HUGR involving a new_array operation. fn test_new_array() { let mut b = DFGBuilder::new(inout_sig( - vec![QB_T, QB_T], - array_type(TypeArg::BoundedNat { n: 2 }, QB_T), + vec![qb_t(), qb_t()], + array_type(TypeArg::BoundedNat { n: 2 }, qb_t()), )) .unwrap(); let [q1, q2] = b.input_wires_arr(); - let op = new_array_op(QB_T, 2); + let op = new_array_op(qb_t(), 2); let out = b.add_dataflow_op(op, [q1, q2]).unwrap(); @@ -738,7 +770,7 @@ mod tests { #[test] fn test_get() { let size = 2; - let element_ty = BOOL_T; + let element_ty = bool_t(); let op = ArrayOpDef::get.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -748,7 +780,7 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_type(size, element_ty.clone()), USIZE_T].into(), + &vec![array_type(size, element_ty.clone()), usize_t()].into(), &vec![option_type(element_ty.clone()).into()].into() ) ); @@ -757,7 +789,7 @@ mod tests { #[test] fn test_set() { let size = 2; - let element_ty = BOOL_T; + let element_ty = bool_t(); let op = ArrayOpDef::set.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -768,7 +800,7 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_ty.clone(), USIZE_T, element_ty.clone()].into(), + &vec![array_ty.clone(), usize_t(), element_ty.clone()].into(), &vec![either_type(result_row.clone(), result_row).into()].into() ) ); @@ -777,7 +809,7 @@ mod tests { #[test] fn test_swap() { let size = 2; - let element_ty = BOOL_T; + let element_ty = bool_t(); let op = ArrayOpDef::swap.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -787,7 +819,7 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_ty.clone(), USIZE_T, USIZE_T].into(), + &vec![array_ty.clone(), usize_t(), usize_t()].into(), &vec![either_type(array_ty.clone(), array_ty).into()].into() ) ); @@ -796,7 +828,7 @@ mod tests { #[test] fn test_pops() { let size = 2; - let element_ty = BOOL_T; + let element_ty = bool_t(); for op in [ArrayOpDef::pop_left, ArrayOpDef::pop_right].iter() { let op = op.to_concrete(element_ty.clone(), size); @@ -821,7 +853,7 @@ mod tests { #[test] fn test_discard_empty() { let size = 0; - let element_ty = BOOL_T; + let element_ty = bool_t(); let op = ArrayOpDef::discard_empty.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -839,7 +871,7 @@ mod tests { #[test] fn test_repeat_def() { - let op = ArrayRepeat::new(QB_T, 2, ExtensionSet::singleton(&PRELUDE_ID)); + let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(&PRELUDE_ID)); let optype: OpType = op.clone().into(); let new_op: ArrayRepeat = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -848,7 +880,7 @@ mod tests { #[test] fn test_repeat() { let size = 2; - let element_ty = QB_T; + let element_ty = qb_t(); let es = ExtensionSet::singleton(&PRELUDE_ID); let op = ArrayRepeat::new(element_ty.clone(), size, es.clone()); @@ -860,7 +892,7 @@ mod tests { sig.io(), ( &vec![Type::new_function( - Signature::new(vec![], vec![QB_T]).with_extension_delta(es) + Signature::new(vec![], vec![qb_t()]).with_extension_delta(es) )] .into(), &vec![array_type(size, element_ty.clone())].into(), @@ -871,9 +903,9 @@ mod tests { #[test] fn test_scan_def() { let op = ArrayScan::new( - BOOL_T, - QB_T, - vec![USIZE_T], + bool_t(), + qb_t(), + vec![usize_t()], 2, ExtensionSet::singleton(&PRELUDE_ID), ); @@ -885,8 +917,8 @@ mod tests { #[test] fn test_scan_map() { let size = 2; - let src_ty = QB_T; - let tgt_ty = BOOL_T; + let src_ty = qb_t(); + let tgt_ty = bool_t(); let es = ExtensionSet::singleton(&PRELUDE_ID); let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); @@ -911,10 +943,10 @@ mod tests { #[test] fn test_scan_accs() { let size = 2; - let src_ty = QB_T; - let tgt_ty = BOOL_T; - let acc_ty1 = USIZE_T; - let acc_ty2 = QB_T; + let src_ty = qb_t(); + let tgt_ty = bool_t(); + let acc_ty1 = usize_t(); + let acc_ty2 = qb_t(); let es = ExtensionSet::singleton(&PRELUDE_ID); let op = ArrayScan::new( diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index 7d7f3f5b3..753533f1f 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -86,7 +86,7 @@ mod tests { use crate::{ builder::{DFGBuilder, DataflowHugr}, extension::{ - prelude::{option_type, BOOL_T}, + prelude::{bool_t, option_type}, PRELUDE_REGISTRY, }, types::Signature, @@ -94,14 +94,15 @@ mod tests { #[test] fn test_build_unwrap() { - let mut builder = - DFGBuilder::new(Signature::new(Type::from(option_type(BOOL_T)), BOOL_T).with_prelude()) - .unwrap(); + let mut builder = DFGBuilder::new( + Signature::new(Type::from(option_type(bool_t())), bool_t()).with_prelude(), + ) + .unwrap(); let [opt] = builder.input_wires_arr(); let [res] = builder - .build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(BOOL_T), opt) + .build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(bool_t()), opt) .unwrap(); builder.finish_prelude_hugr_with_outputs([res]).unwrap(); } diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs new file mode 100644 index 000000000..93e1b97f9 --- /dev/null +++ b/hugr-core/src/extension/resolution.rs @@ -0,0 +1,103 @@ +//! Utilities for resolving operations and types present in a HUGR, and updating +//! the list of used extensions. See [`crate::Hugr::resolve_extension_defs`]. +//! +//! When listing "used extensions" we only care about _definitional_ extension +//! requirements, i.e., the operations and types that are required to define the +//! HUGR nodes and wire types. This is computed from the union of all extension +//! required across the HUGR. +//! +//! This is distinct from _runtime_ extension requirements, which are defined +//! more granularly in each function signature by the `required_extensions` +//! field. See the `extension_inference` feature and related modules for that. +//! +//! Note: These procedures are only temporary until `hugr-model` is stabilized. +//! Once that happens, hugrs will no longer be directly deserialized using serde +//! but instead will be created by the methods in `crate::import`. As these +//! (will) automatically resolve extensions as the operations are created, +//! we will no longer require this post-facto resolution step. + +mod ops; +mod types; + +pub(crate) use ops::update_op_extensions; +pub(crate) use types::update_op_types_extensions; + +use derive_more::{Display, Error, From}; + +use super::{Extension, ExtensionId, ExtensionRegistry}; +use crate::ops::custom::OpaqueOpError; +use crate::ops::{NamedOp, OpName, OpType}; +use crate::types::TypeName; +use crate::Node; + +/// Errors that can occur during extension resolution. +#[derive(Debug, Display, Clone, Error, From, PartialEq)] +#[non_exhaustive] +pub enum ExtensionResolutionError { + /// Could not resolve an opaque operation to an extension operation. + #[display("Error resolving opaque operation: {_0}")] + #[from] + OpaqueOpError(OpaqueOpError), + /// An operation requires an extension that is not in the given registry. + #[display( + "{op} ({node}) requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + available_extensions.join(", ") + )] + MissingOpExtension { + /// The node that requires the extension. + node: Node, + /// The operation that requires the extension. + op: OpName, + /// The missing extension + missing_extension: ExtensionId, + /// A list of available extensions. + available_extensions: Vec, + }, + #[display( + "Type {ty} in {node} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + available_extensions.join(", ") + )] + /// A type references an extension that is not in the given registry. + MissingTypeExtension { + /// The node that requires the extension. + node: Node, + /// The type that requires the extension. + ty: TypeName, + /// The missing extension + missing_extension: ExtensionId, + /// A list of available extensions. + available_extensions: Vec, + }, +} + +impl ExtensionResolutionError { + /// Create a new error for missing operation extensions. + pub fn missing_op_extension( + node: Node, + op: &OpType, + missing_extension: &ExtensionId, + extensions: &ExtensionRegistry, + ) -> Self { + Self::MissingOpExtension { + node, + op: NamedOp::name(op), + missing_extension: missing_extension.clone(), + available_extensions: extensions.ids().cloned().collect(), + } + } + + /// Create a new error for missing type extensions. + pub fn missing_type_extension( + node: Node, + ty: &TypeName, + missing_extension: &ExtensionId, + extensions: &ExtensionRegistry, + ) -> Self { + Self::MissingTypeExtension { + node, + ty: ty.clone(), + missing_extension: missing_extension.clone(), + available_extensions: extensions.ids().cloned().collect(), + } + } +} diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs new file mode 100644 index 000000000..7c8fbfc37 --- /dev/null +++ b/hugr-core/src/extension/resolution/ops.rs @@ -0,0 +1,92 @@ +//! Resolve `OpaqueOp`s into `ExtensionOp`s and return an operation's required extension. + +use std::sync::Arc; + +use super::{Extension, ExtensionRegistry, ExtensionResolutionError}; +use crate::ops::custom::OpaqueOpError; +use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType}; +use crate::Node; + +/// Compute the required extension for an operation. +/// +/// If the op is a [`OpType::OpaqueOp`], replace it with a resolved +/// [`OpType::ExtensionOp`] by looking searching for the operation in the +/// extension registries. +/// +/// If `op` was an opaque or extension operation, the result contains the +/// extension reference that should be added to the hugr's extension registry. +/// +/// # Errors +/// +/// If the serialized opaque resolves to a definition that conflicts with what +/// was serialized. Or if the operation is not found in the registry. +pub(crate) fn update_op_extensions<'e>( + node: Node, + op: &mut OpType, + extensions: &'e ExtensionRegistry, +) -> Result>, ExtensionResolutionError> { + let extension = operation_extension(node, op, extensions)?; + + let OpType::OpaqueOp(opaque) = op else { + return Ok(extension); + }; + + // Fail if the Extension is not in the registry, or if the Extension was + // found but did not have the expected operation. + let extension = extension.expect("OpaqueOp should have an extension"); + let Some(def) = extension.get_op(opaque.op_name()) else { + return Err(OpaqueOpError::OpNotFoundInExtension { + node, + op: opaque.name().clone(), + extension: extension.name().clone(), + available_ops: extension + .operations() + .map(|(name, _)| name.clone()) + .collect(), + } + .into()); + }; + + let ext_op = + ExtensionOp::new_with_cached(def.clone(), opaque.args().to_vec(), opaque, extensions) + .map_err(|e| OpaqueOpError::SignatureError { + node, + name: opaque.name().clone(), + cause: e, + })?; + + if opaque.signature() != ext_op.signature() { + return Err(OpaqueOpError::SignatureMismatch { + node, + extension: opaque.extension().clone(), + op: def.name().clone(), + computed: ext_op.signature().clone(), + stored: opaque.signature().clone(), + } + .into()); + }; + + // Replace the opaque operation with the resolved extension operation. + *op = ext_op.into(); + + Ok(Some(extension)) +} + +/// Returns the extension in the registry required by the operation. +/// +/// If the operation does not require an extension, returns `None`. +fn operation_extension<'e>( + node: Node, + op: &OpType, + extensions: &'e ExtensionRegistry, +) -> Result>, ExtensionResolutionError> { + let Some(ext) = op.extension_id() else { + return Ok(None); + }; + match extensions.get(ext) { + Some(e) => Ok(Some(e)), + None => Err(ExtensionResolutionError::missing_op_extension( + node, op, ext, extensions, + )), + } +} diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs new file mode 100644 index 000000000..b249a4dca --- /dev/null +++ b/hugr-core/src/extension/resolution/types.rs @@ -0,0 +1,174 @@ +//! Resolve weak links inside `CustomType`s in an optype's signature. + +use std::sync::Arc; + +use super::{ExtensionRegistry, ExtensionResolutionError}; +use crate::ops::OpType; +use crate::types::type_row::TypeRowBase; +use crate::types::{MaybeRV, Signature, SumType, TypeBase, TypeEnum}; +use crate::Node; + +/// Replace the dangling extension pointer in the [`CustomType`]s inside a +/// signature with a valid pointer to the extension in the `extensions` +/// registry. +/// +/// When a pointer is replaced, the extension is added to the +/// `used_extensions` registry and the new type definition is returned. +/// +/// This is a helper function used right after deserializing a Hugr. +pub fn update_op_types_extensions( + node: Node, + op: &mut OpType, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + match op { + OpType::ExtensionOp(ext) => { + update_signature_exts(node, ext.signature_mut(), extensions, used_extensions)? + } + OpType::FuncDefn(f) => { + update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? + } + OpType::FuncDecl(f) => { + update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? + } + OpType::Const(_c) => { + // TODO: Is it OK to assume that `Value::get_type` returns a well-resolved value? + } + OpType::Input(inp) => { + update_type_row_exts(node, &mut inp.types, extensions, used_extensions)? + } + OpType::Output(out) => { + update_type_row_exts(node, &mut out.types, extensions, used_extensions)? + } + OpType::Call(c) => { + update_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; + update_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; + } + OpType::CallIndirect(c) => { + update_signature_exts(node, &mut c.signature, extensions, used_extensions)? + } + OpType::LoadConstant(lc) => { + update_type_exts(node, &mut lc.datatype, extensions, used_extensions)? + } + OpType::LoadFunction(lf) => { + update_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; + update_signature_exts(node, &mut lf.signature, extensions, used_extensions)?; + } + OpType::DFG(dfg) => { + update_signature_exts(node, &mut dfg.signature, extensions, used_extensions)? + } + OpType::OpaqueOp(op) => { + update_signature_exts(node, op.signature_mut(), extensions, used_extensions)? + } + OpType::Tag(t) => { + for variant in t.variants.iter_mut() { + update_type_row_exts(node, variant, extensions, used_extensions)? + } + } + OpType::DataflowBlock(db) => { + update_type_row_exts(node, &mut db.inputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut db.other_outputs, extensions, used_extensions)?; + for row in db.sum_rows.iter_mut() { + update_type_row_exts(node, row, extensions, used_extensions)?; + } + } + OpType::ExitBlock(e) => { + update_type_row_exts(node, &mut e.cfg_outputs, extensions, used_extensions)?; + } + OpType::TailLoop(tl) => { + update_type_row_exts(node, &mut tl.just_inputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut tl.just_outputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut tl.rest, extensions, used_extensions)?; + } + OpType::CFG(cfg) => { + update_signature_exts(node, &mut cfg.signature, extensions, used_extensions)?; + } + OpType::Conditional(cond) => { + for row in cond.sum_rows.iter_mut() { + update_type_row_exts(node, row, extensions, used_extensions)?; + } + update_type_row_exts(node, &mut cond.other_inputs, extensions, used_extensions)?; + update_type_row_exts(node, &mut cond.outputs, extensions, used_extensions)?; + } + OpType::Case(case) => { + update_signature_exts(node, &mut case.signature, extensions, used_extensions)?; + } + // Ignore optypes that do not store a signature. + _ => {} + } + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a signature. +/// +/// Adds the extensions used in the signature to the `used_extensions` registry. +fn update_signature_exts( + node: Node, + signature: &mut Signature, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + // Note that we do not include the signature's `extension_reqs` here, as those refer + // to _runtime_ requirements that may not be currently present. + // See https://github.com/CQCL/hugr/issues/1734 + // TODO: Update comment once that issue gets implemented. + update_type_row_exts(node, &mut signature.input, extensions, used_extensions)?; + update_type_row_exts(node, &mut signature.output, extensions, used_extensions)?; + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. +/// +/// Adds the extensions used in the row to the `used_extensions` registry. +fn update_type_row_exts( + node: Node, + row: &mut TypeRowBase, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + for ty in row.iter_mut() { + update_type_exts(node, ty, extensions, used_extensions)?; + } + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a type. +/// +/// Adds the extensions used in the type to the `used_extensions` registry. +fn update_type_exts( + node: Node, + typ: &mut TypeBase, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + match typ.as_type_enum_mut() { + TypeEnum::Extension(custom) => { + let ext_id = custom.extension(); + let ext = extensions.get(ext_id).ok_or_else(|| { + ExtensionResolutionError::missing_type_extension( + node, + custom.name(), + ext_id, + extensions, + ) + })?; + + // Add the extension to the used extensions registry, + // and update the CustomType with the valid pointer. + used_extensions.register_updated_ref(ext); + custom.update_extension(Arc::downgrade(ext)); + } + TypeEnum::Function(f) => { + update_type_row_exts(node, &mut f.input, extensions, used_extensions)?; + update_type_row_exts(node, &mut f.output, extensions, used_extensions)?; + } + TypeEnum::Sum(SumType::General { rows }) => { + for row in rows.iter_mut() { + update_type_row_exts(node, row, extensions, used_extensions)?; + } + } + _ => {} + } + Ok(()) +} diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index c338a693d..d48f596ea 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -1,5 +1,7 @@ //! A trait that enum for op definitions that gathers up some shared functionality. +use std::sync::Weak; + use strum::IntoEnumIterator; use crate::ops::{ExtensionOp, OpName, OpNameRef}; @@ -51,12 +53,24 @@ pub trait MakeOpDef: NamedOp { where Self: Sized; - /// Return the signature (polymorphic function type) of the operation. - fn signature(&self) -> SignatureFunc; - /// The ID of the extension this operation is defined in. fn extension(&self) -> ExtensionId; + /// Returns a weak reference to the extension this operation is defined in. + fn extension_ref(&self) -> Weak; + + /// Compute the signature of the operation while the extension definition is being built. + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`], + /// and it is normally internally called by [`MakeOpDef::add_to_extension`]. + fn init_signature(&self, extension_ref: &Weak) -> SignatureFunc; + + /// Return the signature (polymorphic function type) of the operation. + fn signature(&self) -> SignatureFunc { + self.init_signature(&self.extension_ref()) + } + /// Description of the operation. By default, the same as `self.name()`. fn description(&self) -> String { self.name().to_string() @@ -67,8 +81,20 @@ pub trait MakeOpDef: NamedOp { /// Add an operation implemented as an [MakeOpDef], which can provide the data /// required to define an [OpDef], to an extension. - fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { - let def = extension.add_op(self.name(), self.description(), self.signature())?; + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> { + let def = extension.add_op( + self.name(), + self.description(), + self.init_signature(extension_ref), + extension_ref, + )?; self.post_opdef(def); @@ -77,12 +103,18 @@ pub trait MakeOpDef: NamedOp { /// Load all variants of an enum of op definitions in to an extension as op defs. /// See [strum::IntoEnumIterator]. - fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. + fn load_all_ops( + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> where Self: IntoEnumIterator, { for op in Self::iter() { - op.add_to_extension(extension)?; + op.add_to_extension(extension, extension_ref)?; } Ok(()) } @@ -286,10 +318,14 @@ mod test { } impl MakeOpDef for DummyEnum { - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { Signature::new_endo(type_row![]).into() } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXT) + } + fn from_def(_op_def: &OpDef) -> Result { Ok(Self::Dumb) } @@ -316,9 +352,11 @@ mod test { lazy_static! { static ref EXT: Arc = { - let mut e = Extension::new_test(EXT_ID.clone()); - DummyEnum::Dumb.add_to_extension(&mut e).unwrap(); - Arc::new(e) + Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| { + DummyEnum::Dumb + .add_to_extension(ext, extension_ref) + .unwrap(); + }) }; static ref DUMMY_REG: ExtensionRegistry = ExtensionRegistry::try_new([EXT.clone()]).unwrap(); diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index 1affe68f0..b15a8d3c3 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -1,4 +1,5 @@ use std::collections::btree_map::Entry; +use std::sync::Weak; use super::{CustomConcrete, ExtensionBuildError}; use super::{Extension, ExtensionId, SignatureError}; @@ -56,6 +57,9 @@ impl TypeDefBound { pub struct TypeDef { /// The unique Extension owning this TypeDef (of which this TypeDef is a member) extension: ExtensionId, + /// A weak reference to the extension defining this operation. + #[serde(skip)] + extension_ref: Weak, /// The unique name of the type name: TypeName, /// Declaration of type parameters. The TypeDef must be instantiated @@ -82,9 +86,9 @@ impl TypeDef { /// This function will return an error if the type of the instance does not /// match the definition. pub fn check_custom(&self, custom: &CustomType) -> Result<(), SignatureError> { - if self.extension() != custom.parent_extension() { + if self.extension_id() != custom.parent_extension() { return Err(SignatureError::ExtensionMismatch( - self.extension().clone(), + self.extension_id().clone(), custom.parent_extension().clone(), )); } @@ -121,8 +125,9 @@ impl TypeDef { Ok(CustomType::new( self.name().clone(), args, - self.extension().clone(), + self.extension_id().clone(), bound, + &self.extension_ref, )) } /// The [`TypeBound`] of the definition. @@ -156,22 +161,55 @@ impl TypeDef { &self.name } - fn extension(&self) -> &ExtensionId { + /// Returns a reference to the extension id of this [`TypeDef`]. + pub fn extension_id(&self) -> &ExtensionId { &self.extension } + + /// Returns a weak reference to the extension defining this type. + pub fn extension(&self) -> Weak { + self.extension_ref.clone() + } } impl Extension { /// Add an exported type to the extension. + /// + /// This method requires a [`Weak`] reference to the [`std::sync::Arc`] containing the + /// extension being defined. The intended way to call this method is inside + /// the closure passed to [`Extension::new_arc`] when defining the extension. + /// + /// # Example + /// + /// ``` + /// # use hugr_core::types::Signature; + /// # use hugr_core::extension::{Extension, ExtensionId, Version}; + /// # use hugr_core::extension::{TypeDefBound}; + /// Extension::new_arc( + /// ExtensionId::new_unchecked("my.extension"), + /// Version::new(0, 1, 0), + /// |ext, extension_ref| { + /// ext.add_type( + /// "MyType".into(), + /// vec![], // No type parameters + /// "Some type".into(), + /// TypeDefBound::any(), + /// extension_ref, + /// ); + /// }, + /// ); + /// ``` pub fn add_type( &mut self, name: TypeName, params: Vec, description: String, bound: TypeDefBound, + extension_ref: &Weak, ) -> Result<&TypeDef, ExtensionBuildError> { let ty = TypeDef { extension: self.name.clone(), + extension_ref: extension_ref.clone(), name, params, description, @@ -186,9 +224,9 @@ impl Extension { #[cfg(test)] mod test { - use crate::extension::prelude::{QB_T, USIZE_T}; + use crate::extension::prelude::{qb_t, usize_t}; use crate::extension::SignatureError; - use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; + use crate::std_extensions::arithmetic::float_types::float64_type; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{Signature, Type, TypeBound}; @@ -202,6 +240,8 @@ mod test { b: TypeBound::Copyable, }], extension: "MyRsrc".try_into().unwrap(), + // Dummy extension. Will return `None` when trying to upgrade it into an `Arc`. + extension_ref: Default::default(), description: "Some parametrised type".into(), bound: TypeDefBound::FromParams { indices: vec![0] }, }; @@ -212,15 +252,15 @@ mod test { .unwrap(), ); assert_eq!(typ.least_upper_bound(), TypeBound::Copyable); - let typ2 = Type::new_extension(def.instantiate([USIZE_T.into()]).unwrap()); + let typ2 = Type::new_extension(def.instantiate([usize_t().into()]).unwrap()); assert_eq!(typ2.least_upper_bound(), TypeBound::Copyable); // And some bad arguments...firstly, wrong kind of TypeArg: assert_eq!( - def.instantiate([TypeArg::Type { ty: QB_T }]), + def.instantiate([TypeArg::Type { ty: qb_t() }]), Err(SignatureError::TypeArgMismatch( TypeArgError::TypeMismatch { - arg: TypeArg::Type { ty: QB_T }, + arg: TypeArg::Type { ty: qb_t() }, param: TypeBound::Copyable.into() } )) @@ -233,8 +273,8 @@ mod test { // Too many arguments: assert_eq!( def.instantiate([ - TypeArg::Type { ty: FLOAT64_TYPE }, - TypeArg::Type { ty: FLOAT64_TYPE }, + TypeArg::Type { ty: float64_type() }, + TypeArg::Type { ty: float64_type() }, ]) .unwrap_err(), SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index ed0fcca0a..b42622745 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -24,8 +24,10 @@ use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; +use crate::extension::resolution::{ + update_op_extensions, update_op_types_extensions, ExtensionResolutionError, +}; use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; -use crate::ops::custom::resolve_extension_ops; use crate::ops::{OpTag, OpTrait}; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; @@ -87,7 +89,7 @@ impl Hugr { &mut self, extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { - resolve_extension_ops(self, extension_registry)?; + self.resolve_extension_defs(extension_registry)?; self.validate_no_extensions(extension_registry)?; #[cfg(feature = "extension_inference")] { @@ -167,6 +169,67 @@ impl Hugr { infer(self, self.root(), remove)?; Ok(()) } + + /// Given a Hugr that has been deserialized, collect all extensions used to + /// define the HUGR while resolving all [`OpType::OpaqueOp`] operations into + /// [`OpType::ExtensionOp`]s and updating the extension pointer in all + /// internal [`crate::types::CustomType`]s to point to the extensions in the + /// register. + /// + /// When listing "used extensions" we only care about _definitional_ + /// extension requirements, i.e., the operations and types that are required + /// to define the HUGR nodes and wire types. This is computed from the union + /// of all extension required across the HUGR. + /// + /// This is distinct from _runtime_ extension requirements computed in + /// [`Hugr::infer_extensions`], which are computed more granularly in each + /// function signature by the `required_extensions` field and define the set + /// of capabilities required by the runtime to execute each function. + /// + /// Returns a new extension registry with the extensions used in the Hugr. + /// + /// # Parameters + /// + /// - `extensions`: The extension set considered when resolving opaque + /// operations and types. The original Hugr's internal extension + /// registry is ignored and replaced with the newly computed one. + /// + /// # Errors + /// + /// - If an opaque operation cannot be resolved to an extension operation. + /// - If an extension operation references an extension that is missing from + /// the registry. + /// - If a custom type references an extension that is missing from the + /// registry. + pub fn resolve_extension_defs( + &mut self, + extensions: &ExtensionRegistry, + ) -> Result { + let mut used_extensions = ExtensionRegistry::default(); + + // Here we need to iterate the optypes in the hugr mutably, to avoid + // having to clone and accumulate all replacements before finally + // applying them. + // + // This is not something we want to expose it the API, so we manually + // iterate instead of writing it as a method. + for n in 0..self.node_count() { + let pg_node = portgraph::NodeIndex::new(n); + let node: Node = pg_node.into(); + if !self.contains_node(node) { + continue; + } + + let op = &mut self.op_types[pg_node]; + + if let Some(extension) = update_op_extensions(node, op, extensions)? { + used_extensions.register_updated_ref(extension); + } + update_op_types_extensions(node, op, extensions, &mut used_extensions)?; + } + + Ok(used_extensions) + } } /// Internal API for HUGRs, not intended for use by users. diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 3c538c357..4753e1dec 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -119,6 +119,8 @@ pub trait HugrMut: HugrMutInternals { } /// Remove a node from the graph and return the node weight. + /// Note that if the node has children, they are not removed; this leaves + /// the Hugr in an invalid state. See [Self::remove_subtree]. /// /// # Panics /// @@ -129,6 +131,19 @@ pub trait HugrMut: HugrMutInternals { self.hugr_mut().remove_node(node) } + /// Remove a node from the graph, along with all its descendants in the hierarchy. + /// + /// # Panics + /// + /// If the node is not in the graph, or is the root (this would leave an empty Hugr). + fn remove_subtree(&mut self, node: Node) { + panic_invalid_non_root(self, node); + while let Some(ch) = self.first_child(node) { + self.remove_subtree(ch) + } + self.hugr_mut().remove_node(node); + } + /// Connect two nodes at the given ports. /// /// # Panics @@ -520,18 +535,15 @@ pub(super) fn panic_invalid_port( mod test { use crate::{ extension::{ - prelude::{Noop, USIZE_T}, + prelude::{usize_t, Noop}, PRELUDE_REGISTRY, }, - macros::type_row, - ops::{self, dataflow::IOTrait}, - types::{Signature, Type}, + ops::{self, dataflow::IOTrait, FuncDefn, Input, Output}, + types::Signature, }; use super::*; - const NAT: Type = USIZE_T; - #[test] fn simple_function() -> Result<(), Box> { let mut hugr = Hugr::default(); @@ -544,16 +556,16 @@ mod test { module, ops::FuncDefn { name: "main".into(), - signature: Signature::new(type_row![NAT], type_row![NAT, NAT]) + signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]) .with_prelude() .into(), }, ); { - let f_in = hugr.add_node_with_parent(f, ops::Input::new(type_row![NAT])); - let f_out = hugr.add_node_with_parent(f, ops::Output::new(type_row![NAT, NAT])); - let noop = hugr.add_node_with_parent(f, Noop(NAT)); + let f_in = hugr.add_node_with_parent(f, ops::Input::new(vec![usize_t()])); + let f_out = hugr.add_node_with_parent(f, ops::Output::new(vec![usize_t(), usize_t()])); + let noop = hugr.add_node_with_parent(f, Noop(usize_t())); hugr.connect(f_in, 0, noop, 0); hugr.connect(noop, 0, f_out, 0); @@ -583,4 +595,33 @@ mod test { hugr.remove_metadata(root, "meta"); assert_eq!(hugr.get_metadata(root, "meta"), None); } + + #[test] + fn remove_subtree() { + let mut hugr = Hugr::default(); + let root = hugr.root(); + let [foo, bar] = ["foo", "bar"].map(|name| { + let fd = hugr.add_node_with_parent( + root, + FuncDefn { + name: name.to_string(), + signature: Signature::new_endo(usize_t()).into(), + }, + ); + let inp = hugr.add_node_with_parent(fd, Input::new(usize_t())); + let out = hugr.add_node_with_parent(fd, Output::new(usize_t())); + hugr.connect(inp, 0, out, 0); + fd + }); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 7); + + hugr.remove_subtree(foo); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 4); + + hugr.remove_subtree(bar); + hugr.validate(&PRELUDE_REGISTRY).unwrap(); + assert_eq!(hugr.node_count(), 1); + } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 47bbd657b..3f1c6b6ff 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -1,7 +1,11 @@ //! Internal traits, not exposed in the public `hugr` API. +use std::borrow::Cow; use std::ops::Range; +use std::rc::Rc; +use std::sync::Arc; +use delegate::delegate; use portgraph::{LinkView, MultiPortGraph, PortMut, PortView}; use crate::ops::handle::NodeHandle; @@ -31,7 +35,7 @@ pub trait HugrInternals { fn root_node(&self) -> Node; } -impl> HugrInternals for T { +impl HugrInternals for Hugr { type Portgraph<'p> = &'p MultiPortGraph where @@ -39,20 +43,103 @@ impl> HugrInternals for T { #[inline] fn portgraph(&self) -> Self::Portgraph<'_> { - &self.as_ref().graph + &self.graph } #[inline] fn base_hugr(&self) -> &Hugr { - self.as_ref() + self } #[inline] fn root_node(&self) -> Node { - self.as_ref().root.into() + self.root.into() } } +impl HugrInternals for &T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + delegate! { + to (**self) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn base_hugr(&self) -> &Hugr; + fn root_node(&self) -> Node; + } + } +} + +impl HugrInternals for &mut T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + delegate! { + to (**self) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn base_hugr(&self) -> &Hugr; + fn root_node(&self) -> Node; + } + } +} + +impl HugrInternals for Rc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + delegate! { + to (**self) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn base_hugr(&self) -> &Hugr; + fn root_node(&self) -> Node; + } + } +} + +impl HugrInternals for Arc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + delegate! { + to (**self) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn base_hugr(&self) -> &Hugr; + fn root_node(&self) -> Node; + } + } +} + +impl HugrInternals for Box { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + delegate! { + to (**self) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn base_hugr(&self) -> &Hugr; + fn root_node(&self) -> Node; + } + } +} + +impl HugrInternals for Cow<'_, T> { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + delegate! { + to self.as_ref() { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn base_hugr(&self) -> &Hugr; + fn root_node(&self) -> Node; + } + } +} /// Trait for accessing the mutable internals of a Hugr(Mut). /// /// Specifically, this trait lets you apply arbitrary modifications that may diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index dd26b1ac2..3354fc820 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -70,14 +70,11 @@ impl Rewrite for Transactional { let mut backup = Hugr::new(h.root_type().clone()); backup.insert_from_view(backup.root(), h); let r = self.underlying.apply(h); - fn first_child(h: &impl HugrView) -> Option { - h.children(h.root()).next() - } if r.is_err() { // Try to restore backup. h.replace_op(h.root(), backup.root_type().clone()) .expect("The root replacement should always match the old root type"); - while let Some(child) = first_child(h) { + while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } h.insert_from_view(h.root(), &backup); diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index ad0b91dc3..6981c7277 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -136,7 +136,7 @@ mod test { endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, }; - use crate::extension::prelude::QB_T; + use crate::extension::prelude::qb_t; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::hugr::rewrite::inline_dfg::InlineDFGError; use crate::hugr::HugrMut; @@ -244,13 +244,13 @@ mod test { #[test] fn permutation() -> Result<(), Box> { - let mut h = DFGBuilder::new(endo_sig(type_row![QB_T, QB_T]))?; + let mut h = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [p, q] = h.input_wires_arr(); let [p_h] = h .add_dataflow_op(test_quantum_extension::h_gate(), [p])? .outputs_arr(); let swap = { - let swap = h.dfg_builder(Signature::new_endo(type_row![QB_T, QB_T]), [p_h, q])?; + let swap = h.dfg_builder(Signature::new_endo(vec![qb_t(), qb_t()]), [p_h, q])?; let [a, b] = swap.input_wires_arr(); swap.finish_with_outputs([b, a])? }; @@ -339,11 +339,11 @@ mod test { PRELUDE.to_owned(), ]) .unwrap(); - let mut outer = DFGBuilder::new(endo_sig(type_row![QB_T, QB_T]))?; + let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [a, b] = outer.input_wires_arr(); let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?; let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?; - let mut inner = outer.dfg_builder(endo_sig(QB_T), h_b.outputs())?; + let mut inner = outer.dfg_builder(endo_sig(qb_t()), h_b.outputs())?; let [i] = inner.input_wires_arr(); let f = inner.add_load_value(float_types::ConstF64::new(1.0)); inner.add_other_wire(inner.input().node(), f.node()); diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/rewrite/insert_identity.rs index 6d8108b0c..0c1dc872a 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/rewrite/insert_identity.rs @@ -101,10 +101,8 @@ mod tests { use super::super::simple_replace::test::dfg_hugr; use super::*; - use crate::{ - extension::{prelude::QB_T, PRELUDE_REGISTRY}, - Hugr, - }; + use crate::utils::test_quantum_extension; + use crate::{extension::prelude::qb_t, Hugr}; #[rstest] fn correct_insertion(dfg_hugr: Hugr) { @@ -127,8 +125,8 @@ mod tests { let noop: Noop = h.get_optype(noop_node).cast().unwrap(); - assert_eq!(noop, Noop(QB_T)); + assert_eq!(noop, Noop(qb_t())); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } } diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index 7dd181f92..1b9a47a1a 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -251,14 +251,14 @@ mod test { BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, }; - use crate::extension::prelude::USIZE_T; + use crate::extension::prelude::usize_t; use crate::extension::PRELUDE_REGISTRY; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::constant::Value; use crate::ops::handle::{BasicBlockID, CfgID, ConstID, NodeHandle}; use crate::types::Signature; - use crate::{type_row, Hugr, HugrView, Node}; + use crate::{Hugr, HugrView, Node}; use cool_asserts::assert_matches; use itertools::Itertools; use rstest::rstest; @@ -278,7 +278,7 @@ mod test { } impl CondThenLoopCfg { fn new() -> Result { - let block_ty = Signature::new_endo(USIZE_T); + let block_ty = Signature::new_endo(usize_t()); let mut cfg_builder = CFGBuilder::new(block_ty.clone())?; let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); @@ -295,7 +295,7 @@ mod test { }; let entry = n_identity( - cfg_builder.simple_entry_builder(USIZE_T.into(), 2)?, + cfg_builder.simple_entry_builder(usize_t().into(), 2)?, &pred_const, )?; @@ -311,7 +311,7 @@ mod test { let head = id_block(&mut cfg_builder)?; cfg_builder.branch(&merge, 0, &head)?; let tail = n_identity( - cfg_builder.simple_block_builder(Signature::new_endo(USIZE_T), 2)?, + cfg_builder.simple_block_builder(Signature::new_endo(usize_t()), 2)?, &pred_const, )?; cfg_builder.branch(&tail, 1, &head)?; @@ -439,10 +439,7 @@ mod test { // operating via a SiblingMut let mut module_builder = ModuleBuilder::new(); let mut fbuild = module_builder - .define_function( - "main", - Signature::new(type_row![USIZE_T], type_row![USIZE_T]), - ) + .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()])) .unwrap(); let [i1] = fbuild.input_wires_arr(); let cfg = fbuild diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 8967df9a5..5d770af4b 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -314,11 +314,7 @@ impl Rewrite for Replacement { for (new_parent, &old_parent) in self.adoptions.iter() { let new_parent = node_map.get(new_parent).unwrap(); debug_assert!(h.children(old_parent).next().is_some()); - loop { - let ch = match h.children(old_parent).next() { - None => break, - Some(c) => c, - }; + while let Some(ch) = h.first_child(old_parent) { h.set_parent(ch, *new_parent); } } @@ -451,7 +447,7 @@ mod test { endo_sig, BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, SubContainer, }; - use crate::extension::prelude::{BOOL_T, USIZE_T}; + use crate::extension::prelude::{bool_t, usize_t}; use crate::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::rewrite::replace::WhichHugr; @@ -462,8 +458,8 @@ mod test { use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG}; use crate::std_extensions::collections::{self, list_type, ListOp}; use crate::types::{Signature, Type, TypeRow}; - use crate::utils::depth; - use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; + use crate::utils::{depth, test_quantum_extension}; + use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; @@ -473,17 +469,17 @@ mod test { let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()]) .unwrap(); - let listy = list_type(USIZE_T); + let listy = list_type(usize_t()); let pop: ExtensionOp = ListOp::pop - .with_type(USIZE_T) + .with_type(usize_t()) .to_extension_op(®) .unwrap(); let push: ExtensionOp = ListOp::push - .with_type(USIZE_T) + .with_type(usize_t()) .to_extension_op(®) .unwrap(); let just_list = TypeRow::from(vec![listy.clone()]); - let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]); + let intermed = TypeRow::from(vec![listy.clone(), usize_t()]); let mut cfg = CFGBuilder::new(endo_sig(just_list.clone()))?; @@ -638,12 +634,30 @@ mod test { #[test] fn test_invalid() { - let mut new_ext = crate::Extension::new_test("new_ext".try_into().unwrap()); - let ext_name = new_ext.name().clone(); - let utou = Signature::new_endo(vec![USIZE_T]); - let mut mk_op = |s| new_ext.simple_ext_op(s, utou.clone()); + let utou = Signature::new_endo(vec![usize_t()]); + let ext = Extension::new_test_arc("new_ext".try_into().unwrap(), |ext, extension_ref| { + ext.add_op("foo".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + ext.add_op("bar".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + }); + let ext_name = ext.name().clone(); + let foo = ext + .instantiate_extension_op("foo", [], &PRELUDE_REGISTRY) + .unwrap(); + let bar = ext + .instantiate_extension_op("bar", [], &PRELUDE_REGISTRY) + .unwrap(); + let baz = ext + .instantiate_extension_op("baz", [], &PRELUDE_REGISTRY) + .unwrap(); + let mut registry = test_quantum_extension::REG.clone(); + registry.register(ext).unwrap(); + let mut h = DFGBuilder::new( - Signature::new(type_row![USIZE_T, BOOL_T], type_row![USIZE_T]) + Signature::new(vec![usize_t(), bool_t()], vec![usize_t()]) .with_extension_delta(ext_name.clone()), ) .unwrap(); @@ -651,34 +665,28 @@ mod test { let mut cond = h .conditional_builder_exts( (vec![type_row![]; 2], b), - [(USIZE_T, i)], - type_row![USIZE_T], + [(usize_t(), i)], + vec![usize_t()].into(), ext_name.clone(), ) .unwrap(); let mut case1 = cond.case_builder(0).unwrap(); - let foo = case1 - .add_dataflow_op(mk_op("foo"), case1.input_wires()) - .unwrap(); + let foo = case1.add_dataflow_op(foo, case1.input_wires()).unwrap(); let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node(); let mut case2 = cond.case_builder(1).unwrap(); - let bar = case2 - .add_dataflow_op(mk_op("bar"), case2.input_wires()) - .unwrap(); + let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap(); let mut baz_dfg = case2 .dfg_builder( utou.clone().with_extension_delta(ext_name.clone()), bar.outputs(), ) .unwrap(); - let baz = baz_dfg - .add_dataflow_op(mk_op("baz"), baz_dfg.input_wires()) - .unwrap(); + let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap(); let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap(); let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); let cond = cond.finish_sub_container().unwrap(); let h = h - .finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY) + .finish_hugr_with_outputs(cond.outputs(), ®istry) .unwrap(); let mut r_hugr = Hugr::new(h.get_optype(cond.node()).clone()); diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 3018adce0..e5dc42841 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -1,13 +1,17 @@ //! Implementation of the `SimpleReplace` operation. -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; +use crate::hugr::hugrmut::InsertionResult; +pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; +use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; -use crate::{Hugr, IncomingPort, Node, OutgoingPort}; +use crate::{Hugr, IncomingPort, Node}; use thiserror::Error; +use super::inline_dfg::InlineDFGError; + /// Specification of a simple replacement operation. #[derive(Debug, Clone)] pub struct SimpleReplacement { @@ -62,130 +66,139 @@ impl Rewrite for SimpleReplacement { unimplemented!() } - fn apply(mut self, h: &mut impl HugrMut) -> Result { - let parent = self.subgraph.get_parent(h); + fn apply(self, h: &mut impl HugrMut) -> Result { + let Self { + subgraph, + replacement, + nu_inp, + nu_out, + } = self; + let parent = subgraph.get_parent(h); // 1. Check the parent node exists and is a DataflowParent. if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) { return Err(SimpleReplacementError::InvalidParentNode()); } // 2. Check that all the to-be-removed nodes are children of it and are leaves. - for node in self.subgraph.nodes() { + for node in subgraph.nodes() { if h.get_parent(*node) != Some(parent) || h.children(*node).next().is_some() { return Err(SimpleReplacementError::InvalidRemovedNode()); } } - // 3. Do the replacement. - // 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes. - // Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self). - let mut index_map: HashMap = HashMap::new(); - let replacement_nodes = self - .replacement - .children(self.replacement.root()) - .collect::>(); - // slice of nodes omitting Input and Output: - let replacement_inner_nodes = &replacement_nodes[2..]; - let self_output_node = h.children(parent).nth(1).unwrap(); - let replacement_output_node = *replacement_nodes.get(1).unwrap(); - for &node in replacement_inner_nodes { - // Add the nodes. - let op: &OpType = self.replacement.get_optype(node); - let new_node = h.add_node_after(self_output_node, op.clone()); - index_map.insert(node, new_node); - - // Move the metadata - let meta: Option = self.replacement.take_node_metadata(node); - h.overwrite_node_metadata(new_node, meta); - } - // Add edges between all newly added nodes matching those in replacement. - for &node in replacement_inner_nodes { - let new_node = index_map.get(&node).unwrap(); - for outport in self.replacement.node_outputs(node) { - for target in self.replacement.linked_inputs(node, outport) { - if self.replacement.get_optype(target.0).tag() != OpTag::Output { - let new_target = index_map.get(&target.0).unwrap(); - h.connect(*new_node, outport, *new_target, target.1); - } - } - } - } + let replacement_output_node = replacement + .get_io(replacement.root()) + .expect("parent already checked.")[1]; + + // 3. Do the replacement. // Now we proceed to connect the edges between the newly inserted // replacement and the rest of the graph. // - // We delay creating these connections to avoid them getting mixed with - // the pre-existing ones in the following logic. - // // Existing connections to the removed subgraph will be automatically // removed when the nodes are removed. - let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new(); - // 3.2. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the + // 3.1. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the // predecessor of p to (the new copy of) q. - for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp { - if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output { - // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) - let (rem_inp_pred_node, rem_inp_pred_port) = h - .single_linked_output(*rem_inp_node, *rem_inp_port) + let nu_inp_connects: Vec<_> = nu_inp + .iter() + .filter(|&((rep_inp_node, _), _)| { + replacement.get_optype(*rep_inp_node).tag() != OpTag::Output + }) + .map( + |((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| { + // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) + let (rem_inp_pred_node, rem_inp_pred_port) = h + .single_linked_output(*rem_inp_node, *rem_inp_port) + .unwrap(); + ( + rem_inp_pred_node, + rem_inp_pred_port, + // the new input node will be updated after insertion + rep_inp_node, + rep_inp_port, + ) + }, + ) + .collect(); + + // 3.2. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an + // edge from (the new copy of) the predecessor of q to p. + let nu_out_connects: Vec<_> = nu_out + .iter() + .filter_map(|((rem_out_node, rem_out_port), rep_out_port)| { + let (rep_out_pred_node, rep_out_pred_port) = replacement + .single_linked_output(replacement_output_node, *rep_out_port) .unwrap(); - let new_inp_node = index_map.get(rep_inp_node).unwrap(); - connect.insert(( - rem_inp_pred_node, - rem_inp_pred_port, - *new_inp_node, - *rep_inp_port, - )); - } + (replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({ + ( + // the new output node will be updated after insertion + rep_out_pred_node, + rep_out_pred_port, + rem_out_node, + rem_out_port, + ) + }) + }) + .collect(); + + // 3.3. Insert the replacement as a whole. + let InsertionResult { + new_root, + node_map: index_map, + } = h.insert_hugr(parent, replacement); + + // remove the Input and Output nodes from the replacement graph + let replace_children = h.children(new_root).collect::>(); + for &io in &replace_children[..2] { + h.remove_node(io); } - // 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an - // edge from (the new copy of) the predecessor of q to p. - for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { - let (rep_out_pred_node, rep_out_pred_port) = self - .replacement - .single_linked_output(replacement_output_node, *rep_out_port) - .unwrap(); - if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input { - let new_out_node = index_map.get(&rep_out_pred_node).unwrap(); - connect.insert(( - *new_out_node, - rep_out_pred_port, - *rem_out_node, - *rem_out_port, - )); - } + // make all replacement top level children children of the parent + for &child in &replace_children[2..] { + h.set_parent(child, parent); + } + // remove the replacement root (which now has no children and no edges) + h.remove_node(new_root); + + // 3.4. Update replacement nodes according to insertion mapping and connect + for (src_node, src_port, tgt_node, tgt_port) in nu_inp_connects { + h.connect( + src_node, + src_port, + *index_map.get(tgt_node).unwrap(), + *tgt_port, + ) } - // 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 + + for (src_node, src_port, tgt_node, tgt_port) in nu_out_connects { + h.connect( + *index_map.get(&src_node).unwrap(), + src_port, + *tgt_node, + *tgt_port, + ) + } + // 3.5. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 // to p1. // // i.e. the replacement graph has direct edges between the input and output nodes. - for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out { - let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port)); + for ((rem_out_node, rem_out_port), &rep_out_port) in &nu_out { + let rem_inp_nodeport = nu_inp.get(&(replacement_output_node, rep_out_port)); if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { // add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port): let (rem_inp_pred_node, rem_inp_pred_port) = h .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); - // Delay connecting the nodes until after processing all nu_out - // entries. - // - // Otherwise, we might disconnect other wires in `rem_inp_node` - // that are needed for the following iterations. - connect.insert(( + + h.connect( rem_inp_pred_node, rem_inp_pred_port, *rem_out_node, *rem_out_port, - )); + ); } } - connect - .into_iter() - .for_each(|(src_node, src_port, tgt_node, tgt_port)| { - h.connect(src_node, src_port, tgt_node, tgt_port); - }); - - // 3.5. Remove all nodes in self.removal and edges between them. - Ok(self - .subgraph + + // 3.6. Remove all nodes in subgraph and edges between them. + Ok(subgraph .nodes() .iter() .map(|&node| (node, h.remove_node(node))) @@ -213,6 +226,9 @@ pub enum SimpleReplacementError { /// Node in replacement graph is invalid. #[error("A node in the replacement graph is invalid.")] InvalidReplacementNode(), + /// Inlining replacement failed. + #[error("Inlining replacement failed: {0}")] + InliningFailed(#[from] InlineDFGError), } #[cfg(test)] @@ -221,12 +237,13 @@ pub(in crate::hugr::rewrite) mod test { use rstest::{fixture, rstest}; use std::collections::{HashMap, HashSet}; + use crate::builder::test::n_identity; use crate::builder::{ endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }; - use crate::extension::prelude::BOOL_T; - use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; + use crate::extension::prelude::{bool_t, qb_t}; + use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; use crate::ops::dataflow::DataflowOpTrait; @@ -235,15 +252,12 @@ pub(in crate::hugr::rewrite) mod test { use crate::ops::OpTrait; use crate::std_extensions::logic::test::and_op; use crate::std_extensions::logic::LogicOp; - use crate::type_row; use crate::types::{Signature, Type}; - use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; + use crate::utils::test_quantum_extension::{self, cx_gate, h_gate, EXTENSION_ID}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; - const QB: Type = crate::extension::prelude::QB_T; - /// Creates a hugr like the following: /// -- H -- /// -- [DFG] -- @@ -259,14 +273,16 @@ pub(in crate::hugr::rewrite) mod test { let just_q: ExtensionSet = EXTENSION_ID.into(); let mut func_builder = module_builder.define_function( "main", - Signature::new_endo(type_row![QB, QB, QB]).with_extension_delta(just_q.clone()), + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) + .with_extension_delta(just_q.clone()), )?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?; - let mut inner_builder = func_builder.dfg_builder_endo([(QB, qb0), (QB, qb1)])?; + let mut inner_builder = + func_builder.dfg_builder_endo([(qb_t(), qb0), (qb_t(), qb1)])?; let inner_graph = { let [wire0, wire1] = inner_builder.input_wires_arr(); let wire2 = inner_builder.add_dataflow_op(h_gate(), vec![wire0])?; @@ -281,7 +297,7 @@ pub(in crate::hugr::rewrite) mod test { func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))? }; - Ok(module_builder.finish_prelude_hugr()?) + Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?) } #[fixture] @@ -295,13 +311,13 @@ pub(in crate::hugr::rewrite) mod test { /// ┤ H ├┤ X ├ /// └───┘└───┘ fn make_dfg_hugr() -> Result { - let mut dfg_builder = DFGBuilder::new(endo_sig(type_row![QB, QB]).with_prelude())?; + let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?; let [wire0, wire1] = dfg_builder.input_wires_arr(); let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?; let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; let wire45 = dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?; - dfg_builder.finish_prelude_hugr_with_outputs(wire45.outputs()) + dfg_builder.finish_hugr_with_outputs(wire45.outputs(), &test_quantum_extension::REG) } #[fixture] @@ -315,13 +331,13 @@ pub(in crate::hugr::rewrite) mod test { /// ┤ H ├ /// └───┘ fn make_dfg_hugr2() -> Result { - let mut dfg_builder = DFGBuilder::new(endo_sig(type_row![QB, QB]))?; + let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [wire0, wire1] = dfg_builder.input_wires_arr(); let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; let wire2out = wire2.outputs().exactly_one().unwrap(); let wireoutvec = vec![wire0, wire2out]; - dfg_builder.finish_prelude_hugr_with_outputs(wireoutvec) + dfg_builder.finish_hugr_with_outputs(wireoutvec, &test_quantum_extension::REG) } #[fixture] @@ -329,7 +345,7 @@ pub(in crate::hugr::rewrite) mod test { make_dfg_hugr2().unwrap() } - /// A hugr with a DFG root mapping BOOL_T to (BOOL_T, BOOL_T) + /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t()) /// ┌─────────┐ /// ┌────┤ (1) NOT ├── /// ┌─────────┐ │ └─────────┘ @@ -343,7 +359,7 @@ pub(in crate::hugr::rewrite) mod test { #[fixture] pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { let mut dfg_builder = - DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); + DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap(); @@ -356,13 +372,13 @@ pub(in crate::hugr::rewrite) mod test { ( dfg_builder - .finish_prelude_hugr_with_outputs([b0, b1]) + .finish_hugr_with_outputs([b0, b1], &test_quantum_extension::REG) .unwrap(), vec![not_inp.node(), not_0.node(), not_1.node()], ) } - /// A hugr with a DFG root mapping BOOL_T to (BOOL_T, BOOL_T) + /// A hugr with a DFG root mapping bool_t() to (bool_t(), bool_t()) /// ┌─────────┐ /// ┌────┤ (1) NOT ├── /// ┌─────────┐ │ └─────────┘ @@ -376,7 +392,7 @@ pub(in crate::hugr::rewrite) mod test { #[fixture] pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { let mut dfg_builder = - DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); + DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap(); @@ -388,7 +404,7 @@ pub(in crate::hugr::rewrite) mod test { ( dfg_builder - .finish_prelude_hugr_with_outputs([b0, b1]) + .finish_hugr_with_outputs([b0, b1], &test_quantum_extension::REG) .unwrap(), vec![not_inp.node(), not_0.node()], ) @@ -473,7 +489,7 @@ pub(in crate::hugr::rewrite) mod test { // ├───┤├───┤┌─┴─┐ // ┤ H ├┤ H ├┤ X ├ // └───┘└───┘└───┘ - assert_eq!(h.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(h.update_validate(&test_quantum_extension::REG), Ok(())); } #[rstest] @@ -545,19 +561,21 @@ pub(in crate::hugr::rewrite) mod test { // ├───┤├───┤┌───┐ // ┤ H ├┤ H ├┤ H ├ // └───┘└───┘└───┘ - assert_eq!(h.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(h.update_validate(&test_quantum_extension::REG), Ok(())); } #[test] fn test_replace_cx_cross() { - let q_row: Vec = vec![QB, QB]; + let q_row: Vec = vec![qb_t(), qb_t()]; let mut builder = DFGBuilder::new(endo_sig(q_row)).unwrap(); let mut circ = builder.as_circuit(builder.input_wires()); circ.append(cx_gate(), [0, 1]).unwrap(); circ.append(cx_gate(), [1, 0]).unwrap(); let wires = circ.finish(); let [input, output] = builder.io(); - let mut h = builder.finish_prelude_hugr_with_outputs(wires).unwrap(); + let mut h = builder + .finish_hugr_with_outputs(wires, &test_quantum_extension::REG) + .unwrap(); let replacement = h.clone(); let orig = h.clone(); @@ -606,8 +624,8 @@ pub(in crate::hugr::rewrite) mod test { #[test] fn test_replace_after_copy() { - let one_bit = type_row![BOOL_T]; - let two_bit = type_row![BOOL_T, BOOL_T]; + let one_bit = vec![bool_t()]; + let two_bit = vec![bool_t(), bool_t()]; let mut builder = DFGBuilder::new(endo_sig(one_bit.clone())).unwrap(); let inw = builder.input_wires().exactly_one().unwrap(); @@ -616,13 +634,17 @@ pub(in crate::hugr::rewrite) mod test { .unwrap() .outputs(); let [input, _] = builder.io(); - let mut h = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap(); + let mut h = builder + .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) + .unwrap(); let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap(); let inw = builder.input_wires(); let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs(); let [repl_input, repl_output] = builder.io(); - let repl = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap(); + let repl = builder + .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) + .unwrap(); let orig = h.clone(); @@ -668,8 +690,8 @@ pub(in crate::hugr::rewrite) mod test { let [_input, output] = hugr.get_io(hugr.root()).unwrap(); let replacement = { - let b = DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])) - .unwrap(); + let b = + DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [w] = b.input_wires_arr(); b.finish_prelude_hugr_with_outputs([w, w]).unwrap() }; @@ -726,12 +748,13 @@ pub(in crate::hugr::rewrite) mod test { let (replacement, repl_not) = { let mut b = - DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); + DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [w] = b.input_wires_arr(); let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap(); let [w_not] = not.outputs_arr(); ( - b.finish_prelude_hugr_with_outputs([w, w_not]).unwrap(), + b.finish_hugr_with_outputs([w, w_not], &test_quantum_extension::REG) + .unwrap(), not.node(), ) }; @@ -774,6 +797,51 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(hugr.node_count(), 4); } + #[rstest] + fn test_nested_replace(dfg_hugr2: Hugr) { + // replace a node with a hugr with children + + let mut h = dfg_hugr2; + let h_node = h + .nodes() + .find(|node: &Node| *h.get_optype(*node) == h_gate().into()) + .unwrap(); + + // build a nested identity dfg + let mut nest_build = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap(); + let [input] = nest_build.input_wires_arr(); + let inner_build = nest_build.dfg_builder_endo([(qb_t(), input)]).unwrap(); + let inner_dfg = n_identity(inner_build).unwrap(); + let inner_dfg_node = inner_dfg.node(); + let replacement = nest_build + .finish_prelude_hugr_with_outputs([inner_dfg.out_wire(0)]) + .unwrap(); + let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap(); + let nu_inp = vec![( + (inner_dfg_node, IncomingPort::from(0)), + (h_node, IncomingPort::from(0)), + )] + .into_iter() + .collect(); + + let nu_out = vec![( + (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)), + IncomingPort::from(0), + )] + .into_iter() + .collect(); + + let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out); + + assert_eq!(h.node_count(), 4); + + rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); + h.update_validate(&PRELUDE_REGISTRY) + .unwrap_or_else(|e| panic!("{e}")); + + assert_eq!(h.node_count(), 6); + } + use crate::hugr::rewrite::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec}; diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 5afd19cd4..dba525d39 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -6,14 +6,16 @@ use crate::builder::{ DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T}; +use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; use crate::extension::simple_op::MakeRegisteredOp; +use crate::extension::ExtensionId; use crate::extension::{test::SimpleOpDef, ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::validate::ValidationError; -use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError}; +use crate::hugr::ExtensionResolutionError; +use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{self, dataflow::IOTrait, Input, Module, Output, Value, DFG}; -use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; +use crate::std_extensions::arithmetic::float_types::float64_type; use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY; use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use crate::std_extensions::logic::LogicOp; @@ -22,6 +24,7 @@ use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, TypeRV, }; +use crate::utils::test_quantum_extension; use crate::{type_row, OutgoingPort}; use itertools::Itertools; @@ -31,9 +34,6 @@ use portgraph::LinkView; use portgraph::{multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, UnmanagedDenseMap}; use rstest::rstest; -const NAT: Type = crate::extension::prelude::USIZE_T; -const QB: Type = crate::extension::prelude::QB_T; - /// Version 1 of the Testing HUGR serialization format, see `testing_hugr.py`. #[derive(Serialize, Deserialize, PartialEq, Debug, Default)] struct SerTestingLatest { @@ -248,11 +248,11 @@ fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType { let outputs = g.num_outputs(node); match (inputs == 0, outputs == 0) { (false, false) => DFG { - signature: Signature::new(vec![NAT; inputs - 1], vec![NAT; outputs - 1]), + signature: Signature::new(vec![usize_t(); inputs - 1], vec![usize_t(); outputs - 1]), } .into(), - (true, false) => Input::new(vec![NAT; outputs - 1]).into(), - (false, true) => Output::new(vec![NAT; inputs - 1]).into(), + (true, false) => Input::new(vec![usize_t(); outputs - 1]).into(), + (false, true) => Output::new(vec![usize_t(); inputs - 1]).into(), (true, true) => Module::new().into(), } } @@ -300,7 +300,7 @@ fn weighted_hugr_ser() { let mut module_builder = ModuleBuilder::new(); module_builder.set_metadata("name", "test"); - let t_row = vec![Type::new_sum([type_row![NAT], type_row![QB]])]; + let t_row = vec![Type::new_sum([vec![usize_t()], vec![qb_t()]])]; let mut f_build = module_builder .define_function("main", Signature::new(t_row.clone(), t_row).with_prelude()) .unwrap(); @@ -325,13 +325,16 @@ fn weighted_hugr_ser() { #[test] fn dfg_roundtrip() -> Result<(), Box> { - let tp: Vec = vec![BOOL_T; 2]; + let tp: Vec = vec![bool_t(); 2]; let mut dfg = DFGBuilder::new(Signature::new(tp.clone(), tp).with_prelude())?; let mut params: [_; 2] = dfg.input_wires_arr(); for p in params.iter_mut() { - *p = dfg.add_dataflow_op(Noop(BOOL_T), [*p]).unwrap().out_wire(0); + *p = dfg + .add_dataflow_op(Noop(bool_t()), [*p]) + .unwrap() + .out_wire(0); } - let hugr = dfg.finish_hugr_with_outputs(params, &EMPTY_REG)?; + let hugr = dfg.finish_hugr_with_outputs(params, &test_quantum_extension::REG)?; check_hugr_roundtrip(&hugr, true); Ok(()) @@ -339,7 +342,7 @@ fn dfg_roundtrip() -> Result<(), Box> { #[test] fn extension_ops() -> Result<(), Box> { - let tp: Vec = vec![BOOL_T; 1]; + let tp: Vec = vec![bool_t(); 1]; let mut dfg = DFGBuilder::new(endo_sig(tp))?; let [wire] = dfg.input_wires_arr(); @@ -350,7 +353,7 @@ fn extension_ops() -> Result<(), Box> { .unwrap() .out_wire(0); - let hugr = dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY)?; + let hugr = dfg.finish_hugr_with_outputs([wire], &test_quantum_extension::REG)?; check_hugr_roundtrip(&hugr, true); Ok(()) @@ -358,7 +361,7 @@ fn extension_ops() -> Result<(), Box> { #[test] fn opaque_ops() -> Result<(), Box> { - let tp: Vec = vec![BOOL_T; 1]; + let tp: Vec = vec![bool_t(); 1]; let mut dfg = DFGBuilder::new(endo_sig(tp))?; let [wire] = dfg.input_wires_arr(); @@ -368,6 +371,7 @@ fn opaque_ops() -> Result<(), Box> { .add_dataflow_op(extension_op.clone(), [wire]) .unwrap() .out_wire(0); + let not_node = wire.node(); // Add an unresolved opaque operation let opaque_op: OpaqueOp = extension_op.into(); @@ -376,11 +380,14 @@ fn opaque_ops() -> Result<(), Box> { assert_eq!( dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY), - Err(ValidationError::OpaqueOpError(OpaqueOpError::UnresolvedOp( - wire.node(), - "Not".into(), - ext_name - )) + Err(ValidationError::ExtensionResolutionError( + ExtensionResolutionError::MissingOpExtension { + node: not_node, + op: "logic.Not".into(), + missing_extension: ext_name, + available_extensions: vec![ExtensionId::new("prelude").unwrap()] + } + ) .into()) ); @@ -389,7 +396,7 @@ fn opaque_ops() -> Result<(), Box> { #[test] fn function_type() -> Result<(), Box> { - let fn_ty = Type::new_function(Signature::new_endo(type_row![BOOL_T]).with_prelude()); + let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()]).with_prelude()); let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]).with_prelude())?; let op = bldr.add_dataflow_op(Noop(fn_ty), bldr.input_wires())?; let h = bldr.finish_prelude_hugr_with_outputs(op.outputs())?; @@ -400,12 +407,12 @@ fn function_type() -> Result<(), Box> { #[test] fn hierarchy_order() -> Result<(), Box> { - let mut hugr = closed_dfg_root_hugr(Signature::new(vec![QB], vec![QB])); + let mut hugr = closed_dfg_root_hugr(Signature::new(vec![qb_t()], vec![qb_t()])); let [old_in, out] = hugr.get_io(hugr.root()).unwrap(); hugr.connect(old_in, 0, out, 0); // Now add a new input - let new_in = hugr.add_node(Input::new([QB].to_vec()).into()); + let new_in = hugr.add_node(Input::new([qb_t()].to_vec()).into()); hugr.disconnect(old_in, OutgoingPort::from(0)); hugr.connect(new_in, 0, out, 0); hugr.move_before_sibling(new_in, old_in); @@ -438,11 +445,11 @@ fn serialize_types_roundtrip() { check_testing_roundtrip(g.clone()); // A Simple tuple - let t = Type::new_tuple(vec![USIZE_T, g]); + let t = Type::new_tuple(vec![usize_t(), g]); check_testing_roundtrip(t); // A Classic sum - let t = TypeRV::new_sum([type_row![USIZE_T], type_row![FLOAT64_TYPE]]); + let t = TypeRV::new_sum([vec![usize_t()], vec![float64_type()]]); check_testing_roundtrip(t); let t = Type::new_unit_sum(4); @@ -450,21 +457,21 @@ fn serialize_types_roundtrip() { } #[rstest] -#[case(BOOL_T)] -#[case(USIZE_T)] +#[case(bool_t())] +#[case(usize_t())] #[case(INT_TYPES[2].clone())] #[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Any)))] #[case(Type::new_var_use(2, TypeBound::Copyable))] -#[case(Type::new_tuple(type_row![BOOL_T,QB_T]))] -#[case(Type::new_sum([type_row![BOOL_T,QB_T], type_row![Type::new_unit_sum(4)]]))] -#[case(Type::new_function(Signature::new_endo(type_row![QB_T,BOOL_T,USIZE_T])))] +#[case(Type::new_tuple(vec![bool_t(),qb_t()]))] +#[case(Type::new_sum([vec![bool_t(),qb_t()], vec![Type::new_unit_sum(4)]]))] +#[case(Type::new_function(Signature::new_endo(vec![qb_t(),bool_t(),usize_t()])))] fn roundtrip_type(#[case] typ: Type) { check_testing_roundtrip(typ); } #[rstest] #[case(SumType::new_unary(2))] -#[case(SumType::new([type_row![USIZE_T, QB_T], type_row![]]))] +#[case(SumType::new([vec![usize_t(), qb_t()].into(), type_row![]]))] fn roundtrip_sumtype(#[case] sum_type: SumType) { check_testing_roundtrip(sum_type); } @@ -506,8 +513,8 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[rstest] #[case(Signature::new_endo(type_row![]).into())] #[case(polyfunctype1())] -#[case(PolyFuncType::new([TypeParam::String], Signature::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncType::new([TypeBound::Copyable.into()], Signature::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncType::new([TypeParam::String], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncType::new([TypeBound::Copyable.into()], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncType::new([TypeParam::new_list(TypeBound::Any)], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( @@ -519,8 +526,8 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[rstest] #[case(FuncValueType::new_endo(type_row![]).into())] -#[case(PolyFuncTypeRV::new([TypeParam::String], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncTypeRV::new([TypeParam::String], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncTypeRV::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( @@ -539,10 +546,10 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Any})] #[case(ops::Const::new(Value::false_val()))] #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] -#[case(ops::Input::new(type_row![Type::new_var_use(3,TypeBound::Copyable)]))] +#[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] #[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(&PRELUDE_ID)} ], &EMPTY_REG).unwrap())] -#[case(ops::CallIndirect { signature : Signature::new_endo(type_row![BOOL_T]) })] +#[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { parent: portgraph::NodeIndex::new(0).into(), @@ -554,7 +561,7 @@ fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { // test all standard extension serialisations are valid against scheme fn std_extensions_valid() { let std_reg = crate::std_extensions::std_reg(); - for (_, ext) in std_reg.into_iter() { + for ext in std_reg { let val = serde_json::to_value(ext).unwrap(); NamedSchema::check_schemas(&val, get_schemas(true)); // check deserialises correctly, can't check equality because of custom binaries. diff --git a/hugr-core/src/hugr/serialize/upgrade/test.rs b/hugr-core/src/hugr/serialize/upgrade/test.rs index 16449afdc..e9837ddc0 100644 --- a/hugr-core/src/hugr/serialize/upgrade/test.rs +++ b/hugr-core/src/hugr/serialize/upgrade/test.rs @@ -1,10 +1,10 @@ use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::BOOL_T, + extension::prelude::bool_t, hugr::serialize::test::check_hugr_deserialize, std_extensions::logic::LogicOp, - type_row, types::Signature, + utils::test_quantum_extension, }; use lazy_static::lazy_static; use std::{ @@ -47,11 +47,11 @@ pub fn empty_hugr() -> Hugr { #[once] pub fn hugr_with_named_op() -> Hugr { let mut builder = - DFGBuilder::new(Signature::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T])).unwrap(); + DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![bool_t()])).unwrap(); let [a, b] = builder.input_wires_arr(); let x = builder.add_dataflow_op(LogicOp::And, [a, b]).unwrap(); builder - .finish_prelude_hugr_with_outputs(x.outputs()) + .finish_hugr_with_outputs(x.outputs(), &test_quantum_extension::REG) .unwrap() } diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 02d7a0fc6..6c9c92753 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,12 +9,13 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; +use crate::extension::resolution::ExtensionResolutionError; use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED}; use crate::ops::constant::ConstTypeError; use crate::ops::custom::{ExtensionOp, OpaqueOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{FuncDefn, OpParent, OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, NamedOp, OpName, OpParent, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, Signature}; use crate::{Direction, Hugr, Node, Port}; @@ -259,7 +260,11 @@ impl<'a, 'b> ValidationContext<'a, 'b> { } self.validate_port_kind(&port_kind, var_decls) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; + .map_err(|cause| ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + })?; let mut link_cnt = 0; for (_, link) in links { @@ -579,7 +584,11 @@ impl<'a, 'b> ValidationContext<'a, 'b> { ext_op .def() .validate_args(ext_op.args(), self.extension_registry, var_decls) - .map_err(|cause| ValidationError::SignatureError { node, cause }) + .map_err(|cause| ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + }) }; match op_type { OpType::ExtensionOp(ext_op) => validate_ext(ext_op)?, @@ -591,12 +600,22 @@ impl<'a, 'b> ValidationContext<'a, 'b> { ))?; } OpType::Call(c) => { - c.validate(self.extension_registry) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; + c.validate(self.extension_registry).map_err(|cause| { + ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + } + })?; } OpType::LoadFunction(c) => { - c.validate(self.extension_registry) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; + c.validate(self.extension_registry).map_err(|cause| { + ValidationError::SignatureError { + node, + op: op_type.name(), + cause, + } + })?; } _ => (), } @@ -738,9 +757,10 @@ pub enum ValidationError { #[error("Node {node} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only")] ExtensionsNotInferred { node: Node }, /// Error in a node signature - #[error("Error in signature of node {node}: {cause}")] + #[error("Error in signature of operation {op} at {node}: {cause}")] SignatureError { node: Node, + op: OpName, #[source] cause: SignatureError, }, @@ -757,6 +777,11 @@ pub enum ValidationError { /// [Type]: crate::types::Type #[error(transparent)] ConstTypeError(#[from] ConstTypeError), + /// Some operations or types in the HUGR reference invalid extensions. + // + // TODO: Remove once `hugr::update_validate` is removed. + #[error(transparent)] + ExtensionResolutionError(#[from] ExtensionResolutionError), } /// Errors related to the inter-graph edge validations. diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index cf934e18b..e98744a82 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -1,5 +1,6 @@ use std::fs::File; use std::io::BufReader; +use std::sync::Arc; use cool_asserts::assert_matches; @@ -10,7 +11,7 @@ use crate::builder::{ FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, QB_T, USIZE_T}; +use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE, PRELUDE_ID}; use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrMut; @@ -25,19 +26,18 @@ use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, TypeRow, }; +use crate::utils::test_quantum_extension; use crate::{ const_extension_ids, test_file, type_row, Direction, IncomingPort, Node, OutgoingPort, }; -const NAT: Type = crate::extension::prelude::USIZE_T; - /// Creates a hugr with a single function definition that copies a bit `copies` times. /// /// Returns the hugr and the node index of the definition. fn make_simple_hugr(copies: usize) -> (Hugr, Node) { let def_op: OpType = ops::FuncDefn { name: "main".into(), - signature: Signature::new(type_row![BOOL_T], vec![BOOL_T; copies]) + signature: Signature::new(vec![bool_t()], vec![bool_t(); copies]) .with_prelude() .into(), } @@ -52,13 +52,13 @@ fn make_simple_hugr(copies: usize) -> (Hugr, Node) { (b, def) } -/// Adds an input{BOOL_T}, copy{BOOL_T -> BOOL_T^copies}, and output{BOOL_T^copies} operation to a dataflow container. +/// Adds an input{bool_t()}, copy{bool_t() -> bool_t()^copies}, and output{bool_t()^copies} operation to a dataflow container. /// /// Returns the node indices of each of the operations. fn add_df_children(b: &mut Hugr, parent: Node, copies: usize) -> (Node, Node, Node) { - let input = b.add_node_with_parent(parent, ops::Input::new(type_row![BOOL_T])); - let output = b.add_node_with_parent(parent, ops::Output::new(vec![BOOL_T; copies])); - let copy = b.add_node_with_parent(parent, Noop(BOOL_T)); + let input = b.add_node_with_parent(parent, ops::Input::new(vec![bool_t()])); + let output = b.add_node_with_parent(parent, ops::Output::new(vec![bool_t(); copies])); + let copy = b.add_node_with_parent(parent, Noop(bool_t())); b.connect(input, 0, copy, 0); for i in 0..copies { @@ -112,7 +112,7 @@ fn invalid_root() { #[test] fn leaf_root() { - let leaf_op: OpType = Noop(USIZE_T).into(); + let leaf_op: OpType = Noop(usize_t()).into(); let b = Hugr::new(leaf_op); assert_eq!(b.validate(&PRELUDE_REGISTRY), Ok(())); @@ -121,20 +121,20 @@ fn leaf_root() { #[test] fn dfg_root() { let dfg_op: OpType = ops::DFG { - signature: Signature::new_endo(type_row![BOOL_T]).with_prelude(), + signature: Signature::new_endo(vec![bool_t()]).with_prelude(), } .into(); let mut b = Hugr::new(dfg_op); let root = b.root(); add_df_children(&mut b, root, 1); - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); } #[test] fn simple_hugr() { let mut b = make_simple_hugr(2).0; - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); } #[test] @@ -150,7 +150,7 @@ fn children_restrictions() { .unwrap(); // Add a definition without children - let def_sig = Signature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]); + let def_sig = Signature::new(vec![bool_t()], vec![bool_t(), bool_t()]); let new_def = b.add_node_with_parent( root, ops::FuncDefn { @@ -159,7 +159,7 @@ fn children_restrictions() { }, ); assert_matches!( - b.update_validate(&EMPTY_REG), + b.update_validate(&test_quantum_extension::REG), Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) ); @@ -167,7 +167,7 @@ fn children_restrictions() { add_df_children(&mut b, new_def, 2); b.set_parent(new_def, copy); assert_matches!( - b.update_validate(&EMPTY_REG), + b.update_validate(&test_quantum_extension::REG), Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) ); b.set_parent(new_def, root); @@ -176,7 +176,7 @@ fn children_restrictions() { // add an input node to the module subgraph let new_input = b.add_node_with_parent(root, ops::Input::new(type_row![])); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)} ); } @@ -193,28 +193,28 @@ fn df_children_restrictions() { .unwrap(); // Replace the output operation of the df subgraph with a copy - b.replace_op(output, Noop(NAT)).unwrap(); + b.replace_op(output, Noop(usize_t())).unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) ); // Revert it back to an output, but with the wrong number of ports - b.replace_op(output, ops::Output::new(type_row![BOOL_T])) + b.replace_op(output, ops::Output::new(vec![bool_t()])) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} ); - b.replace_op(output, ops::Output::new(type_row![BOOL_T, BOOL_T])) + b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])) .unwrap(); // After fixing the output back, replace the copy with an output op - b.replace_op(copy, ops::Output::new(type_row![BOOL_T, BOOL_T])) + b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} ); @@ -223,22 +223,22 @@ fn df_children_restrictions() { #[test] fn test_ext_edge() { let mut h = closed_dfg_root_hugr( - Signature::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]) + Signature::new(vec![bool_t(), bool_t()], vec![bool_t()]) .with_extension_delta(TO_BE_INFERRED), ); let [input, output] = h.get_io(h.root()).unwrap(); - // Nested DFG BOOL_T -> BOOL_T + // Nested DFG bool_t() -> bool_t() let sub_dfg = h.add_node_with_parent( h.root(), ops::DFG { - signature: Signature::new_endo(type_row![BOOL_T]).with_extension_delta(TO_BE_INFERRED), + signature: Signature::new_endo(vec![bool_t()]).with_extension_delta(TO_BE_INFERRED), }, ); // this Xor has its 2nd input unconnected let sub_op = { - let sub_input = h.add_node_with_parent(sub_dfg, ops::Input::new(type_row![BOOL_T])); - let sub_output = h.add_node_with_parent(sub_dfg, ops::Output::new(type_row![BOOL_T])); + let sub_input = h.add_node_with_parent(sub_dfg, ops::Input::new(vec![bool_t()])); + let sub_output = h.add_node_with_parent(sub_dfg, ops::Output::new(vec![bool_t()])); let sub_op = h.add_node_with_parent(sub_dfg, and_op()); h.connect(sub_input, 0, sub_op, 0); h.connect(sub_op, 0, sub_output, 0); @@ -249,26 +249,26 @@ fn test_ext_edge() { h.connect(sub_dfg, 0, output, 0); assert_matches!( - h.update_validate(&EMPTY_REG), + h.update_validate(&test_quantum_extension::REG), Err(ValidationError::UnconnectedPort { .. }) ); h.connect(input, 1, sub_op, 1); assert_matches!( - h.update_validate(&EMPTY_REG), + h.update_validate(&test_quantum_extension::REG), Err(ValidationError::InterGraphEdgeError( InterGraphEdgeError::MissingOrderEdge { .. } )) ); //Order edge. This will need metadata indicating its purpose. h.add_other_edge(input, sub_dfg); - h.update_validate(&EMPTY_REG).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } #[test] fn no_ext_edge_into_func() -> Result<(), Box> { - let b2b = Signature::new_endo(BOOL_T); - let mut h = DFGBuilder::new(Signature::new(BOOL_T, Type::new_function(b2b.clone())))?; + let b2b = Signature::new_endo(bool_t()); + let mut h = DFGBuilder::new(Signature::new(bool_t(), Type::new_function(b2b.clone())))?; let [input] = h.input_wires_arr(); let mut dfg = h.dfg_builder(Signature::new(vec![], Type::new_function(b2b.clone())), [])?; @@ -278,7 +278,7 @@ fn no_ext_edge_into_func() -> Result<(), Box> { let func = func.finish_with_outputs(and_op.outputs())?; let loadfn = dfg.load_func(func.handle(), &[], &EMPTY_REG)?; let dfg = dfg.finish_with_outputs([loadfn])?; - let res = h.finish_hugr_with_outputs(dfg.outputs(), &EMPTY_REG); + let res = h.finish_hugr_with_outputs(dfg.outputs(), &test_quantum_extension::REG); assert_eq!( res, Err(BuildError::InvalidHUGR( @@ -297,17 +297,17 @@ fn no_ext_edge_into_func() -> Result<(), Box> { #[test] fn test_local_const() { let mut h = - closed_dfg_root_hugr(Signature::new_endo(BOOL_T).with_extension_delta(TO_BE_INFERRED)); + closed_dfg_root_hugr(Signature::new_endo(bool_t()).with_extension_delta(TO_BE_INFERRED)); let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_node_with_parent(h.root(), and_op()); h.connect(input, 0, and, 0); h.connect(and, 0, output, 0); assert_eq!( - h.update_validate(&EMPTY_REG), + h.update_validate(&test_quantum_extension::REG), Err(ValidationError::UnconnectedPort { node: and, port: IncomingPort::from(1).into(), - port_kind: EdgeKind::Value(BOOL_T) + port_kind: EdgeKind::Value(bool_t()) }) ); let const_op: ops::Const = logic::EXTENSION @@ -318,18 +318,18 @@ fn test_local_const() { .into(); // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op); - let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: BOOL_T }); + let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: bool_t() }); h.connect(cst, 0, lcst, 0); h.connect(lcst, 0, and, 1); assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: - h.update_validate(&EMPTY_REG).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } #[test] fn dfg_with_cycles() { - let mut h = closed_dfg_root_hugr(Signature::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T])); + let mut h = closed_dfg_root_hugr(Signature::new(vec![bool_t(), bool_t()], vec![bool_t()])); let [input, output] = h.get_io(h.root()).unwrap(); let or = h.add_node_with_parent(h.root(), or_op()); let not1 = h.add_node_with_parent(h.root(), LogicOp::Not); @@ -340,7 +340,10 @@ fn dfg_with_cycles() { h.connect(input, 1, not2, 0); h.connect(not2, 0, output, 0); // The graph contains a cycle: - assert_matches!(h.validate(&EMPTY_REG), Err(ValidationError::NotADag { .. })); + assert_matches!( + h.validate(&test_quantum_extension::REG), + Err(ValidationError::NotADag { .. }) + ); } fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { @@ -362,15 +365,12 @@ fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { } #[test] fn unregistered_extension() { - let (mut h, def) = identity_hugr_with_type(USIZE_T); - assert_eq!( + let (mut h, _def) = identity_hugr_with_type(usize_t()); + assert_matches!( h.validate(&EMPTY_REG), - Err(ValidationError::SignatureError { - node: def, - cause: SignatureError::ExtensionNotFound(PRELUDE.name.clone()) - }) + Err(ValidationError::SignatureError { .. }) ); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.update_validate(&test_quantum_extension::REG).unwrap(); } const_extension_ids! { @@ -378,20 +378,22 @@ const_extension_ids! { } #[test] fn invalid_types() { - let mut e = Extension::new_test(EXT_ID); - e.add_type( - "MyContainer".into(), - vec![TypeBound::Copyable.into()], - "".into(), - TypeDefBound::any(), - ) - .unwrap(); - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()]).unwrap(); + let ext = Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + ext.add_type( + "MyContainer".into(), + vec![TypeBound::Copyable.into()], + "".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + }); + let reg = ExtensionRegistry::try_new([ext.clone(), PRELUDE.clone()]).unwrap(); let validate_to_sig_error = |t: CustomType| { let (h, def) = identity_hugr_with_type(Type::new_extension(t)); match h.validate(®) { - Err(ValidationError::SignatureError { node, cause }) if node == def => cause, + Err(ValidationError::SignatureError { node, cause, .. }) if node == def => cause, e => panic!( "Expected SignatureError at def node, got {}", match e { @@ -404,9 +406,10 @@ fn invalid_types() { let valid = Type::new_extension(CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: USIZE_T }], + vec![TypeArg::Type { ty: usize_t() }], EXT_ID, TypeBound::Any, + &Arc::downgrade(&ext), )); assert_eq!( identity_hugr_with_type(valid.clone()) @@ -421,6 +424,7 @@ fn invalid_types() { vec![TypeArg::Type { ty: valid.clone() }], EXT_ID, TypeBound::Any, + &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(element_outside_bound), @@ -432,9 +436,10 @@ fn invalid_types() { let bad_bound = CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: USIZE_T }], + vec![TypeArg::Type { ty: usize_t() }], EXT_ID, TypeBound::Copyable, + &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(bad_bound.clone()), @@ -452,6 +457,7 @@ fn invalid_types() { }], EXT_ID, TypeBound::Any, + &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(nested), @@ -463,9 +469,13 @@ fn invalid_types() { let too_many_type_args = CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 3 }], + vec![ + TypeArg::Type { ty: usize_t() }, + TypeArg::BoundedNat { n: 3 }, + ], EXT_ID, TypeBound::Any, + &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(too_many_type_args), @@ -587,47 +597,47 @@ fn no_polymorphic_consts() -> Result<(), Box> { Ok(()) } -pub(crate) fn extension_with_eval_parallel() -> Extension { +pub(crate) fn extension_with_eval_parallel() -> Arc { let rowp = TypeParam::new_list(TypeBound::Any); - let mut e = Extension::new_test(EXT_ID); - - let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); - let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); - let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); - let pf = PolyFuncTypeRV::new( - [rowp.clone(), rowp.clone()], - FuncValueType::new(vec![evaled_fn, inputs], outputs), - ); - e.add_op("eval".into(), "".into(), pf).unwrap(); - - let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); - let pf = PolyFuncTypeRV::new( - [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], - Signature::new( - vec![ - Type::new_function(FuncValueType::new(rv(0), rv(2))), - Type::new_function(FuncValueType::new(rv(1), rv(3))), - ], - Type::new_function(FuncValueType::new(vec![rv(0), rv(1)], vec![rv(2), rv(3)])), - ), - ); - e.add_op("parallel".into(), "".into(), pf).unwrap(); + Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); + let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); + let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); + let pf = PolyFuncTypeRV::new( + [rowp.clone(), rowp.clone()], + FuncValueType::new(vec![evaled_fn, inputs], outputs), + ); + ext.add_op("eval".into(), "".into(), pf, extension_ref) + .unwrap(); - e + let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); + let pf = PolyFuncTypeRV::new( + [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], + Signature::new( + vec![ + Type::new_function(FuncValueType::new(rv(0), rv(2))), + Type::new_function(FuncValueType::new(rv(1), rv(3))), + ], + Type::new_function(FuncValueType::new(vec![rv(0), rv(1)], vec![rv(2), rv(3)])), + ), + ); + ext.add_op("parallel".into(), "".into(), pf, extension_ref) + .unwrap(); + }) } #[test] fn instantiate_row_variables() -> Result<(), Box> { fn uint_seq(i: usize) -> TypeArg { - vec![TypeArg::Type { ty: USIZE_T }; i].into() + vec![TypeArg::Type { ty: usize_t() }; i].into() } let e = extension_with_eval_parallel(); let mut dfb = DFGBuilder::new(inout_sig( vec![ - Type::new_function(Signature::new(USIZE_T, vec![USIZE_T, USIZE_T])), - USIZE_T, + Type::new_function(Signature::new(usize_t(), vec![usize_t(), usize_t()])), + usize_t(), ], // inputs: function + its argument - vec![USIZE_T; 4], // outputs (*2^2, three calls) + vec![usize_t(); 4], // outputs (*2^2, three calls) ))?; let [func, int] = dfb.input_wires_arr(); let eval = e.instantiate_extension_op("eval", [uint_seq(1), uint_seq(2)], &PRELUDE_REGISTRY)?; @@ -643,7 +653,7 @@ fn instantiate_row_variables() -> Result<(), Box> { let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; dfb.finish_hugr_with_outputs( eval2.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), )?; Ok(()) } @@ -659,7 +669,7 @@ fn row_variables() -> Result<(), Box> { let e = extension_with_eval_parallel(); let tv = TypeRV::new_row_var_use(0, TypeBound::Any); let inner_ft = Type::new_function(FuncValueType::new_endo(tv.clone())); - let ft_usz = Type::new_function(FuncValueType::new_endo(vec![tv.clone(), USIZE_T.into()])); + let ft_usz = Type::new_function(FuncValueType::new_endo(vec![tv.clone(), usize_t().into()])); let mut fb = FunctionBuilder::new( "id", PolyFuncType::new( @@ -670,60 +680,63 @@ fn row_variables() -> Result<(), Box> { // All the wires here are carrying higher-order Function values let [func_arg] = fb.input_wires_arr(); let id_usz = { - let bldr = fb.define_function("id_usz", Signature::new_endo(USIZE_T))?; + let bldr = fb.define_function("id_usz", Signature::new_endo(usize_t()))?; let vals = bldr.input_wires(); let inner_def = bldr.finish_with_outputs(vals)?; fb.load_func(inner_def.handle(), &[], &PRELUDE_REGISTRY)? }; let par = e.instantiate_extension_op( "parallel", - [tv.clone(), USIZE_T.into(), tv.clone(), USIZE_T.into()].map(seq1ty), + [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(seq1ty), &PRELUDE_REGISTRY, )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs( par_func.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), )?; Ok(()) } #[test] fn test_polymorphic_call() -> Result<(), Box> { - let mut e = Extension::new_test(EXT_ID); - - let params: Vec = vec![ - TypeBound::Any.into(), - TypeParam::Extensions, - TypeBound::Any.into(), - ]; - let evaled_fn = Type::new_function( - Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), - ); - // Single-input/output version of the higher-order "eval" operation, with extension param. - // Note the extension-delta of the eval node includes that of the input function. - e.add_op( - "eval".into(), - "".into(), - PolyFuncTypeRV::new( - params.clone(), + let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let params: Vec = vec![ + TypeBound::Any.into(), + TypeParam::Extensions, + TypeBound::Any.into(), + ]; + let evaled_fn = Type::new_function( Signature::new( - vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(0, TypeBound::Any), Type::new_var_use(2, TypeBound::Any), ) .with_extension_delta(ExtensionSet::type_var(1)), - ), - )?; + ); + // Single-input/output version of the higher-order "eval" operation, with extension param. + // Note the extension-delta of the eval node includes that of the input function. + ext.add_op( + "eval".into(), + "".into(), + PolyFuncTypeRV::new( + params.clone(), + Signature::new( + vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(2, TypeBound::Any), + ) + .with_extension_delta(ExtensionSet::type_var(1)), + ), + extension_ref, + )?; + + Ok(()) + })?; fn utou(e: impl Into) -> Type { - Type::new_function(Signature::new_endo(USIZE_T).with_extension_delta(e.into())) + Type::new_function(Signature::new_endo(usize_t()).with_extension_delta(e.into())) } - let int_pair = Type::new_tuple(type_row![USIZE_T; 2]); + let int_pair = Type::new_tuple(vec![usize_t(); 2]); // Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints let mut d = DFGBuilder::new(inout_sig( vec![utou(PRELUDE_ID), int_pair.clone()], @@ -744,15 +757,19 @@ fn test_polymorphic_call() -> Result<(), Box> { )?; let [func, tup] = f.input_wires_arr(); let mut c = f.conditional_builder( - (vec![type_row![USIZE_T; 2]], tup), + (vec![vec![usize_t(); 2].into()], tup), vec![], - type_row![USIZE_T;2], + vec![usize_t(); 2].into(), )?; let mut cc = c.case_builder(0)?; let [i1, i2] = cc.input_wires_arr(); let op = e.instantiate_extension_op( "eval", - vec![USIZE_T.into(), TypeArg::Extensions { es }, USIZE_T.into()], + vec![ + usize_t().into(), + TypeArg::Extensions { es }, + usize_t().into(), + ], &PRELUDE_REGISTRY, )?; let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); @@ -763,7 +780,7 @@ fn test_polymorphic_call() -> Result<(), Box> { f.finish_with_outputs([tup])? }; - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; + let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; let [func, tup] = d.input_wires_arr(); let call = d.call( f.handle(), @@ -794,10 +811,10 @@ fn test_polymorphic_load() -> Result<(), Box> { )?; let sig = Signature::new( vec![], - vec![Type::new_function(Signature::new_endo(vec![USIZE_T]))], + vec![Type::new_function(Signature::new_endo(vec![usize_t()]))], ); let mut f = m.define_function("main", sig)?; - let l = f.load_func(&id, &[USIZE_T.into()], &PRELUDE_REGISTRY)?; + let l = f.load_func(&id, &[usize_t().into()], &PRELUDE_REGISTRY)?; f.finish_with_outputs([l])?; let _ = m.finish_prelude_hugr()?; Ok(()) @@ -815,16 +832,16 @@ fn cfg_children_restrictions() { .unwrap(); // Write Extension annotations into the Hugr while it's still well-formed // enough for us to compute them - b.validate(&EMPTY_REG).unwrap(); + b.validate(&test_quantum_extension::REG).unwrap(); b.replace_op( copy, ops::CFG { - signature: Signature::new(type_row![BOOL_T], type_row![BOOL_T]), + signature: Signature::new(vec![bool_t()], vec![bool_t()]), }, ) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::ContainerWithoutChildren { .. }) ); let cfg = copy; @@ -833,18 +850,18 @@ fn cfg_children_restrictions() { let block = b.add_node_with_parent( cfg, ops::DataflowBlock { - inputs: type_row![BOOL_T], + inputs: vec![bool_t()].into(), sum_rows: vec![type_row![]], - other_outputs: type_row![BOOL_T], + other_outputs: vec![bool_t()].into(), extension_delta: ExtensionSet::new(), }, ); let const_op: ops::Const = ops::Value::unit_sum(0, 1).unwrap().into(); let tag_type = Type::new_unit_sum(1); { - let input = b.add_node_with_parent(block, ops::Input::new(type_row![BOOL_T])); + let input = b.add_node_with_parent(block, ops::Input::new(vec![bool_t()])); let output = - b.add_node_with_parent(block, ops::Output::new(vec![tag_type.clone(), BOOL_T])); + b.add_node_with_parent(block, ops::Output::new(vec![tag_type.clone(), bool_t()])); let tag_def = b.add_node_with_parent(b.root(), const_op); let tag = b.add_node_with_parent(block, ops::LoadConstant { datatype: tag_type }); @@ -856,11 +873,11 @@ fn cfg_children_restrictions() { let exit = b.add_node_with_parent( cfg, ops::ExitBlock { - cfg_outputs: type_row![BOOL_T], + cfg_outputs: vec![bool_t()].into(), }, ); b.add_other_edge(block, exit); - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); // Test malformed errors @@ -868,11 +885,11 @@ fn cfg_children_restrictions() { let exit2 = b.add_node_after( exit, ops::ExitBlock { - cfg_outputs: type_row![BOOL_T], + cfg_outputs: vec![bool_t()].into(), }, ); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} ); @@ -882,16 +899,16 @@ fn cfg_children_restrictions() { b.replace_op( cfg, ops::CFG { - signature: Signature::new(type_row![QB_T], type_row![BOOL_T]), + signature: Signature::new(vec![qb_t()], vec![bool_t()]), }, ) .unwrap(); b.replace_op( block, ops::DataflowBlock { - inputs: type_row![QB_T], + inputs: vec![qb_t()].into(), sum_rows: vec![type_row![]], - other_outputs: type_row![QB_T], + other_outputs: vec![qb_t()].into(), extension_delta: ExtensionSet::new(), }, ) @@ -899,15 +916,15 @@ fn cfg_children_restrictions() { let mut block_children = b.hierarchy.children(block.pg_index()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); - b.replace_op(block_input, ops::Input::new(type_row![QB_T])) + b.replace_op(block_input, ops::Input::new(vec![qb_t()])) .unwrap(); b.replace_op( block_output, - ops::Output::new(type_row![Type::new_unit_sum(1), QB_T]), + ops::Output::new(vec![Type::new_unit_sum(1), qb_t()]), ) .unwrap(); assert_matches!( - b.validate(&EMPTY_REG), + b.validate(&test_quantum_extension::REG), Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) => assert_eq!(parent, cfg) ); @@ -920,14 +937,15 @@ fn cfg_children_restrictions() { fn cfg_connections() -> Result<(), Box> { use crate::builder::CFGBuilder; - let mut hugr = CFGBuilder::new(Signature::new_endo(USIZE_T))?; + let mut hugr = CFGBuilder::new(Signature::new_endo(usize_t()))?; let unary_pred = hugr.add_constant(Value::unary_unit_sum()); - let mut entry = hugr.simple_entry_builder_exts(type_row![USIZE_T], 1, ExtensionSet::new())?; + let mut entry = + hugr.simple_entry_builder_exts(vec![usize_t()].into(), 1, ExtensionSet::new())?; let p = entry.load_const(&unary_pred); let ins = entry.input_wires(); let entry = entry.finish_with_outputs(p, ins)?; - let mut middle = hugr.simple_block_builder(Signature::new_endo(USIZE_T), 1)?; + let mut middle = hugr.simple_block_builder(Signature::new_endo(usize_t()), 1)?; let p = middle.load_const(&unary_pred); let ins = middle.input_wires(); let middle = middle.finish_with_outputs(p, ins)?; @@ -995,24 +1013,25 @@ mod extension_tests { ) { // Child graph adds extension "XB", but the parent (in all cases) // declares a different delta, causing a mismatch. - let parent = - parent_f(Signature::new_endo(USIZE_T).with_extension_delta(parent_extensions.clone())); + let parent = parent_f( + Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), + ); let mut hugr = Hugr::new(parent); let input = hugr.add_node_with_parent( hugr.root(), ops::Input { - types: type_row![USIZE_T], + types: vec![usize_t()].into(), }, ); let output = hugr.add_node_with_parent( hugr.root(), ops::Output { - types: type_row![USIZE_T], + types: vec![usize_t()].into(), }, ); - let lift = hugr.add_node_with_parent(hugr.root(), Lift::new(type_row![USIZE_T], XB)); + let lift = hugr.add_node_with_parent(hugr.root(), Lift::new(vec![usize_t()].into(), XB)); hugr.connect(input, 0, lift, 0); hugr.connect(lift, 0, output, 0); @@ -1038,9 +1057,9 @@ mod extension_tests { #[case] success: bool, ) -> Result<(), BuildError> { let mut cfg = CFGBuilder::new( - Signature::new_endo(USIZE_T).with_extension_delta(parent_extensions.clone()), + Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), )?; - let mut bb = cfg.simple_entry_builder_exts(USIZE_T.into(), 1, XB)?; + let mut bb = cfg.simple_entry_builder_exts(usize_t().into(), 1, XB)?; let pred = bb.add_load_value(Value::unary_unit_sum()); let inputs = bb.input_wires(); let blk = bb.finish_with_outputs(pred, inputs)?; @@ -1076,8 +1095,8 @@ mod extension_tests { // declares a different delta, in same cases causing a mismatch. let parent = ops::Conditional { sum_rows: vec![type_row![], type_row![]], - other_inputs: type_row![USIZE_T], - outputs: type_row![USIZE_T], + other_inputs: vec![usize_t()].into(), + outputs: vec![usize_t()].into(), extension_delta: parent_extensions.clone(), }; let mut hugr = Hugr::new(parent); @@ -1092,27 +1111,27 @@ mod extension_tests { let case = hugr.add_node_with_parent( hugr.root(), ops::Case { - signature: Signature::new_endo(USIZE_T).with_extension_delta(case_exts), + signature: Signature::new_endo(usize_t()).with_extension_delta(case_exts), }, ); let input = hugr.add_node_with_parent( case, ops::Input { - types: type_row![USIZE_T], + types: vec![usize_t()].into(), }, ); let output = hugr.add_node_with_parent( case, ops::Output { - types: type_row![USIZE_T], + types: vec![usize_t()].into(), }, ); let res = match case_ext { None => input, Some(new_ext) => { let lift = - hugr.add_node_with_parent(case, Lift::new(type_row![USIZE_T], new_ext)); + hugr.add_node_with_parent(case, Lift::new(vec![usize_t()].into(), new_ext)); hugr.connect(input, 0, lift, 0); lift } @@ -1145,8 +1164,8 @@ mod extension_tests { parent_exts_success: (ExtensionSet, bool), ) -> Result<(), BuildError> { let (parent_extensions, success) = parent_exts_success; - let mut dfg = dfg_fn(USIZE_T, parent_extensions.clone()); - let lift = dfg.add_dataflow_op(Lift::new(USIZE_T.into(), XB), dfg.input_wires())?; + let mut dfg = dfg_fn(usize_t(), parent_extensions.clone()); + let lift = dfg.add_dataflow_op(Lift::new(usize_t().into(), XB), dfg.input_wires())?; let pred = make_pred(&mut dfg, lift.outputs())?; let root = dfg.hugr().root(); let res = dfg.finish_prelude_hugr_with_outputs([pred]); diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 7d744c150..e0b68ffaa 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -1,6 +1,7 @@ //! Read-only access into HUGR graphs and subgraphs. pub mod descendants; +mod impls; pub mod petgraph; pub mod render; mod root_checked; @@ -261,6 +262,12 @@ pub trait HugrView: HugrInternals { /// Return iterator over the direct children of node. fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone; + /// Returns the first child of the specified node (if it is a parent). + /// Useful because `x.children().next()` leaves x borrowed. + fn first_child(&self, node: Node) -> Option { + self.children(node).next() + } + /// Iterates over neighbour nodes in the given direction. /// May contain duplicates if the graph has multiple links between nodes. fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone; @@ -513,41 +520,35 @@ impl ExtractHugr for &mut Hugr { } } -impl> HugrView for T { +impl HugrView for Hugr { #[inline] fn contains_node(&self, node: Node) -> bool { - self.as_ref().graph.contains_node(node.pg_index()) + self.graph.contains_node(node.pg_index()) } #[inline] fn node_count(&self) -> usize { - self.as_ref().graph.node_count() + self.graph.node_count() } #[inline] fn edge_count(&self) -> usize { - self.as_ref().graph.link_count() + self.graph.link_count() } #[inline] fn nodes(&self) -> impl Iterator + Clone { - self.as_ref().graph.nodes_iter().map_into() + self.graph.nodes_iter().map_into() } #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.as_ref() - .graph - .port_offsets(node.pg_index(), dir) - .map_into() + self.graph.port_offsets(node.pg_index(), dir).map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { - self.as_ref() - .graph - .all_port_offsets(node.pg_index()) - .map_into() + self.graph.all_port_offsets(node.pg_index()).map_into() } #[inline] @@ -557,54 +558,46 @@ impl> HugrView for T { port: impl Into, ) -> impl Iterator + Clone { let port = port.into(); - let hugr = self.as_ref(); - let port = hugr + + let port = self .graph .port_index(node.pg_index(), port.pg_offset()) .unwrap(); - hugr.graph.port_links(port).map(|(_, link)| { + self.graph.port_links(port).map(|(_, link)| { let port = link.port(); - let node = hugr.graph.port_node(port).unwrap(); - let offset = hugr.graph.port_offset(port).unwrap(); + let node = self.graph.port_node(port).unwrap(); + let offset = self.graph.port_offset(port).unwrap(); (node.into(), offset.into()) }) } #[inline] fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { - let hugr = self.as_ref(); - - hugr.graph + self.graph .get_connections(node.pg_index(), other.pg_index()) .map(|(p1, p2)| { - [p1, p2].map(|link| hugr.graph.port_offset(link.port()).unwrap().into()) + [p1, p2].map(|link| self.graph.port_offset(link.port()).unwrap().into()) }) } #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.as_ref().graph.num_ports(node.pg_index(), dir) + self.graph.num_ports(node.pg_index(), dir) } #[inline] fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { - self.as_ref().hierarchy.children(node.pg_index()).map_into() + self.hierarchy.children(node.pg_index()).map_into() } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.as_ref() - .graph - .neighbours(node.pg_index(), dir) - .map_into() + self.graph.neighbours(node.pg_index(), dir).map_into() } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { - self.as_ref() - .graph - .all_neighbours(node.pg_index()) - .map_into() + self.graph.all_neighbours(node.pg_index()).map_into() } } diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 64082eb7e..f7b893ddf 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -175,20 +175,17 @@ where pub(super) mod test { use rstest::rstest; - use crate::extension::PRELUDE_REGISTRY; + use crate::extension::prelude::{qb_t, usize_t}; + use crate::utils::test_quantum_extension; use crate::IncomingPort; use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, - type_row, - types::{Signature, Type}, + types::Signature, utils::test_quantum_extension::{h_gate, EXTENSION_ID}, }; use super::*; - const NAT: Type = crate::extension::prelude::USIZE_T; - const QB: Type = crate::extension::prelude::QB_T; - /// Make a module hugr with a fn definition containing an inner dfg node. /// /// Returns the hugr, the fn node id, and the nested dgf node id. @@ -199,7 +196,7 @@ pub(super) mod test { let (f_id, inner_id) = { let mut func_builder = module_builder.define_function( "main", - Signature::new_endo(type_row![NAT, QB]).with_extension_delta(EXTENSION_ID), + Signature::new_endo(vec![usize_t(), qb_t()]).with_extension_delta(EXTENSION_ID), )?; let [int, qb] = func_builder.input_wires_arr(); @@ -208,7 +205,7 @@ pub(super) mod test { let inner_id = { let inner_builder = func_builder - .dfg_builder(Signature::new(type_row![NAT], type_row![NAT]), [int])?; + .dfg_builder(Signature::new(vec![usize_t()], vec![usize_t()]), [int])?; let w = inner_builder.input_wires(); inner_builder.finish_with_outputs(w) }?; @@ -217,7 +214,7 @@ pub(super) mod test { func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?; (f_id, inner_id) }; - let hugr = module_builder.finish_prelude_hugr()?; + let hugr = module_builder.finish_hugr(&test_quantum_extension::REG)?; Ok((hugr, f_id.handle().node(), inner_id.handle().node())) } @@ -237,7 +234,7 @@ pub(super) mod test { assert_eq!( region.poly_func_type(), Some( - Signature::new_endo(type_row![NAT, QB]) + Signature::new_endo(vec![usize_t(), qb_t()]) .with_extension_delta(EXTENSION_ID) .into() ) @@ -246,7 +243,7 @@ pub(super) mod test { let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; assert_eq!( inner_region.inner_function_type(), - Some(Signature::new(type_row![NAT], type_row![NAT])) + Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); assert_eq!(inner_region.node_count(), 3); assert_eq!(inner_region.edge_count(), 2); @@ -294,7 +291,7 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; let extracted = region.extract_hugr(); - extracted.validate(&PRELUDE_REGISTRY)?; + extracted.validate(&test_quantum_extension::REG)?; let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs new file mode 100644 index 000000000..7f3f45386 --- /dev/null +++ b/hugr-core/src/hugr/views/impls.rs @@ -0,0 +1,101 @@ +use std::{borrow::Cow, rc::Rc, sync::Arc}; + +use delegate::delegate; + +use super::{HugrView, RootChecked}; +use crate::{Direction, Hugr, Node, Port}; + +macro_rules! hugr_view_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate! { + to ({let $arg=self; $e}) { + fn contains_node(&self, node: Node) -> bool; + fn node_count(&self) -> usize; + fn edge_count(&self) -> usize; + fn nodes(&self) -> impl Iterator + Clone; + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone; + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone; + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone; + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone; + fn num_ports(&self, node: Node, dir: Direction) -> usize; + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone; + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone; + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone; + } + } + } +} + +impl HugrView for &T { + hugr_view_methods! {this, *this} +} + +impl HugrView for &mut T { + hugr_view_methods! {this, &**this} +} + +impl HugrView for Rc { + hugr_view_methods! {this, this.as_ref()} +} + +impl HugrView for Arc { + hugr_view_methods! {this, this.as_ref()} +} + +impl HugrView for Box { + hugr_view_methods! {this, this.as_ref()} +} + +impl HugrView for Cow<'_, T> { + hugr_view_methods! {this, this.as_ref()} +} + +impl, Root> HugrView for RootChecked { + hugr_view_methods! {this, this.as_ref()} +} + +#[cfg(test)] +mod test { + use std::{rc::Rc, sync::Arc}; + + use crate::hugr::views::{DescendantsGraph, HierarchyView}; + use crate::{Hugr, HugrView, Node}; + + struct ViewWrapper(H); + impl ViewWrapper { + fn nodes(&self) -> impl Iterator + '_ { + self.0.nodes() + } + } + + #[test] + fn test_refs_to_view() { + let h = Hugr::default(); + let v = ViewWrapper(&h); + let c = h.nodes().count(); + assert_eq!(v.nodes().count(), c); + let v2 = ViewWrapper(DescendantsGraph::::try_new(&h, h.root()).unwrap()); + // v2 owns the DescendantsGraph, but that only borrows `h`, so we still have both + assert_eq!(v2.nodes().count(), v.nodes().count()); + // And we can borrow the DescendantsGraph, even just a reference to that counts as a HugrView + assert_eq!(ViewWrapper(&v2.0).nodes().count(), v.nodes().count()); + + let vh = ViewWrapper(h); + assert_eq!(vh.nodes().count(), c); + let h: Hugr = vh.0; + assert_eq!(h.nodes().count(), c); + + let vb = ViewWrapper(Box::new(&h)); + assert_eq!(vb.nodes().count(), c); + let va = ViewWrapper(Arc::new(h)); + assert_eq!(va.nodes().count(), c); + let h = Arc::try_unwrap(va.0).unwrap(); + let vr = Rc::new(&h); + assert_eq!(ViewWrapper(&vr).nodes().count(), h.nodes().count()); + } +} diff --git a/hugr-core/src/hugr/views/root_checked.rs b/hugr-core/src/hugr/views/root_checked.rs index c98943966..88b0ad020 100644 --- a/hugr-core/src/hugr/views/root_checked.rs +++ b/hugr-core/src/hugr/views/root_checked.rs @@ -1,6 +1,9 @@ use std::marker::PhantomData; -use crate::hugr::internal::HugrMutInternals; +use delegate::delegate; +use portgraph::MultiPortGraph; + +use crate::hugr::internal::{HugrInternals, HugrMutInternals}; use crate::hugr::{HugrError, HugrMut}; use crate::ops::handle::NodeHandle; use crate::{Hugr, Node}; @@ -45,6 +48,20 @@ impl RootChecked<&mut Hugr, Root> { } } +impl, Root> HugrInternals for RootChecked { + type Portgraph<'p> + = &'p MultiPortGraph + where + Self: 'p; + delegate! { + to self.as_ref() { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn base_hugr(&self) -> &Hugr; + fn root_node(&self) -> Node; + } + } +} + impl, Root: NodeHandle> RootTagged for RootChecked { type RootHandle = Root; } diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 45ed1f759..b9710f00a 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -336,16 +336,13 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; - use crate::extension::PRELUDE_REGISTRY; + use crate::extension::prelude::{qb_t, usize_t}; use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID}; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::ops::{OpTrait, OpType}; - use crate::types::{Signature, Type}; - use crate::utils::test_quantum_extension::EXTENSION_ID; - use crate::{type_row, IncomingPort}; - - const NAT: Type = crate::extension::prelude::USIZE_T; - const QB: Type = crate::extension::prelude::QB_T; + use crate::types::Signature; + use crate::utils::test_quantum_extension::{self, EXTENSION_ID}; + use crate::IncomingPort; use super::super::descendants::test::make_module_hgr; use super::*; @@ -372,7 +369,7 @@ mod test { assert_eq!( region.poly_func_type(), Some( - Signature::new_endo(type_row![NAT, QB]) + Signature::new_endo(vec![usize_t(), qb_t()]) .with_extension_delta(EXTENSION_ID) .into() ) @@ -380,7 +377,7 @@ mod test { assert_eq!( inner_region.inner_function_type(), - Some(Signature::new(type_row![NAT], type_row![NAT])) + Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); assert_eq!(inner_region.node_count(), 3); assert_eq!(inner_region.edge_count(), 1); @@ -453,13 +450,13 @@ mod test { #[test] fn nested_flat() -> Result<(), Box> { let mut module_builder = ModuleBuilder::new(); - let fty = Signature::new(type_row![NAT], type_row![NAT]); + let fty = Signature::new(vec![usize_t()], vec![usize_t()]); let mut fbuild = module_builder.define_function("main", fty.clone())?; let dfg = fbuild.dfg_builder(fty, fbuild.input_wires())?; let ins = dfg.input_wires(); let sub_dfg = dfg.finish_with_outputs(ins)?; let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?; - let h = module_builder.finish_hugr(&PRELUDE_REGISTRY)?; + let h = module_builder.finish_hugr(&test_quantum_extension::REG)?; let sub_dfg = sub_dfg.node(); // We can create a view from a child or grandchild of a hugr: @@ -472,8 +469,8 @@ mod test { // Both ways work: let just_io = vec![ - Input::new(type_row![NAT]).into(), - Output::new(type_row![NAT]).into(), + Input::new(vec![usize_t()]).into(), + Output::new(vec![usize_t()]).into(), ]; for d in [dfg_view, nested_dfg_view] { assert_eq!( @@ -488,7 +485,9 @@ mod test { /// Mutate a SiblingMut wrapper #[rstest] fn flat_mut(mut simple_dfg_hugr: Hugr) { - simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap(); + simple_dfg_hugr + .update_validate(&test_quantum_extension::REG) + .unwrap(); let root = simple_dfg_hugr.root(); let signature = simple_dfg_hugr.inner_function_type().unwrap().clone(); @@ -513,7 +512,9 @@ mod test { // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); - assert!(simple_dfg_hugr.validate(&PRELUDE_REGISTRY).is_err()); + assert!(simple_dfg_hugr + .validate(&test_quantum_extension::REG) + .is_err()); } #[rstest] @@ -542,7 +543,7 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; let extracted = region.extract_hugr(); - extracted.validate(&PRELUDE_REGISTRY)?; + extracted.validate(&test_quantum_extension::REG)?; let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index fad4d835c..ea23e17de 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -795,14 +795,10 @@ mod tests { BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }, - extension::{ - prelude::{BOOL_T, QB_T}, - EMPTY_REG, - }, + extension::prelude::{bool_t, qb_t}, hugr::views::{HierarchyView, SiblingGraph}, ops::handle::{DfgID, FuncID, NodeHandle}, std_extensions::logic::test::and_op, - type_row, }; use super::*; @@ -837,7 +833,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new_endo(type_row![QB_T, QB_T, QB_T]) + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) .with_extension_delta(ExtensionSet::from_iter([ test_quantum_extension::EXTENSION_ID, float_types::EXTENSION_ID, @@ -870,7 +866,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new_endo(type_row![BOOL_T]) + Signature::new_endo(vec![bool_t()]) .with_extension_delta(logic::EXTENSION_ID) .into(), )?; @@ -882,7 +878,7 @@ mod tests { dfg.finish_with_outputs(outs3.outputs())? }; let hugr = mod_builder - .finish_prelude_hugr() + .finish_hugr(&test_quantum_extension::REG) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -892,7 +888,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new(BOOL_T, type_row![BOOL_T, BOOL_T]) + Signature::new(bool_t(), vec![bool_t(), bool_t()]) .with_extension_delta(logic::EXTENSION_ID) .into(), )?; @@ -904,7 +900,7 @@ mod tests { dfg.finish_with_outputs([b1, b2])? }; let hugr = mod_builder - .finish_prelude_hugr() + .finish_hugr(&test_quantum_extension::REG) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -914,7 +910,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new_endo(BOOL_T) + Signature::new_endo(bool_t()) .with_extension_delta(logic::EXTENSION_ID) .into(), )?; @@ -925,7 +921,7 @@ mod tests { dfg.finish_with_outputs(outs.outputs())? }; let hugr = mod_builder - .finish_hugr(&EMPTY_REG) + .finish_hugr(&test_quantum_extension::REG) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -956,7 +952,7 @@ mod tests { let empty_dfg = { let builder = - DFGBuilder::new(Signature::new_endo(type_row![QB_T, QB_T, QB_T])).unwrap(); + DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap(); let inputs = builder.input_wires(); builder.finish_prelude_hugr_with_outputs(inputs).unwrap() }; @@ -979,7 +975,7 @@ mod tests { let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; assert_eq!( sub.signature(&func), - Signature::new_endo(type_row![QB_T, QB_T, QB_T]).with_extension_delta( + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta( ExtensionSet::from_iter([ test_quantum_extension::EXTENSION_ID, float_types::EXTENSION_ID, @@ -996,7 +992,7 @@ mod tests { let sub = SiblingSubgraph::from_sibling_graph(&func)?; let empty_dfg = { - let builder = DFGBuilder::new(Signature::new_endo(type_row![QB_T])).unwrap(); + let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap(); let inputs = builder.input_wires(); builder.finish_prelude_hugr_with_outputs(inputs).unwrap() }; @@ -1157,8 +1153,8 @@ mod tests { #[test] fn edge_both_output_and_copy() { // https://github.com/CQCL/hugr/issues/518 - let one_bit = type_row![BOOL_T]; - let two_bit = type_row![BOOL_T, BOOL_T]; + let one_bit = vec![bool_t()]; + let two_bit = vec![bool_t(), bool_t()]; let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap(); let inw = builder.input_wires().exactly_one().unwrap(); @@ -1171,7 +1167,9 @@ mod tests { .unwrap() .outputs(); let outw = [outw1].into_iter().chain(outw2); - let h = builder.finish_hugr_with_outputs(outw, &EMPTY_REG).unwrap(); + let h = builder + .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) + .unwrap(); let view = SiblingGraph::::try_new(&h, h.root()).unwrap(); let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap(); assert_eq!(subg.nodes().len(), 2); diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index c958c8b9f..f69779082 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -1,11 +1,13 @@ use portgraph::PortOffset; use rstest::{fixture, rstest}; +use crate::std_extensions::logic::LOGIC_REG; +use crate::utils::test_quantum_extension; use crate::{ builder::{ endo_sig, inout_sig, BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, }, - extension::prelude::QB_T, + extension::prelude::qb_t, ops::{ handle::{DataflowOpID, NodeHandle}, Value, @@ -22,7 +24,7 @@ use crate::{ /// Returns the Hugr and the two CX node ids. #[fixture] pub(crate) fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle) { - let mut dfg = DFGBuilder::new(endo_sig(type_row![QB_T, QB_T])).unwrap(); + let mut dfg = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); let [q1, q2] = dfg.input_wires_arr(); @@ -32,7 +34,8 @@ pub(crate) fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle, BuildHandle>().join(", "))] + Extension { + /// The missing extension. + missing_ext: ExtensionId, + /// The available extensions in the registry. + available: Vec, + }, /// The model is not well-formed. #[error("validate error: {0}")] Model(#[from] model::ModelError), @@ -1038,32 +1049,39 @@ impl<'a> Context<'a> { &mut self, term_id: model::TermId, ) -> Result { - match self.get_term(term_id)? { - model::Term::Wildcard => Err(error_uninferred!("wildcard")), - - model::Term::Var(var) => { - let mut es = ExtensionSet::new(); - let (index, _) = self.resolve_local_ref(var)?; - es.insert_type_var(index); - Ok(es) - } + let mut es = ExtensionSet::new(); + let mut stack = vec![term_id]; - model::Term::ExtSet { extensions, rest } => { - let mut es = match rest { - Some(rest) => self.import_extension_set(*rest)?, - None => ExtensionSet::new(), - }; + while let Some(term_id) = stack.pop() { + match self.get_term(term_id)? { + model::Term::Wildcard => return Err(error_uninferred!("wildcard")), - for ext in extensions.iter() { - let ext_ident = IdentList::new(*ext) - .map_err(|_| model::ModelError::MalformedName(ext.to_smolstr()))?; - es.insert(&ext_ident); + model::Term::Var(var) => { + let (index, _) = self.resolve_local_ref(var)?; + es.insert_type_var(index); } - Ok(es) + model::Term::ExtSet { parts } => { + for part in *parts { + match part { + model::ExtSetPart::Extension(ext) => { + let ext_ident = IdentList::new(*ext).map_err(|_| { + model::ModelError::MalformedName(ext.to_smolstr()) + })?; + es.insert(&ext_ident); + } + model::ExtSetPart::Splice(term_id) => { + // The order in an extension set does not matter. + stack.push(*term_id); + } + } + } + } + _ => return Err(model::ModelError::TypeError(term_id).into()), } - _ => Err(model::ModelError::TypeError(term_id).into()), } + + Ok(es) } /// Import a `Type` from a term that represents a runtime type. @@ -1086,6 +1104,14 @@ impl<'a> Context<'a> { let name = self.get_global_name(*name)?; let (extension, id) = self.import_custom_name(name)?; + let extension_ref = + self.extensions.get(&extension.to_string()).ok_or_else(|| { + ImportError::Extension { + missing_ext: extension.clone(), + available: self.extensions.ids().cloned().collect(), + } + })?; + Ok(TypeBase::new_extension(CustomType::new( id, args, @@ -1093,6 +1119,7 @@ impl<'a> Context<'a> { // As part of the migration from `TypeBound`s to constraints, we pretend that all // `TypeBound`s are copyable. TypeBound::Copyable, + &Arc::downgrade(extension_ref), ))) } @@ -1103,7 +1130,7 @@ impl<'a> Context<'a> { } model::Term::FuncType { .. } => { - let func_type = self.import_func_type::(term_id)?; + let func_type = self.import_func_type::(term_id)?; Ok(TypeBase::new_function(func_type)) } @@ -1157,39 +1184,45 @@ impl<'a> Context<'a> { term_id: model::TermId, ) -> Result, ImportError> { let (inputs, outputs, extensions) = self.get_func_type(term_id)?; - let inputs = self.import_type_row::(inputs)?; - let outputs = self.import_type_row::(outputs)?; + let inputs = self.import_type_row(inputs)?; + let outputs = self.import_type_row(outputs)?; let extensions = self.import_extension_set(extensions)?; Ok(FuncTypeBase::new(inputs, outputs).with_extension_delta(extensions)) } fn import_closed_list( &mut self, - mut term_id: model::TermId, + term_id: model::TermId, ) -> Result, ImportError> { - // PERFORMANCE: We currently allocate a Vec here to collect list items - // into, in order to handle the case where the tail of the list is another - // list. We should avoid this. - let mut list_items = Vec::new(); - - loop { - match self.get_term(term_id)? { - model::Term::Var(_) => return Err(error_unsupported!("open lists")), - model::Term::List { items, tail } => { - list_items.extend(items.iter()); - - match tail { - Some(tail) => term_id = *tail, - None => break, + fn import_into( + ctx: &mut Context, + term_id: model::TermId, + types: &mut Vec, + ) -> Result<(), ImportError> { + match ctx.get_term(term_id)? { + model::Term::List { parts } => { + types.reserve(parts.len()); + + for part in *parts { + match part { + model::ListPart::Item(term_id) => { + types.push(*term_id); + } + model::ListPart::Splice(term_id) => { + import_into(ctx, *term_id, types)?; + } + } } } - _ => { - return Err(model::ModelError::TypeError(term_id).into()); - } + _ => return Err(model::ModelError::TypeError(term_id).into()), } + + Ok(()) } - Ok(list_items) + let mut types = Vec::new(); + import_into(self, term_id, &mut types)?; + Ok(types) } fn import_type_rows( @@ -1197,8 +1230,8 @@ impl<'a> Context<'a> { term_id: model::TermId, ) -> Result>, ImportError> { self.import_closed_list(term_id)? - .iter() - .map(|row| self.import_type_row::(*row)) + .into_iter() + .map(|term_id| self.import_type_row::(term_id)) .collect() } @@ -1206,13 +1239,41 @@ impl<'a> Context<'a> { &mut self, term_id: model::TermId, ) -> Result, ImportError> { - let items = self - .import_closed_list(term_id)? - .iter() - .map(|item| self.import_type(*item)) - .collect::, _>>()?; + fn import_into( + ctx: &mut Context, + term_id: model::TermId, + types: &mut Vec>, + ) -> Result<(), ImportError> { + match ctx.get_term(term_id)? { + model::Term::List { parts } => { + types.reserve(parts.len()); + + for item in *parts { + match item { + model::ListPart::Item(term_id) => { + types.push(ctx.import_type::(*term_id)?); + } + model::ListPart::Splice(term_id) => { + import_into(ctx, *term_id, types)?; + } + } + } + } + model::Term::Var(var) => { + let (index, _) = ctx.resolve_local_ref(var)?; + let var = RV::try_from_rv(RowVariable(index, TypeBound::Any)) + .map_err(|_| model::ModelError::TypeError(term_id))?; + types.push(TypeBase::new(TypeEnum::RowVar(var))); + } + _ => return Err(model::ModelError::TypeError(term_id).into()), + } + + Ok(()) + } - Ok(items.into()) + let mut types = Vec::new(); + import_into(self, term_id, &mut types)?; + Ok(types.into()) } fn import_custom_name( diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 24ce8492e..d3c24a89a 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -10,7 +10,7 @@ pub mod sum; pub mod tag; pub mod validate; use crate::extension::simple_op::MakeExtensionOp; -use crate::extension::ExtensionSet; +use crate::extension::{ExtensionId, ExtensionSet}; use crate::types::{EdgeKind, Signature}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; @@ -300,6 +300,15 @@ impl OpType { self.as_extension_op() .and_then(|o| T::from_extension_op(o).ok()) } + + /// Returns the extension where the operation is defined, if any. + pub fn extension_id(&self) -> Option<&ExtensionId> { + match self { + OpType::OpaqueOp(opaque) => Some(opaque.extension()), + OpType::ExtensionOp(e) => Some(e.def().extension_id()), + _ => None, + } + } } /// Macro used by operations that want their diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 3522c61b3..f79f05d12 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -243,23 +243,23 @@ pub enum Value { /// use serde::{Serialize,Deserialize}; /// use hugr::{ /// types::Type,ops::constant::{OpaqueValue, ValueName, CustomConst, CustomSerialized}, -/// extension::{ExtensionSet, prelude::{USIZE_T, ConstUsize}}, +/// extension::{ExtensionSet, prelude::{usize_t, ConstUsize}}, /// std_extensions::arithmetic::int_types}; /// use serde_json::json; /// /// let expected_json = json!({ /// "extensions": ["prelude"], -/// "typ": USIZE_T, +/// "typ": usize_t(), /// "value": {'c': "ConstUsize", 'v': 1} /// }); /// let ev = OpaqueValue::new(ConstUsize::new(1)); /// assert_eq!(&serde_json::to_value(&ev).unwrap(), &expected_json); /// assert_eq!(ev, serde_json::from_value(expected_json).unwrap()); /// -/// let ev = OpaqueValue::new(CustomSerialized::new(USIZE_T.clone(), serde_json::Value::Null, ExtensionSet::default())); +/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null, ExtensionSet::default())); /// let expected_json = json!({ /// "extensions": [], -/// "typ": USIZE_T, +/// "typ": usize_t(), /// "value": null /// }); /// @@ -560,18 +560,20 @@ pub type ValueNameRef = str; #[cfg(test)] mod test { use std::collections::HashSet; + use std::sync::{Arc, Weak}; use super::Value; use crate::builder::inout_sig; use crate::builder::test::simple_dfg_hugr; + use crate::extension::prelude::{bool_t, usize_custom_t}; use crate::std_extensions::arithmetic::int_types::ConstInt; use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, extension::{ - prelude::{ConstUsize, USIZE_CUSTOM_T, USIZE_T}, + prelude::{usize_t, ConstUsize}, ExtensionId, ExtensionRegistry, PRELUDE, }, - std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, + std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}, type_row, types::type_param::TypeArg, types::{Type, TypeBound, TypeRow}, @@ -604,7 +606,7 @@ mod test { } } - /// A [`CustomSerialized`] encoding a [`FLOAT64_TYPE`] float constant used in testing. + /// A [`CustomSerialized`] encoding a [`float64_type()`] float constant used in testing. pub(crate) fn serialized_float(f: f64) -> Value { CustomSerialized::try_from_custom_const(ConstF64::new(f)) .unwrap() @@ -619,17 +621,18 @@ mod test { #[test] fn test_sum() -> Result<(), BuildError> { use crate::builder::Container; - let pred_rows = vec![type_row![USIZE_T, FLOAT64_TYPE], Type::EMPTY_TYPEROW]; + let pred_rows = vec![vec![usize_t(), float64_type()].into(), Type::EMPTY_TYPEROW]; let pred_ty = SumType::new(pred_rows.clone()); let mut b = DFGBuilder::new(inout_sig( type_row![], TypeRow::from(vec![pred_ty.clone().into()]), ))?; + let usize_custom_t = usize_custom_t(&Arc::downgrade(&PRELUDE)); let c = b.add_constant(Value::sum( 0, [ - CustomTestValue(USIZE_CUSTOM_T).into(), + CustomTestValue(usize_custom_t.clone()).into(), ConstF64::new(5.1).into(), ], pred_ty.clone(), @@ -650,7 +653,7 @@ mod test { #[test] fn test_bad_sum() { - let pred_ty = SumType::new([type_row![USIZE_T, FLOAT64_TYPE], type_row![]]); + let pred_ty = SumType::new([vec![usize_t(), float64_type()].into(), type_row![]]); let good_sum = const_usize(); println!("{}", serde_json::to_string_pretty(&good_sum).unwrap()); @@ -686,7 +689,7 @@ mod test { index: 1, expected, found, - })) if expected == FLOAT64_TYPE && found == const_usize() + })) if expected == float64_type() && found == const_usize() ); } @@ -694,9 +697,7 @@ mod test { fn function_value(simple_dfg_hugr: Hugr) { let v = Value::function(simple_dfg_hugr).unwrap(); - let correct_type = Type::new_function(Signature::new_endo(type_row![ - crate::extension::prelude::BOOL_T - ])); + let correct_type = Type::new_function(Signature::new_endo(vec![bool_t()])); assert_eq!(v.get_type(), correct_type); assert!(v.name().starts_with("const:function:")) @@ -714,9 +715,9 @@ mod test { #[rstest] #[case(Value::unit(), Type::UNIT, "const:seq:{}")] - #[case(const_usize(), USIZE_T, "const:custom:ConstUsize(")] - #[case(serialized_float(17.4), FLOAT64_TYPE, "const:custom:json:Object")] - #[case(const_tuple(), Type::new_tuple(type_row![USIZE_T, FLOAT64_TYPE]), "const:seq:{")] + #[case(const_usize(), usize_t(), "const:custom:ConstUsize(")] + #[case(serialized_float(17.4), float64_type(), "const:custom:json:Object")] + #[case(const_tuple(), Type::new_tuple(vec![usize_t(), float64_type()]), "const:seq:{")] fn const_type( #[case] const_value: Value, #[case] expected_type: Type, @@ -749,6 +750,8 @@ mod test { vec![TypeArg::BoundedNat { n: 8 }], ex_id.clone(), TypeBound::Copyable, + // Dummy extension reference. + &Weak::default(), ); let json_const: Value = CustomSerialized::new(typ_int.clone(), 6.into(), ex_id.clone()).into(); @@ -756,7 +759,13 @@ mod test { assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable); assert_eq!(json_const.get_type(), classic_t); - let typ_qb = CustomType::new("my_type", vec![], ex_id, TypeBound::Copyable); + let typ_qb = CustomType::new( + "my_type", + vec![], + ex_id, + TypeBound::Copyable, + &Weak::default(), + ); let t = Type::new_extension(typ_qb.clone()); assert_ne!(json_const.get_type(), t); } @@ -937,7 +946,10 @@ mod test { Value::sum( 1, [Value::true_val()], - SumType::new([vec![Type::UNIT], vec![Value::true_val().get_type()]]), + SumType::new([ + type_row![Type::UNIT], + vec![Value::true_val().get_type()].into() + ]), ) .unwrap() ]) diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index a95b4b9ba..b6755d5c6 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -62,7 +62,7 @@ pub trait CustomConst: /// (a set to allow, say, a [List] of [USize]) /// /// [List]: crate::std_extensions::collections::LIST_TYPENAME - /// [USize]: crate::extension::prelude::USIZE_T + /// [USize]: crate::extension::prelude::usize_t fn extension_reqs(&self) -> ExtensionSet; /// Check the value. @@ -360,7 +360,7 @@ mod test { use rstest::rstest; use crate::{ - extension::prelude::{ConstUsize, USIZE_T}, + extension::prelude::{usize_t, ConstUsize}, ops::{constant::custom::serialize_custom_const, Value}, std_extensions::collections::ListValue, }; @@ -386,7 +386,7 @@ mod test { fn scce_list() -> SerializeCustomConstExample { let cc = ListValue::new( - USIZE_T, + usize_t(), [ConstUsize::new(1), ConstUsize::new(2)] .into_iter() .map(Value::extension), diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index eec5f4d34..be4ec01b9 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -1,5 +1,6 @@ //! Extensible operations. +use itertools::Itertools; use std::sync::Arc; use thiserror::Error; #[cfg(test)] @@ -11,14 +12,12 @@ use { }; use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; -use crate::hugr::internal::HugrMutInternals; -use crate::hugr::HugrView; use crate::types::{type_param::TypeArg, Signature}; -use crate::{ops, Hugr, IncomingPort, Node}; +use crate::{ops, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; use super::tag::OpTag; -use super::{NamedOp, OpName, OpNameRef, OpTrait, OpType}; +use super::{NamedOp, OpName, OpNameRef}; /// An operation defined by an [OpDef] from a loaded [Extension]. /// @@ -56,13 +55,13 @@ impl ExtensionOp { } /// If OpDef is missing binary computation, trust the cached signature. - fn new_with_cached( + pub(crate) fn new_with_cached( def: Arc, - args: impl Into>, + args: impl IntoIterator, opaque: &OpaqueOp, exts: &ExtensionRegistry, ) -> Result { - let args: Vec = args.into(); + let args: Vec = args.into_iter().collect(); // TODO skip computation depending on config // see https://github.com/CQCL/hugr/issues/1363 let signature = match def.compute_signature(&args, exts) { @@ -99,18 +98,24 @@ impl ExtensionOp { /// [`ExtensionOp`]. /// /// Regenerating the [`ExtensionOp`] back from the [`OpaqueOp`] requires a - /// registry with the appropriate extension. See [`resolve_opaque_op`]. + /// registry with the appropriate extension. See + /// [`crate::Hugr::resolve_extension_defs`]. /// /// For a non-cloning version of this operation, use [`OpaqueOp::from`]. pub fn make_opaque(&self) -> OpaqueOp { OpaqueOp { - extension: self.def.extension().clone(), + extension: self.def.extension_id().clone(), name: self.def.name().clone(), description: self.def.description().into(), args: self.args.clone(), signature: self.signature.clone(), } } + + /// Returns a mutable reference to the cached signature of the operation. + pub fn signature_mut(&mut self) -> &mut Signature { + &mut self.signature + } } impl From for OpaqueOp { @@ -121,7 +126,7 @@ impl From for OpaqueOp { signature, } = op; OpaqueOp { - extension: def.extension().clone(), + extension: def.extension_id().clone(), name: def.name().clone(), description: def.description().into(), args, @@ -141,7 +146,7 @@ impl Eq for ExtensionOp {} impl NamedOp for ExtensionOp { /// The name of the operation. fn name(&self) -> OpName { - qualify_name(self.def.extension(), self.def.name()) + qualify_name(self.def.extension_id(), self.def.name()) } } @@ -202,6 +207,11 @@ impl OpaqueOp { signature, } } + + /// Returns a mutable reference to the signature of the operation. + pub fn signature_mut(&mut self) -> &mut Signature { + &mut self.signature + } } impl NamedOp for OpaqueOp { @@ -241,89 +251,25 @@ impl DataflowOpTrait for OpaqueOp { } } -/// Resolve serialized names of operations into concrete implementation (OpDefs) where possible -pub fn resolve_extension_ops( - h: &mut Hugr, - extension_registry: &ExtensionRegistry, -) -> Result<(), OpaqueOpError> { - let mut replacements = Vec::new(); - for n in h.nodes() { - if let OpType::OpaqueOp(opaque) = h.get_optype(n) { - let resolved = resolve_opaque_op(n, opaque, extension_registry)?; - replacements.push((n, resolved)); - } - } - // Only now can we perform the replacements as the 'for' loop was borrowing 'h' preventing use from using it mutably - for (n, op) in replacements { - debug_assert_eq!(h.get_optype(n).tag(), OpTag::Leaf); - debug_assert_eq!(op.tag(), OpTag::Leaf); - h.replace_op(n, op).unwrap(); - } - Ok(()) -} - -/// Try to resolve a [`OpaqueOp`] to a [`ExtensionOp`] by looking the op up in -/// the registry. -/// -/// # Return -/// Some if the serialized opaque resolves to an extension-defined op and all is -/// ok; None if the serialized opaque doesn't identify an extension -/// -/// # Errors -/// If the serialized opaque resolves to a definition that conflicts with what -/// was serialized -pub fn resolve_opaque_op( - node: Node, - opaque: &OpaqueOp, - extension_registry: &ExtensionRegistry, -) -> Result { - if let Some(r) = extension_registry.get(&opaque.extension) { - // Fail if the Extension was found but did not have the expected operation - let Some(def) = r.get_op(&opaque.name) else { - return Err(OpaqueOpError::OpNotFoundInExtension( - node, - opaque.name.clone(), - r.name().clone(), - )); - }; - let ext_op = ExtensionOp::new_with_cached( - def.clone(), - opaque.args.clone(), - opaque, - extension_registry, - ) - .map_err(|e| OpaqueOpError::SignatureError { - node, - name: opaque.name.clone(), - cause: e, - })?; - if opaque.signature() != ext_op.signature() { - return Err(OpaqueOpError::SignatureMismatch { - node, - extension: opaque.extension.clone(), - op: def.name().clone(), - computed: ext_op.signature.clone(), - stored: opaque.signature.clone(), - }); - }; - Ok(ext_op) - } else { - Err(OpaqueOpError::UnresolvedOp( - node, - opaque.name.clone(), - opaque.extension.clone(), - )) - } -} - /// Errors that arise after loading a Hugr containing opaque ops (serialized just as their names) /// when trying to resolve the serialized names against a registry of known Extensions. #[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] pub enum OpaqueOpError { /// The Extension was found but did not contain the expected OpDef - #[error("Operation '{1}' in {0} not found in Extension {2}")] - OpNotFoundInExtension(Node, OpName, ExtensionId), + #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}", + available_ops.iter().join(", ") + )] + OpNotFoundInExtension { + /// The node where the error occurred. + node: Node, + /// The missing operation. + op: OpName, + /// The extension where the operation was expected. + extension: ExtensionId, + /// The available operations in the extension. + available_ops: Vec, + }, /// Extension and OpDef found, but computed signature did not match stored #[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")] #[allow(missing_docs)] @@ -351,10 +297,13 @@ pub enum OpaqueOpError { #[cfg(test)] mod test { + use ops::OpType; + + use crate::extension::resolution::update_op_extensions; use crate::std_extensions::arithmetic::conversions::{self, CONVERT_OPS_REGISTRY}; use crate::{ extension::{ - prelude::{BOOL_T, QB_T, USIZE_T}, + prelude::{bool_t, qb_t, usize_t}, SignatureFunc, }, std_extensions::arithmetic::int_types::INT_TYPES, @@ -364,19 +313,24 @@ mod test { use super::*; + /// Unwrap the replacement type's `OpDef` from the return type of `resolve_op_definition`. + fn resolve_res_definition(res: &OpType) -> &OpDef { + res.as_extension_op().unwrap().def() + } + #[test] fn new_opaque_op() { - let sig = Signature::new_endo(vec![QB_T]); + let sig = Signature::new_endo(vec![qb_t()]); let op = OpaqueOp::new( "res".try_into().unwrap(), "op", "desc".into(), - vec![TypeArg::Type { ty: USIZE_T }], + vec![TypeArg::Type { ty: usize_t() }], sig.clone(), ); assert_eq!(op.name(), "res.op"); assert_eq!(DataflowOpTrait::description(&op), "desc"); - assert_eq!(op.args(), &[TypeArg::Type { ty: USIZE_T }]); + assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]); assert_eq!( op.signature(), sig.with_extension_delta(op.extension().clone()) @@ -392,36 +346,44 @@ mod test { "itobool", "description".into(), vec![], - Signature::new(i0.clone(), BOOL_T), + Signature::new(i0.clone(), bool_t()), ); - let resolved = - super::resolve_opaque_op(Node::from(portgraph::NodeIndex::new(1)), &opaque, registry) - .unwrap(); - assert_eq!(resolved.def().name(), "itobool"); + let mut resolved = opaque.into(); + update_op_extensions( + Node::from(portgraph::NodeIndex::new(1)), + &mut resolved, + registry, + ) + .unwrap(); + assert_eq!(resolve_res_definition(&resolved).name(), "itobool"); } #[test] fn resolve_missing() { - let mut ext = Extension::new_test("ext".try_into().unwrap()); - let ext_id = ext.name().clone(); let val_name = "missing_val"; let comp_name = "missing_comp"; + let endo_sig = Signature::new_endo(bool_t()); + + let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| { + ext.add_op( + val_name.into(), + "".to_string(), + SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()), + extension_ref, + ) + .unwrap(); + + ext.add_op( + comp_name.into(), + "".to_string(), + SignatureFunc::MissingComputeFunc, + extension_ref, + ) + .unwrap(); + }); + let ext_id = ext.name().clone(); - let endo_sig = Signature::new_endo(BOOL_T); - ext.add_op( - val_name.into(), - "".to_string(), - SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()), - ) - .unwrap(); - - ext.add_op( - comp_name.into(), - "".to_string(), - SignatureFunc::MissingComputeFunc, - ) - .unwrap(); - let registry = ExtensionRegistry::try_new([ext.into()]).unwrap(); + let registry = ExtensionRegistry::try_new([ext]).unwrap(); let opaque_val = OpaqueOp::new( ext_id.clone(), val_name, @@ -430,20 +392,22 @@ mod test { endo_sig.clone(), ); let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig); - let resolved_val = super::resolve_opaque_op( + let mut resolved_val = opaque_val.into(); + update_op_extensions( Node::from(portgraph::NodeIndex::new(1)), - &opaque_val, + &mut resolved_val, ®istry, ) .unwrap(); - assert_eq!(resolved_val.def().name(), val_name); + assert_eq!(resolve_res_definition(&resolved_val).name(), val_name); - let resolved_comp = super::resolve_opaque_op( + let mut resolved_comp = opaque_comp.into(); + update_op_extensions( Node::from(portgraph::NodeIndex::new(2)), - &opaque_comp, + &mut resolved_comp, ®istry, ) .unwrap(); - assert_eq!(resolved_comp.def().name(), comp_name); + assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name); } } diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 364429784..ed9a9e118 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -221,9 +221,9 @@ impl Call { /// # use hugr::ops::dataflow::Call; /// # use hugr::ops::OpType; /// # use hugr::types::Signature; - /// # use hugr::extension::prelude::QB_T; + /// # use hugr::extension::prelude::qb_t; /// # use hugr::extension::PRELUDE_REGISTRY; - /// let signature = Signature::new(vec![QB_T, QB_T], vec![QB_T, QB_T]); + /// let signature = Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]); /// let call = Call::try_new(signature.into(), &[], &PRELUDE_REGISTRY).unwrap(); /// let op = OpType::Call(call.clone()); /// assert_eq!(op.static_input_port(), Some(call.called_function_port())); diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index ce67bb9a0..16eed59e6 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -351,21 +351,21 @@ fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> #[cfg(test)] mod test { - use crate::extension::prelude::{Noop, USIZE_T}; + use crate::extension::prelude::{usize_t, Noop}; + use crate::ops; use crate::ops::dataflow::IOTrait; - use crate::{ops, type_row}; use cool_asserts::assert_matches; use super::*; #[test] fn test_validate_io_nodes() { - let in_types: TypeRow = type_row![USIZE_T]; - let out_types: TypeRow = type_row![USIZE_T, USIZE_T]; + let in_types: TypeRow = vec![usize_t()].into(); + let out_types: TypeRow = vec![usize_t(), usize_t()].into(); let input_node: OpType = ops::Input::new(in_types.clone()).into(); let output_node = ops::Output::new(out_types.clone()).into(); - let leaf_node = Noop(USIZE_T).into(); + let leaf_node = Noop(usize_t()).into(); // Well-formed dataflow sibling nodes. Check the input and output node signatures. let children = vec![ diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index cec1b2c85..aa32f24d5 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -99,7 +99,7 @@ impl Package { reg: &mut ExtensionRegistry, ) -> Result<(), PackageValidationError> { for ext in &self.extensions { - reg.register_updated_ref(ext)?; + reg.register_updated_ref(ext); } for hugr in self.modules.iter_mut() { hugr.update_validate(reg)?; diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index deb93f8c2..b8b2771c2 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -1,28 +1,24 @@ //! Conversions between integer and floating-point values. -use std::sync::Arc; +use std::sync::{Arc, Weak}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; -use crate::extension::prelude::{BOOL_T, STRING_TYPE, USIZE_T}; +use crate::extension::prelude::sum_with_error; +use crate::extension::prelude::{bool_t, string_type, usize_t}; use crate::extension::simple_op::{HasConcrete, HasDef}; +use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; +use crate::extension::{ + ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, PRELUDE, +}; use crate::ops::OpName; +use crate::ops::{custom::ExtensionOp, NamedOp}; use crate::std_extensions::arithmetic::int_ops::int_polytype; use crate::std_extensions::arithmetic::int_types::int_type; -use crate::{ - extension::{ - prelude::sum_with_error, - simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, - ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, - PRELUDE, - }, - ops::{custom::ExtensionOp, NamedOp}, - type_row, - types::{TypeArg, TypeRV}, - Extension, -}; +use crate::types::{TypeArg, TypeRV}; +use crate::Extension; -use super::float_types::FLOAT64_TYPE; +use super::float_types::float64_type; use super::int_types::{get_log_width, int_tv}; use lazy_static::lazy_static; mod const_fold; @@ -50,27 +46,31 @@ pub enum ConvertOpDef { impl MakeOpDef for ConvertOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { EXTENSION_ID.to_owned() } - fn signature(&self) -> SignatureFunc { + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { use ConvertOpDef::*; match self { trunc_s | trunc_u => int_polytype( 1, - type_row![FLOAT64_TYPE], + vec![float64_type()], TypeRV::from(sum_with_error(int_tv(0))), ), - convert_s | convert_u => int_polytype(1, vec![int_tv(0)], type_row![FLOAT64_TYPE]), - itobool => int_polytype(0, vec![int_type(0)], vec![BOOL_T]), - ifrombool => int_polytype(0, vec![BOOL_T], vec![int_type(0)]), - itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![STRING_TYPE]), - itousize => int_polytype(0, vec![int_type(6)], vec![USIZE_T]), - ifromusize => int_polytype(0, vec![USIZE_T], vec![int_type(6)]), + convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]), + itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]), + ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]), + itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]), + itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]), + ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]), } .into() } @@ -158,18 +158,15 @@ impl MakeExtensionOp for ConvertOpType { lazy_static! { /// Extension for conversions between integers and floats. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ]), - ); - - ConvertOpDef::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_requirements( + ExtensionSet::from_iter(vec![ + super::int_types::EXTENSION_ID, + super::float_types::EXTENSION_ID, + ])); + + ConvertOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate integer operations. diff --git a/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs b/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs index e9c736858..8ef7bb63d 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions/const_fold.rs @@ -4,7 +4,7 @@ use crate::ops::Value; use crate::std_extensions::arithmetic::int_types::INT_TYPES; use crate::{ extension::{ - prelude::{const_ok, ConstError, ERROR_TYPE}, + prelude::{const_ok, error_type, ConstError}, ConstFold, ConstFoldResult, OpDef, }, ops, @@ -59,7 +59,7 @@ fn fold_trunc( } else { let cv = convert(f, log_width); if let Ok(cv) = cv { - const_ok(cv, ERROR_TYPE) + const_ok(cv, error_type()) } else { err_value() } diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 7d353e71a..dad35d3c7 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -1,17 +1,16 @@ //! Basic floating-point operations. -use std::sync::Arc; +use std::sync::{Arc, Weak}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; -use super::float_types::FLOAT64_TYPE; +use super::float_types::float64_type; use crate::{ extension::{ - prelude::{BOOL_T, STRING_TYPE}, + prelude::{bool_t, string_type}, simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, PRELUDE, }, - type_row, types::Signature, Extension, }; @@ -50,25 +49,29 @@ pub enum FloatOps { impl MakeOpDef for FloatOps { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { EXTENSION_ID.to_owned() } - fn signature(&self) -> SignatureFunc { + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { use FloatOps::*; match self { feq | fne | flt | fgt | fle | fge => { - Signature::new(type_row![FLOAT64_TYPE; 2], type_row![BOOL_T]) + Signature::new(vec![float64_type(); 2], vec![bool_t()]) } fmax | fmin | fadd | fsub | fmul | fdiv | fpow => { - Signature::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]) + Signature::new(vec![float64_type(); 2], vec![float64_type()]) } - fneg | fabs | ffloor | fceil | fround => Signature::new_endo(type_row![FLOAT64_TYPE]), - ftostring => Signature::new(type_row![FLOAT64_TYPE], STRING_TYPE), + fneg | fabs | ffloor | fceil | fround => Signature::new_endo(vec![float64_type()]), + ftostring => Signature::new(vec![float64_type()], string_type()), } .into() } @@ -107,15 +110,10 @@ impl MakeOpDef for FloatOps { lazy_static! { /// Extension for basic float operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::singleton(&super::int_types::EXTENSION_ID), - ); - - FloatOps::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_requirements(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + FloatOps::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate float operations. diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index ec145008f..304f44899 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -1,6 +1,6 @@ //! Basic floating-point types -use std::sync::Arc; +use std::sync::{Arc, Weak}; use crate::ops::constant::{TryHash, ValueName}; use crate::types::TypeName; @@ -18,14 +18,23 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.flo pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Identifier for the 64-bit IEEE 754-2019 floating-point type. -const FLOAT_TYPE_ID: TypeName = TypeName::new_inline("float64"); +pub const FLOAT_TYPE_ID: TypeName = TypeName::new_inline("float64"); /// 64-bit IEEE 754-2019 floating-point type (as [CustomType]) -pub const FLOAT64_CUSTOM_TYPE: CustomType = - CustomType::new_simple(FLOAT_TYPE_ID, EXTENSION_ID, TypeBound::Copyable); +pub fn float64_custom_type(extension_ref: &Weak) -> CustomType { + CustomType::new( + FLOAT_TYPE_ID, + vec![], + EXTENSION_ID, + TypeBound::Copyable, + extension_ref, + ) +} /// 64-bit IEEE 754-2019 floating-point type (as [Type]) -pub const FLOAT64_TYPE: Type = Type::new_extension(FLOAT64_CUSTOM_TYPE); +pub fn float64_type() -> Type { + float64_custom_type(&Arc::downgrade(&EXTENSION)).into() +} #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] /// A floating-point value. @@ -67,7 +76,7 @@ impl CustomConst for ConstF64 { } fn get_type(&self) -> Type { - FLOAT64_TYPE + float64_type() } fn equal_consts(&self, _: &dyn CustomConst) -> bool { @@ -82,18 +91,17 @@ impl CustomConst for ConstF64 { lazy_static! { /// Extension defining the float type. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - extension - .add_type( - FLOAT_TYPE_ID, - vec![], - "64-bit IEEE 754-2019 floating-point value".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + FLOAT_TYPE_ID, + vec![], + "64-bit IEEE 754-2019 floating-point value".to_owned(), + TypeBound::Copyable.into(), + extension_ref, + ) + .unwrap(); + }) }; } #[cfg(test)] diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 97bb247a2..4f6332767 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -1,9 +1,9 @@ //! Basic integer operations. -use std::sync::Arc; +use std::sync::{Arc, Weak}; use super::int_types::{get_log_width, int_tv, LOG_WIDTH_TYPE_PARAM}; -use crate::extension::prelude::{sum_with_error, BOOL_T}; +use crate::extension::prelude::{bool_t, sum_with_error}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; @@ -12,7 +12,6 @@ use crate::extension::{ }; use crate::ops::custom::ExtensionOp; use crate::ops::{NamedOp, OpName}; -use crate::type_row; use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; use crate::utils::collect_array; @@ -104,14 +103,18 @@ pub enum IntOpDef { impl MakeOpDef for IntOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { EXTENSION_ID.to_owned() } - fn signature(&self) -> SignatureFunc { + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { use IntOpDef::*; let tv0 = int_tv(0); match self { @@ -126,7 +129,7 @@ impl MakeOpDef for IntOpDef { ) .into(), ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => { - int_polytype(1, vec![tv0; 2], type_row![BOOL_T]).into() + int_polytype(1, vec![tv0; 2], vec![bool_t()]).into() } imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => { ibinop_sig().into() @@ -250,15 +253,10 @@ fn iunop_sig() -> PolyFuncTypeRV { lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::singleton(&super::int_types::EXTENSION_ID) - ); - - IntOpDef::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_requirements(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + IntOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate integer operations. diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 3d257b9d0..db52c6576 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -1,7 +1,7 @@ //! Basic integer types use std::num::NonZeroU64; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use crate::ops::constant::ValueName; use crate::types::TypeName; @@ -26,12 +26,16 @@ pub const INT_TYPE_ID: TypeName = TypeName::new_inline("int"); /// Integer type of a given bit width (specified by the TypeArg). Depending on /// the operation, the semantic interpretation may be unsigned integer, signed /// integer or bit string. -pub fn int_custom_type(width_arg: impl Into) -> CustomType { +pub fn int_custom_type( + width_arg: impl Into, + extension_ref: &Weak, +) -> CustomType { CustomType::new( INT_TYPE_ID, [width_arg.into()], EXTENSION_ID, TypeBound::Copyable, + extension_ref, ) } @@ -39,7 +43,7 @@ pub fn int_custom_type(width_arg: impl Into) -> CustomType { /// /// Constructed from [int_custom_type]. pub fn int_type(width_arg: impl Into) -> Type { - Type::new_extension(int_custom_type(width_arg.into())) + int_custom_type(width_arg.into(), &Arc::::downgrade(&EXTENSION)).into() } lazy_static! { @@ -187,19 +191,18 @@ impl CustomConst for ConstInt { } /// Extension for basic integer types. -pub fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - extension - .add_type( - INT_TYPE_ID, - vec![LOG_WIDTH_TYPE_PARAM], - "integral value of a given bit width".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - Arc::new(extension) +fn extension() -> Arc { + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + INT_TYPE_ID, + vec![LOG_WIDTH_TYPE_PARAM], + "integral value of a given bit width".to_owned(), + TypeBound::Copyable.into(), + extension_ref, + ) + .unwrap(); + }) } lazy_static! { diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 17a1b0d03..492edf428 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -5,14 +5,14 @@ use std::hash::{Hash, Hasher}; mod list_fold; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use itertools::Itertools; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; -use crate::extension::prelude::{either_type, option_type, USIZE_T}; +use crate::extension::prelude::{either_type, option_type, usize_t}; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; use crate::ops::constant::{maybe_hash_values, TryHash, ValueName}; @@ -59,6 +59,16 @@ impl ListValue { pub fn custom_type(&self) -> CustomType { list_custom_type(self.1.clone()) } + + /// Returns the type of values inside the `[ListValue]`. + pub fn get_element_type(&self) -> &Type { + &self.1 + } + + /// Returns the values contained inside the `[ListValue]`. + pub fn get_contents(&self) -> &[Value] { + &self.0 + } } impl TryHash for ListValue { @@ -165,21 +175,23 @@ impl ListOp { .into(), push => self.list_polytype(vec![l.clone(), e], vec![l]).into(), get => self - .list_polytype(vec![l, USIZE_T], vec![Type::from(option_type(e))]) + .list_polytype(vec![l, usize_t()], vec![Type::from(option_type(e))]) .into(), set => self .list_polytype( - vec![l.clone(), USIZE_T, e.clone()], + vec![l.clone(), usize_t(), e.clone()], vec![l, Type::from(either_type(e.clone(), e))], ) .into(), insert => self .list_polytype( - vec![l.clone(), USIZE_T, e.clone()], + vec![l.clone(), usize_t(), e.clone()], vec![l, either_type(e, Type::UNIT).into()], ) .into(), - length => self.list_polytype(vec![l.clone()], vec![l, USIZE_T]).into(), + length => self + .list_polytype(vec![l.clone()], vec![l, usize_t()]) + .into(), } } @@ -204,28 +216,36 @@ impl ListOp { impl MakeOpDef for ListOp { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { EXTENSION_ID.to_owned() } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } + /// Add an operation implemented as an [MakeOpDef], which can provide the data /// required to define an [OpDef], to an extension. // // This method is re-defined here since we need to pass the list type def while computing the signature, // to avoid recursive loops initializing the extension. - fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> { let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); Ok(()) } - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { self.compute_signature(list_type_def()) } @@ -251,20 +271,19 @@ impl MakeOpDef for ListOp { lazy_static! { /// Extension for list operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - // The list type must be defined before the operations are added. - extension.add_type( - LIST_TYPENAME, - vec![ListOp::TP], - "Generic dynamically sized list of type T.".into(), - TypeDefBound::from_params(vec![0]), - ) - .unwrap(); - - ListOp::load_all_ops(&mut extension).unwrap(); + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_type( + LIST_TYPENAME, + vec![ListOp::TP], + "Generic dynamically sized list of type T.".into(), + TypeDefBound::from_params(vec![0]), + extension_ref + ) + .unwrap(); - Arc::new(extension) + // The list type must be defined before the operations are added. + ListOp::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate list operations. @@ -353,7 +372,7 @@ impl ListOpInst { .clone() .into_iter() // ignore self if already in registry - .filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext)) + .filter(|ext| ext.name() != EXTENSION.name()) .chain(std::iter::once(EXTENSION.to_owned())), ) .unwrap(); @@ -377,10 +396,10 @@ mod test { use crate::PortIndex; use crate::{ extension::{ - prelude::{ConstUsize, QB_T, USIZE_T}, + prelude::{qb_t, usize_t, ConstUsize}, PRELUDE, }, - std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, + std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}, types::TypeRow, }; @@ -392,7 +411,7 @@ mod test { assert_eq!(&ListOp::push.extension(), EXTENSION.name()); assert!(ListOp::pop.registry().contains(EXTENSION.name())); for (_, op_def) in EXTENSION.operations() { - assert_eq!(op_def.extension(), &EXTENSION_ID); + assert_eq!(op_def.extension_id(), &EXTENSION_ID); } } @@ -401,7 +420,7 @@ mod test { let list_def = list_type_def(); let list_type = list_def - .instantiate([TypeArg::Type { ty: USIZE_T }]) + .instantiate([TypeArg::Type { ty: usize_t() }]) .unwrap(); assert!(list_def @@ -409,11 +428,11 @@ mod test { .is_err()); list_def.check_custom(&list_type).unwrap(); - let list_value = ListValue(vec![ConstUsize::new(3).into()], USIZE_T); + let list_value = ListValue(vec![ConstUsize::new(3).into()], usize_t()); list_value.validate().unwrap(); - let wrong_list_value = ListValue(vec![ConstF64::new(1.2).into()], USIZE_T); + let wrong_list_value = ListValue(vec![ConstF64::new(1.2).into()], usize_t()); assert!(wrong_list_value.validate().is_err()); } @@ -422,26 +441,26 @@ mod test { let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) .unwrap(); - let pop_op = ListOp::pop.with_type(QB_T); + let pop_op = ListOp::pop.with_type(qb_t()); let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); let pop_sig = pop_ext.dataflow_signature().unwrap(); - let list_t = list_type(QB_T); + let list_t = list_type(qb_t()); - let both_row: TypeRow = vec![list_t.clone(), option_type(QB_T).into()].into(); + let both_row: TypeRow = vec![list_t.clone(), option_type(qb_t()).into()].into(); let just_list_row: TypeRow = vec![list_t].into(); assert_eq!(pop_sig.input(), &just_list_row); assert_eq!(pop_sig.output(), &both_row); - let push_op = ListOp::push.with_type(FLOAT64_TYPE); + let push_op = ListOp::push.with_type(float64_type()); let push_ext = push_op.clone().to_extension_op(®).unwrap(); assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op); let push_sig = push_ext.dataflow_signature().unwrap(); - let list_t = list_type(FLOAT64_TYPE); + let list_t = list_type(float64_type()); - let both_row: TypeRow = vec![list_t.clone(), FLOAT64_TYPE].into(); + let both_row: TypeRow = vec![list_t.clone(), float64_type()].into(); let just_list_row: TypeRow = vec![list_t].into(); assert_eq!(push_sig.input(), &both_row); @@ -470,7 +489,7 @@ mod test { .iter() .map(|&i| Value::extension(ConstUsize::new(i as u64))) .collect(); - Value::extension(ListValue(elems, USIZE_T)) + Value::extension(ListValue(elems, usize_t())) } TestVal::Some(l) => { let elems = l.iter().map(TestVal::to_value); @@ -491,13 +510,13 @@ mod test { #[rstest] #[case::pop(ListOp::pop, &[TestVal::List(vec![77,88, 42])], &[TestVal::List(vec![77,88]), TestVal::Some(vec![TestVal::Elem(42)])])] - #[case::pop_empty(ListOp::pop, &[TestVal::List(vec![])], &[TestVal::List(vec![]), TestVal::None(vec![USIZE_T].into())])] + #[case::pop_empty(ListOp::pop, &[TestVal::List(vec![])], &[TestVal::List(vec![]), TestVal::None(vec![usize_t()].into())])] #[case::push(ListOp::push, &[TestVal::List(vec![77,88]), TestVal::Elem(42)], &[TestVal::List(vec![77,88,42])])] - #[case::set(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,42]), TestVal::Ok(vec![TestVal::Elem(88)], vec![USIZE_T].into())])] - #[case::set_invalid(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(123), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(vec![USIZE_T].into(), vec![TestVal::Elem(99)])])] + #[case::set(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,42]), TestVal::Ok(vec![TestVal::Elem(88)], vec![usize_t()].into())])] + #[case::set_invalid(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(123), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(vec![usize_t()].into(), vec![TestVal::Elem(99)])])] #[case::get(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1)], &[TestVal::Some(vec![TestVal::Elem(88)])])] - #[case::get_invalid(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(99)], &[TestVal::None(vec![USIZE_T].into())])] - #[case::insert(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,88,42]), TestVal::Ok(vec![], vec![USIZE_T].into())])] + #[case::get_invalid(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(99)], &[TestVal::None(vec![usize_t()].into())])] + #[case::insert(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,88,42]), TestVal::Ok(vec![], vec![usize_t()].into())])] #[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])] #[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])] fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) { @@ -508,7 +527,7 @@ mod test { .collect(); let res = op - .with_type(USIZE_T) + .with_type(usize_t()) .to_extension_op(&COLLECTIONS_REGISTRY) .unwrap() .constant_fold(&consts) diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index 89f9dfa8b..a4f73ed4f 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -1,6 +1,6 @@ //! Basic logical operations. -use std::sync::Arc; +use std::sync::{Arc, Weak}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; @@ -10,11 +10,11 @@ use crate::ops::Value; use crate::types::Signature; use crate::{ extension::{ - prelude::BOOL_T, + prelude::bool_t, simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp, OpLoadError}, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, }, - ops, type_row, + ops, types::type_param::TypeArg, utils::sorted_consts, Extension, IncomingPort, @@ -70,16 +70,20 @@ pub enum LogicOp { } impl MakeOpDef for LogicOp { - fn signature(&self) -> SignatureFunc { + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { match self { LogicOp::And | LogicOp::Or | LogicOp::Eq => { - Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T]) + Signature::new(vec![bool_t(); 2], vec![bool_t()]) } - LogicOp::Not => Signature::new_endo(type_row![BOOL_T]), + LogicOp::Not => Signature::new_endo(vec![bool_t()]), } .into() } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } + fn description(&self) -> String { match self { LogicOp::And => "logical 'and'", @@ -91,7 +95,7 @@ impl MakeOpDef for LogicOp { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -110,16 +114,16 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for basic logical operations. fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - LogicOp::load_all_ops(&mut extension).unwrap(); + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + LogicOp::load_all_ops(extension, extension_ref).unwrap(); - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); - Arc::new(extension) + extension + .add_value(FALSE_NAME, ops::Value::false_val()) + .unwrap(); + extension + .add_value(TRUE_NAME, ops::Value::true_val()) + .unwrap(); + }) } lazy_static! { @@ -166,7 +170,7 @@ pub(crate) mod test { use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; use crate::{ extension::{ - prelude::BOOL_T, + prelude::bool_t, simple_op::{MakeOpDef, MakeRegisteredOp}, }, ops::{NamedOp, Value}, @@ -206,16 +210,16 @@ pub(crate) mod test { for v in [false_val, true_val] { let simpl = v.typed_value().get_type(); - assert_eq!(simpl, BOOL_T); + assert_eq!(simpl, bool_t()); } } - /// Generate a logic extension "and" operation over [`crate::prelude::BOOL_T`] + /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`] pub(crate) fn and_op() -> LogicOp { LogicOp::And } - /// Generate a logic extension "or" operation over [`crate::prelude::BOOL_T`] + /// Generate a logic extension "or" operation over [`crate::prelude::bool_t()`] pub(crate) fn or_op() -> LogicOp { LogicOp::Or } diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index e3023e4b5..774ea10eb 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -1,6 +1,6 @@ //! Pointer type and operations. -use std::sync::Arc; +use std::sync::{Arc, Weak}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; @@ -47,11 +47,12 @@ impl MakeOpDef for PtrOpDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } - fn signature(&self) -> SignatureFunc { - let ptr_t = ptr_type(Type::new_var_use(0, TypeBound::Copyable)); + fn init_signature(&self, extension_ref: &Weak) -> SignatureFunc { + let ptr_t: Type = + ptr_custom_type(Type::new_var_use(0, TypeBound::Copyable), extension_ref).into(); let inner_t = Type::new_var_use(0, TypeBound::Copyable); let body = match self { PtrOpDef::New => Signature::new(inner_t, ptr_t), @@ -66,6 +67,10 @@ impl MakeOpDef for PtrOpDef { EXTENSION_ID } + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) + } + fn description(&self) -> String { match self { PtrOpDef::New => "Create a new pointer from a value.".into(), @@ -87,17 +92,18 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for pointer operations. fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - extension - .add_type( - PTR_TYPE_ID, - TYPE_PARAMS.into(), - "Standard extension pointer type.".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - PtrOpDef::load_all_ops(&mut extension).unwrap(); - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + PTR_TYPE_ID, + TYPE_PARAMS.into(), + "Standard extension pointer type.".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + PtrOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) } lazy_static! { @@ -111,16 +117,20 @@ lazy_static! { /// Integer type of a given bit width (specified by the TypeArg). Depending on /// the operation, the semantic interpretation may be unsigned integer, signed /// integer or bit string. -pub fn ptr_custom_type(ty: impl Into) -> CustomType { +fn ptr_custom_type(ty: impl Into, extension_ref: &Weak) -> CustomType { let ty = ty.into(); - CustomType::new(PTR_TYPE_ID, [ty.into()], EXTENSION_ID, TypeBound::Copyable) + CustomType::new( + PTR_TYPE_ID, + [ty.into()], + EXTENSION_ID, + TypeBound::Copyable, + extension_ref, + ) } /// Integer type of a given bit width (specified by the TypeArg). -/// -/// Constructed from [ptr_custom_type]. pub fn ptr_type(ty: impl Into) -> Type { - Type::new_extension(ptr_custom_type(ty)) + ptr_custom_type(ty, &Arc::::downgrade(&EXTENSION)).into() } #[derive(Clone, Debug, PartialEq)] @@ -214,7 +224,7 @@ impl HasDef for PtrOp { #[cfg(test)] pub(crate) mod test { use crate::builder::DFGBuilder; - use crate::extension::prelude::BOOL_T; + use crate::extension::prelude::bool_t; use crate::ops::ExtensionOp; use crate::{ builder::{Dataflow, DataflowHugr}, @@ -227,7 +237,7 @@ pub(crate) mod test { use super::*; use crate::std_extensions::arithmetic::float_types::{ - EXTENSION as FLOAT_EXTENSION, FLOAT64_TYPE, + float64_type, EXTENSION as FLOAT_EXTENSION, }; fn get_opdef(op: impl NamedOp) -> Option<&'static Arc> { EXTENSION.get_op(&op.name()) @@ -245,8 +255,8 @@ pub(crate) mod test { #[test] fn test_ops() { let ops = [ - PtrOp::new(PtrOpDef::New, BOOL_T.clone()), - PtrOp::new(PtrOpDef::Read, FLOAT64_TYPE.clone()), + PtrOp::new(PtrOpDef::New, bool_t().clone()), + PtrOp::new(PtrOpDef::Read, float64_type()), PtrOp::new(PtrOpDef::Write, INT_TYPES[5].clone()), ]; for op in ops { @@ -260,7 +270,7 @@ pub(crate) mod test { #[test] fn test_build() { - let in_row = vec![BOOL_T, FLOAT64_TYPE]; + let in_row = vec![bool_t(), float64_type()]; let reg = ExtensionRegistry::try_new([EXTENSION.to_owned(), FLOAT_EXTENSION.to_owned()]).unwrap(); diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 72e752614..655531ab5 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -364,7 +364,7 @@ impl TypeBase { Self::new(TypeEnum::Alias(alias)) } - fn new(type_e: TypeEnum) -> Self { + pub(crate) fn new(type_e: TypeEnum) -> Self { let bound = type_e.least_upper_bound(); Self(type_e, bound) } @@ -395,6 +395,12 @@ impl TypeBase { &self.0 } + /// Report a mutable reference to the component TypeEnum. + #[inline(always)] + pub fn as_type_enum_mut(&mut self) -> &mut TypeEnum { + &mut self.0 + } + /// Report if the type is copyable - i.e.the least upper bound of the type /// is contained by the copyable bound. pub const fn copyable(&self) -> bool { @@ -613,20 +619,24 @@ pub(crate) fn check_typevar_decl( #[cfg(test)] pub(crate) mod test { + use std::sync::Weak; + use super::*; - use crate::extension::prelude::USIZE_T; + use crate::extension::prelude::usize_t; use crate::type_row; #[test] fn construct() { let t: Type = Type::new_tuple(vec![ - USIZE_T, + usize_t(), Type::new_function(Signature::new_endo(vec![])), Type::new_extension(CustomType::new( "my_custom", [], "my_extension".try_into().unwrap(), TypeBound::Copyable, + // Dummy extension reference. + &Weak::default(), )), Type::new_alias(AliasDecl::new("my_alias", TypeBound::Copyable)), ]); diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index e38bde9cf..22af9b77a 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -2,8 +2,10 @@ //! //! [`Type`]: super::Type use std::fmt::{self, Display}; +use std::sync::Weak; use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef}; +use crate::Extension; use super::{ type_param::{TypeArg, TypeParam}, @@ -12,9 +14,13 @@ use super::{ use super::{Type, TypeName}; /// An opaque type element. Contains the unique identifier of its definition. -#[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct CustomType { + /// The identifier for the extension owning this type. extension: ExtensionId, + /// A weak reference to the extension defining this type. + #[serde(skip)] + extension_ref: Weak, /// Unique identifier of the opaque type. /// Same as the corresponding [`TypeDef`] /// @@ -28,6 +34,26 @@ pub struct CustomType { bound: TypeBound, } +impl std::hash::Hash for CustomType { + fn hash(&self, state: &mut H) { + self.extension.hash(state); + self.id.hash(state); + self.args.hash(state); + self.bound.hash(state); + } +} + +impl PartialEq for CustomType { + fn eq(&self, other: &Self) -> bool { + self.extension == other.extension + && self.id == other.id + && self.args == other.args + && self.bound == other.bound + } +} + +impl Eq for CustomType {} + impl CustomType { /// Creates a new opaque type. pub fn new( @@ -35,22 +61,14 @@ impl CustomType { args: impl Into>, extension: ExtensionId, bound: TypeBound, + extension_ref: &Weak, ) -> Self { Self { id: id.into(), args: args.into(), extension, bound, - } - } - - /// Creates a new opaque type (constant version, no type arguments) - pub const fn new_simple(id: TypeName, extension: ExtensionId, bound: TypeBound) -> Self { - Self { - id, - args: vec![], - extension, - bound, + extension_ref: extension_ref.clone(), } } @@ -80,7 +98,10 @@ impl CustomType { let ex = extension_registry.get(&self.extension); // Even if OpDef's (+binaries) are not available, the part of the Extension definition // describing the TypeDefs can easily be passed around (serialized), so should be available. - let ex = ex.ok_or(SignatureError::ExtensionNotFound(self.extension.clone()))?; + let ex = ex.ok_or(SignatureError::ExtensionNotFound { + missing: self.extension.clone(), + available: extension_registry.ids().cloned().collect(), + })?; ex.get_type(&self.id) .ok_or(SignatureError::ExtensionTypeNotFound { exn: self.extension.clone(), @@ -120,6 +141,16 @@ impl CustomType { pub fn extension(&self) -> &ExtensionId { &self.extension } + + /// Returns a weak reference to the extension defining this type. + pub fn extension_ref(&self) -> Weak { + self.extension_ref.clone() + } + + /// Update the internal extension reference with a new weak pointer. + pub fn update_extension(&mut self, extension_ref: Weak) { + self.extension_ref = extension_ref; + } } impl Display for CustomType { @@ -142,6 +173,8 @@ impl From for Type { mod test { pub mod proptest { + use std::sync::Weak; + use crate::extension::ExtensionId; use crate::proptest::any_nonempty_string; use crate::proptest::RecursionDepth; @@ -184,7 +217,9 @@ mod test { vec(any_with::(depth.descend()), 0..3).boxed() }; (any_nonempty_string(), args, any::(), bound) - .prop_map(|(id, args, extension, bound)| Self::new(id, args, extension, bound)) + .prop_map(|(id, args, extension, bound)| { + Self::new(id, args, extension, bound, &Weak::default()) + }) .boxed() } } diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 7e3f4f664..754e32205 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -145,16 +145,22 @@ impl PolyFuncTypeBase { self.params.iter().map(ToString::to_string).join(" ") ) } + + /// Returns a mutable reference to the body of the function type. + pub fn body_mut(&mut self) -> &mut FuncTypeBase { + &mut self.body + } } #[cfg(test)] pub(crate) mod test { use std::num::NonZeroU64; + use std::sync::Arc; use cool_asserts::assert_matches; use lazy_static::lazy_static; - use crate::extension::prelude::{BOOL_T, PRELUDE_ID, USIZE_T}; + use crate::extension::prelude::{bool_t, usize_t, PRELUDE_ID}; use crate::extension::{ ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound, EMPTY_REG, PRELUDE, PRELUDE_REGISTRY, @@ -193,20 +199,20 @@ pub(crate) mod test { let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); let list_len = PolyFuncTypeBase::new_validated( [TypeBound::Any.into()], - Signature::new(vec![list_of_var], vec![USIZE_T]), + Signature::new(vec![list_of_var], vec![usize_t()]), ®ISTRY, )?; - let t = list_len.instantiate(&[TypeArg::Type { ty: USIZE_T }], ®ISTRY)?; + let t = list_len.instantiate(&[TypeArg::Type { ty: usize_t() }], ®ISTRY)?; assert_eq!( t, Signature::new( vec![Type::new_extension( list_def - .instantiate([TypeArg::Type { ty: USIZE_T }]) + .instantiate([TypeArg::Type { ty: usize_t() }]) .unwrap() )], - vec![USIZE_T] + vec![usize_t()] ) ); @@ -230,12 +236,18 @@ pub(crate) mod test { // Sanity check (good args) good_ts.instantiate( - &[TypeArg::BoundedNat { n: 5 }, TypeArg::Type { ty: USIZE_T }], + &[ + TypeArg::BoundedNat { n: 5 }, + TypeArg::Type { ty: usize_t() }, + ], &PRELUDE_REGISTRY, )?; let wrong_args = good_ts.instantiate( - &[TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 5 }], + &[ + TypeArg::Type { ty: usize_t() }, + TypeArg::BoundedNat { n: 5 }, + ], &PRELUDE_REGISTRY, ); assert_eq!( @@ -243,7 +255,7 @@ pub(crate) mod test { Err(SignatureError::TypeArgMismatch( TypeArgError::TypeMismatch { param: typarams[0].clone(), - arg: TypeArg::Type { ty: USIZE_T } + arg: TypeArg::Type { ty: usize_t() } } )) ); @@ -263,6 +275,7 @@ pub(crate) mod test { [szvar, tyvar], PRELUDE_ID, TypeBound::Any, + &Arc::downgrade(&PRELUDE), )); let bad_ts = PolyFuncTypeBase::new_validated( typarams.clone(), @@ -321,16 +334,18 @@ pub(crate) mod test { const EXT_ID: ExtensionId = ExtensionId::new_unchecked("my_ext"); const TYPE_NAME: TypeName = TypeName::new_inline("MyType"); - let mut e = Extension::new_test(EXT_ID); - e.add_type( - TYPE_NAME, - vec![bound.clone()], - "".into(), - TypeDefBound::any(), - ) - .unwrap(); + let ext = Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + ext.add_type( + TYPE_NAME, + vec![bound.clone()], + "".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + }); - let reg = ExtensionRegistry::try_new([e.into()]).unwrap(); + let reg = ExtensionRegistry::try_new([ext.clone()]).unwrap(); let make_scheme = |tp: TypeParam| { PolyFuncTypeBase::new_validated( @@ -340,6 +355,7 @@ pub(crate) mod test { [TypeArg::new_var_use(0, tp)], EXT_ID, TypeBound::Any, + &Arc::downgrade(&ext), ))), ®, ) @@ -401,7 +417,7 @@ pub(crate) mod test { let e = PolyFuncTypeBase::new_validated( [decl.clone()], FuncValueType::new( - vec![USIZE_T], + vec![usize_t()], vec![TypeRV::new_row_var_use(0, TypeBound::Copyable)], ), &PRELUDE_REGISTRY, @@ -430,7 +446,7 @@ pub(crate) mod test { let pf = PolyFuncTypeBase::new_validated( [TypeParam::new_list(TP_ANY)], FuncValueType::new( - vec![USIZE_T.into(), rty.clone()], + vec![usize_t().into(), rty.clone()], vec![TypeRV::new_tuple(rty)], ), &PRELUDE_REGISTRY, @@ -438,13 +454,13 @@ pub(crate) mod test { .unwrap(); fn seq2() -> Vec { - vec![USIZE_T.into(), BOOL_T.into()] + vec![usize_t().into(), bool_t().into()] } - pf.instantiate(&[TypeArg::Type { ty: USIZE_T }], &PRELUDE_REGISTRY) + pf.instantiate(&[TypeArg::Type { ty: usize_t() }], &PRELUDE_REGISTRY) .unwrap_err(); pf.instantiate( &[TypeArg::Sequence { - elems: vec![USIZE_T.into(), TypeArg::Sequence { elems: seq2() }], + elems: vec![usize_t().into(), TypeArg::Sequence { elems: seq2() }], }], &PRELUDE_REGISTRY, ) @@ -456,8 +472,8 @@ pub(crate) mod test { assert_eq!( t2, Signature::new( - vec![USIZE_T, USIZE_T, BOOL_T], - vec![Type::new_tuple(vec![USIZE_T, BOOL_T])] + vec![usize_t(), usize_t(), bool_t()], + vec![Type::new_tuple(vec![usize_t(), bool_t()])] ) ); } @@ -474,23 +490,23 @@ pub(crate) mod test { b: TypeBound::Copyable, }), }], - Signature::new(vec![USIZE_T, inner_fty.clone()], vec![inner_fty]), + Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), &PRELUDE_REGISTRY, ) .unwrap(); - let inner3 = Type::new_function(Signature::new_endo(vec![USIZE_T, BOOL_T, USIZE_T])); + let inner3 = Type::new_function(Signature::new_endo(vec![usize_t(), bool_t(), usize_t()])); let t3 = pf .instantiate( &[TypeArg::Sequence { - elems: vec![USIZE_T.into(), BOOL_T.into(), USIZE_T.into()], + elems: vec![usize_t().into(), bool_t().into(), usize_t().into()], }], &PRELUDE_REGISTRY, ) .unwrap(); assert_eq!( t3, - Signature::new(vec![USIZE_T, inner3.clone()], vec![inner3]) + Signature::new(vec![usize_t(), inner3.clone()], vec![inner3]) ); } } diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index de1f94da1..74d1be6e8 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -2,7 +2,7 @@ use super::{FuncValueType, MaybeRV, RowVariable, SumType, TypeArg, TypeBase, Typ use super::custom::CustomType; -use crate::extension::prelude::{array_type, QB_T, USIZE_T}; +use crate::extension::prelude::{array_type, qb_t, usize_t}; use crate::extension::SignatureError; use crate::ops::AliasDecl; @@ -22,10 +22,10 @@ pub(super) enum SerSimpleType { impl From> for SerSimpleType { fn from(value: TypeBase) -> Self { - if value == QB_T { + if value == qb_t() { return SerSimpleType::Q; }; - if value == USIZE_T { + if value == usize_t() { return SerSimpleType::I; }; match value.0 { @@ -46,8 +46,8 @@ impl TryFrom for TypeBase { type Error = SignatureError; fn try_from(value: SerSimpleType) -> Result { Ok(match value { - SerSimpleType::Q => QB_T.into_(), - SerSimpleType::I => USIZE_T.into_(), + SerSimpleType::Q => qb_t().into_(), + SerSimpleType::I => usize_t().into_(), SerSimpleType::G(sig) => TypeBase::new_function(*sig), SerSimpleType::Sum(st) => st.into(), SerSimpleType::Array { inner, len } => { diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index ee0164087..faf56abd2 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -298,7 +298,7 @@ impl PartialEq> for FuncTypeBase PolyFuncTypeRV { - FuncValueType::new_endo(QB_T).into() + FuncValueType::new_endo(qb_t()).into() } fn two_qb_func() -> PolyFuncTypeRV { - FuncValueType::new_endo(type_row![QB_T, QB_T]).into() + FuncValueType::new_endo(vec![qb_t(), qb_t()]).into() } /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); fn extension() -> Arc { - let mut extension = Extension::new_test(EXTENSION_ID); - - extension - .add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func()) - .unwrap(); - extension - .add_op( - OpName::new_inline("RzF64"), - "Rotation specified by float".into(), - Signature::new(type_row![QB_T, float_types::FLOAT64_TYPE], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func()) - .unwrap(); - - extension - .add_op( - OpName::new_inline("Measure"), - "Measure a qubit, returning the qubit and the measurement result.".into(), - Signature::new(type_row![QB_T], type_row![QB_T, BOOL_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("QAlloc"), - "Allocate a new qubit.".into(), - Signature::new(type_row![], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("QDiscard"), - "Discard a qubit.".into(), - Signature::new(type_row![QB_T], type_row![]), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| { + extension + .add_op( + OpName::new_inline("H"), + "Hadamard".into(), + one_qb_func(), + extension_ref, + ) + .unwrap(); + extension + .add_op( + OpName::new_inline("RzF64"), + "Rotation specified by float".into(), + Signature::new(vec![qb_t(), float_types::float64_type()], vec![qb_t()]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("CX"), + "CX".into(), + two_qb_func(), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("Measure"), + "Measure a qubit, returning the qubit and the measurement result.".into(), + Signature::new(vec![qb_t()], vec![qb_t(), bool_t()]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("QAlloc"), + "Allocate a new qubit.".into(), + Signature::new(type_row![], vec![qb_t()]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("QDiscard"), + "Discard a qubit.".into(), + Signature::new(vec![qb_t()], type_row![]), + extension_ref, + ) + .unwrap(); + }) } lazy_static! { /// Quantum extension definition. pub static ref EXTENSION: Arc = extension(); - static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.clone(), PRELUDE.clone(), float_types::EXTENSION.clone()]).unwrap(); + + /// A registry with all necessary extensions to run tests internally, including the test quantum extension. + pub static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([ + EXTENSION.clone(), + PRELUDE.clone(), + float_types::EXTENSION.clone(), + float_ops::EXTENSION.clone(), + logic::EXTENSION.clone() + ]).unwrap(); } diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 460d8f4c0..5ddc4eb32 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -8,7 +8,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call (forall ?0 ext-set) [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] - (ext arithmetic.int . ?0) + (ext ?0 ... arithmetic.int) (meta doc.description (@ prelude.json "\"This is a function declaration.\"")) (meta doc.title (@ prelude.json "\"Callee\""))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index f3c0f0acc..41a8f0d62 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-cfg.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.edn\"))" --- (hugr 0) @@ -16,13 +16,13 @@ expression: "roundtrip(include_str!(\"fixtures/model-cfg.edn\"))" [%2] [%8] (signature (fn [?0] [?0] (ext))) (block [%2] [%5] - (signature (fn [(ctrl [?0])] [(ctrl [?0 . []])] (ext))) + (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext))) (tag 0 [%3] [%4] (signature (fn [?0] [(adt [[?0]])] (ext)))))) (block [%5] [%8] - (signature (fn [(ctrl [?0])] [(ctrl [?0 . []])] (ext))) + (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg [%6] [%7] (signature (fn [?0] [(adt [[?0]])] (ext))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap index f085c4785..291c2de48 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -8,9 +8,9 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (forall ?0 type) (forall ?1 nat) (where (nonlinear ?0)) - [?0] [(@ array.Array ?0 ?1)] (ext)) + [?0] [(@ prelude.Array ?0 ?1)] (ext)) (declare-func array.copy (forall ?0 type) (where (nonlinear ?0)) - [(@ array.Array ?0)] [(@ array.Array ?0) (@ array.Array ?0)] (ext)) + [(@ prelude.Array ?0)] [(@ prelude.Array ?0) (@ prelude.Array ?0)] (ext)) diff --git a/hugr-llvm/Cargo.toml b/hugr-llvm/Cargo.toml index ce5cd278d..fd8035457 100644 --- a/hugr-llvm/Cargo.toml +++ b/hugr-llvm/Cargo.toml @@ -31,7 +31,7 @@ llvm14-0 = ["inkwell/llvm14-0"] [dependencies] -inkwell = { version = "0.4.0", default-features = false } +inkwell = { version = "0.5.0", default-features = false } hugr-core = { path = "../hugr-core", version = "0.13.3" } anyhow = "1.0.83" itertools.workspace = true diff --git a/hugr-llvm/src/custom.rs b/hugr-llvm/src/custom.rs index c5e384d29..3c4525954 100644 --- a/hugr-llvm/src/custom.rs +++ b/hugr-llvm/src/custom.rs @@ -156,7 +156,7 @@ pub struct CodegenExtsMap<'a, H> { #[cfg(test)] mod test { use hugr_core::{ - extension::prelude::{ConstString, PRELUDE_ID, PRINT_OP_ID, STRING_TYPE, STRING_TYPE_NAME}, + extension::prelude::{string_type, ConstString, PRELUDE_ID, PRINT_OP_ID, STRING_TYPE_NAME}, Hugr, }; use inkwell::{ @@ -187,7 +187,7 @@ mod test { let ty = cem .type_converter .session(&ctx) - .llvm_type(&STRING_TYPE) + .llvm_type(&string_type()) .unwrap() .into_struct_type(); let ty_n = ty.get_name().unwrap().to_str().unwrap(); diff --git a/hugr-llvm/src/custom/extension_op.rs b/hugr-llvm/src/custom/extension_op.rs index 08b392036..cd3c3b6e7 100644 --- a/hugr-llvm/src/custom/extension_op.rs +++ b/hugr-llvm/src/custom/extension_op.rs @@ -100,7 +100,7 @@ impl<'a, H: HugrView> ExtensionOpMap<'a, H> { args: EmitOpArgs<'c, '_, ExtensionOp, H>, ) -> Result<()> { let node = args.node(); - let key = (node.def().extension().clone(), node.def().name().clone()); + let key = (node.def().extension_id().clone(), node.def().name().clone()); let Some(handler) = self.0.get(&key) else { bail!("No extension could emit extension op: {key:?}") }; diff --git a/hugr-llvm/src/emit/func.rs b/hugr-llvm/src/emit/func.rs index e9d861337..0a09dc9b3 100644 --- a/hugr-llvm/src/emit/func.rs +++ b/hugr-llvm/src/emit/func.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, rc::Rc}; use anyhow::{anyhow, Result}; use hugr_core::{ + extension::prelude::{either_type, option_type}, ops::{constant::CustomConst, ExtensionOp, FuncDecl, FuncDefn}, types::Type, HugrView, NodeIndex, PortIndex, Wire, @@ -12,7 +13,7 @@ use inkwell::{ context::Context, module::Module, types::{BasicType, BasicTypeEnum, FunctionType}, - values::{BasicValueEnum, FunctionValue, GlobalValue}, + values::{BasicValueEnum, FunctionValue, GlobalValue, IntValue}, }; use itertools::zip_eq; @@ -314,3 +315,36 @@ impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> { Ok((self.emit_context, self.todo)) } } + +/// Builds an optional value wrapping `some_value` conditioned on the provided `is_some` flag. +pub fn build_option<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + is_some: IntValue<'c>, + some_value: BasicValueEnum<'c>, + hugr_ty: HugrType, +) -> Result> { + let option_ty = ctx.llvm_sum_type(option_type(hugr_ty))?; + let builder = ctx.builder(); + let some = option_ty.build_tag(builder, 1, vec![some_value])?; + let none = option_ty.build_tag(builder, 0, vec![])?; + let option = builder.build_select(is_some, some, none, "")?; + Ok(option) +} + +/// Builds a result value wrapping either `ok_value` or `else_value` depending on the provided +/// `is_ok` flag. +pub fn build_ok_or_else<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + is_ok: IntValue<'c>, + ok_value: BasicValueEnum<'c>, + ok_hugr_ty: HugrType, + else_value: BasicValueEnum<'c>, + else_hugr_ty: HugrType, +) -> Result> { + let either_ty = ctx.llvm_sum_type(either_type(else_hugr_ty, ok_hugr_ty))?; + let builder = ctx.builder(); + let left = either_ty.build_tag(builder, 0, vec![else_value])?; + let right = either_ty.build_tag(builder, 1, vec![ok_value])?; + let either = builder.build_select(is_ok, right, left, "")?; + Ok(either) +} diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index fe66fb312..44fe2d410 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -218,7 +218,7 @@ impl<'c, 'hugr, H: HugrView> CfgEmitter<'c, 'hugr, H> { #[cfg(test)] mod test { use hugr_core::builder::{Dataflow, DataflowSubContainer, SubContainer}; - use hugr_core::extension::prelude::{self, BOOL_T}; + use hugr_core::extension::prelude::{self, bool_t}; use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; use hugr_core::ops::Value; use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; @@ -245,8 +245,11 @@ mod test { .with_ins(vec![t1.clone(), t2.clone()]) .with_outs(t2.clone()) .with_extensions( - ExtensionRegistry::try_new([int_types::extension(), prelude::PRELUDE.to_owned()]) - .unwrap(), + ExtensionRegistry::try_new([ + int_types::EXTENSION.to_owned(), + prelude::PRELUDE.to_owned(), + ]) + .unwrap(), ) .finish(|mut builder| { let [in1, in2] = builder.input_wires_arr(); @@ -295,14 +298,14 @@ mod test { fn nested(llvm_ctx: TestContext) { let t1 = HugrType::new_unit_sum(3); let hugr = SimpleHugrConfig::new() - .with_ins(vec![t1.clone(), BOOL_T]) - .with_outs(BOOL_T) + .with_ins(vec![t1.clone(), bool_t()]) + .with_outs(bool_t()) .finish(|mut builder| { let [in1, in2] = builder.input_wires_arr(); let unit_val = builder.add_load_value(Value::unit()); let [outer_cfg_out] = { let mut outer_cfg_builder = builder - .cfg_builder([(t1.clone(), in1), (BOOL_T, in2)], BOOL_T.into()) + .cfg_builder([(t1.clone(), in1), (bool_t(), in2)], bool_t().into()) .unwrap(); let outer_entry_block = { @@ -312,8 +315,9 @@ mod test { let [outer_entry_in1, outer_entry_in2] = outer_entry_builder.input_wires_arr(); let [outer_entry_out] = { - let mut inner_cfg_builder = - outer_entry_builder.cfg_builder([], BOOL_T.into()).unwrap(); + let mut inner_cfg_builder = outer_entry_builder + .cfg_builder([], bool_t().into()) + .unwrap(); let inner_exit_block = inner_cfg_builder.exit_block(); let inner_entry_block = { let inner_entry_builder = inner_cfg_builder @@ -333,7 +337,7 @@ mod test { .block_builder( type_row![], vec![type_row![]], - BOOL_T.into(), + bool_t().into(), ) .unwrap(); let output = match i { @@ -373,7 +377,7 @@ mod test { let [b1, b2] = (0..2) .map(|i| { let mut b_builder = outer_cfg_builder - .block_builder(type_row![], vec![type_row![]], BOOL_T.into()) + .block_builder(type_row![], vec![type_row![]], bool_t().into()) .unwrap(); let output = match i { 0 => b_builder.add_load_value(Value::true_val()), diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 647636c95..f1c9eda8c 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -10,7 +10,7 @@ use hugr_core::ops::handle::FuncID; use hugr_core::std_extensions::arithmetic::{ conversions, float_ops, float_types, int_ops, int_types, }; -use hugr_core::std_extensions::logic; +use hugr_core::std_extensions::{collections, logic}; use hugr_core::types::TypeRow; use hugr_core::{Hugr, HugrView}; use inkwell::module::Module; @@ -153,6 +153,7 @@ impl SimpleHugrConfig { float_ops::EXTENSION_ID, conversions::EXTENSION_ID, logic::EXTENSION_ID, + collections::EXTENSION_ID, ]), ), ) @@ -249,7 +250,7 @@ mod test_fns { use hugr_core::builder::DataflowSubContainer; use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}; - use hugr_core::extension::prelude::{ConstUsize, BOOL_T, USIZE_T}; + use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize}; use hugr_core::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use hugr_core::ops::constant::CustomConst; @@ -377,8 +378,8 @@ mod test_fns { let mut mod_b = ModuleBuilder::new(); build_recursive(&mut mod_b, "main_void", type_row![]); - build_recursive(&mut mod_b, "main_unary", type_row![BOOL_T]); - build_recursive(&mut mod_b, "main_binary", type_row![BOOL_T, BOOL_T]); + build_recursive(&mut mod_b, "main_unary", vec![bool_t()].into()); + build_recursive(&mut mod_b, "main_binary", vec![bool_t(), bool_t()].into()); let hugr = mod_b.finish_hugr(&EMPTY_REG).unwrap(); check_emission!(hugr, llvm_ctx); } @@ -399,8 +400,8 @@ mod test_fns { let mut mod_b = ModuleBuilder::new(); build_recursive(&mut mod_b, "main_void", type_row![]); - build_recursive(&mut mod_b, "main_unary", type_row![BOOL_T]); - build_recursive(&mut mod_b, "main_binary", type_row![BOOL_T, BOOL_T]); + build_recursive(&mut mod_b, "main_unary", vec![bool_t()].into()); + build_recursive(&mut mod_b, "main_binary", vec![bool_t(), bool_t()].into()); let hugr = mod_b.finish_hugr(&EMPTY_REG).unwrap(); check_emission!(hugr, llvm_ctx); } @@ -466,16 +467,19 @@ mod test_fns { #[rstest] fn diverse_dfg_children(llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_outs(BOOL_T) + .with_outs(bool_t()) .finish(|mut builder: DFGW| { let [r] = { let mut builder = builder - .dfg_builder(HugrFuncType::new(type_row![], BOOL_T), []) + .dfg_builder(HugrFuncType::new(type_row![], bool_t()), []) .unwrap(); let konst = builder.add_constant(Value::false_val()); let func = { let mut builder = builder - .define_function("scoped_func", HugrFuncType::new(type_row![], BOOL_T)) + .define_function( + "scoped_func", + HugrFuncType::new(type_row![], bool_t()), + ) .unwrap(); let w = builder.load_const(&konst); builder.finish_with_outputs([w]).unwrap() @@ -494,21 +498,24 @@ mod test_fns { #[rstest] fn diverse_cfg_children(llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_outs(BOOL_T) + .with_outs(bool_t()) .finish(|mut builder: DFGW| { let [r] = { - let mut builder = builder.cfg_builder([], type_row![BOOL_T]).unwrap(); + let mut builder = builder.cfg_builder([], vec![bool_t()].into()).unwrap(); let konst = builder.add_constant(Value::false_val()); let func = { let mut builder = builder - .define_function("scoped_func", HugrFuncType::new(type_row![], BOOL_T)) + .define_function( + "scoped_func", + HugrFuncType::new(type_row![], bool_t()), + ) .unwrap(); let w = builder.load_const(&konst); builder.finish_with_outputs([w]).unwrap() }; let entry = { let mut builder = builder - .entry_builder([type_row![]], type_row![BOOL_T]) + .entry_builder([type_row![]], vec![bool_t()].into()) .unwrap(); let control = builder.add_load_value(Value::unary_unit_sum()); let [r] = builder @@ -553,7 +560,7 @@ mod test_fns { #[rstest] fn test_exec(mut exec_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); diff --git a/hugr-llvm/src/extension.rs b/hugr-llvm/src/extension.rs index fe6db7ed5..4757926b3 100644 --- a/hugr-llvm/src/extension.rs +++ b/hugr-llvm/src/extension.rs @@ -1,3 +1,4 @@ +pub mod collections; pub mod conversions; pub mod float; pub mod int; diff --git a/hugr-llvm/src/extension/collections.rs b/hugr-llvm/src/extension/collections.rs new file mode 100644 index 000000000..60be7e3cd --- /dev/null +++ b/hugr-llvm/src/extension/collections.rs @@ -0,0 +1,445 @@ +use anyhow::{bail, Ok, Result}; +use hugr_core::{ + ops::{ExtensionOp, NamedOp}, + std_extensions::collections::{self, ListOp, ListValue}, + types::{SumType, Type, TypeArg}, + HugrView, +}; +use inkwell::values::FunctionValue; +use inkwell::{ + types::{BasicType, BasicTypeEnum, FunctionType}, + values::{BasicValueEnum, PointerValue}, + AddressSpace, +}; + +use crate::emit::func::{build_ok_or_else, build_option}; +use crate::{ + custom::{CodegenExtension, CodegenExtsBuilder}, + emit::{emit_value, func::EmitFuncContext, EmitOpArgs}, + types::TypingSession, +}; + +/// Runtime functions that implement operations on lists. +#[derive(Clone, Copy, Debug, PartialEq, Hash)] +#[non_exhaustive] +pub enum CollectionsRtFunc { + New, + Push, + Pop, + Get, + Set, + Insert, + Length, +} + +impl CollectionsRtFunc { + /// The signature of a given [CollectionsRtFunc]. + /// + /// Requires a [CollectionsCodegen] to determine the type of lists. + pub fn signature<'c>( + self, + ts: TypingSession<'c, '_>, + ccg: &(impl CollectionsCodegen + 'c), + ) -> FunctionType<'c> { + let iwc = ts.iw_context(); + match self { + CollectionsRtFunc::New => ccg.list_type(ts).fn_type( + &[ + iwc.i64_type().into(), // Capacity + iwc.i64_type().into(), // Single element size in bytes + iwc.i64_type().into(), // Element alignment + // Pointer to element destructor + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ), + CollectionsRtFunc::Push => iwc.void_type().fn_type( + &[ + ccg.list_type(ts).into(), + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ), + CollectionsRtFunc::Pop => iwc.bool_type().fn_type( + &[ + ccg.list_type(ts).into(), + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ), + CollectionsRtFunc::Get | CollectionsRtFunc::Set | CollectionsRtFunc::Insert => { + iwc.bool_type().fn_type( + &[ + ccg.list_type(ts).into(), + iwc.i64_type().into(), + iwc.i8_type().ptr_type(AddressSpace::default()).into(), + ], + false, + ) + } + CollectionsRtFunc::Length => iwc.i64_type().fn_type(&[ccg.list_type(ts).into()], false), + } + } + + /// Returns the extern function corresponding to this [CollectionsRtFunc]. + /// + /// Requires a [CollectionsCodegen] to determine the function signature. + pub fn get_extern<'c, H: HugrView>( + self, + ctx: &EmitFuncContext<'c, '_, H>, + ccg: &(impl CollectionsCodegen + 'c), + ) -> Result> { + ctx.get_extern_func( + ccg.rt_func_name(self), + self.signature(ctx.typing_session(), ccg), + ) + } +} + +impl From for CollectionsRtFunc { + fn from(op: ListOp) -> Self { + match op { + ListOp::get => CollectionsRtFunc::Get, + ListOp::set => CollectionsRtFunc::Set, + ListOp::push => CollectionsRtFunc::Push, + ListOp::pop => CollectionsRtFunc::Pop, + ListOp::insert => CollectionsRtFunc::Insert, + ListOp::length => CollectionsRtFunc::Length, + _ => todo!(), + } + } +} + +/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections] +/// types, [hugr_core::ops::constant::CustomConst]s, and ops. +pub trait CollectionsCodegen: Clone { + /// Return the llvm type of [hugr_core::std_extensions::collections::LIST_TYPENAME]. + fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> { + session + .iw_context() + .i8_type() + .ptr_type(AddressSpace::default()) + .into() + } + + /// Return the name of a given [CollectionsRtFunc]. + fn rt_func_name(&self, func: CollectionsRtFunc) -> String { + match func { + CollectionsRtFunc::New => "__rt__list__new", + CollectionsRtFunc::Push => "__rt__list__push", + CollectionsRtFunc::Pop => "__rt__list__pop", + CollectionsRtFunc::Get => "__rt__list__get", + CollectionsRtFunc::Set => "__rt__list__set", + CollectionsRtFunc::Insert => "__rt__list__insert", + CollectionsRtFunc::Length => "__rt__list__length", + } + .into() + } +} + +/// A trivial implementation of [CollectionsCodegen] which passes all methods +/// through to their default implementations. +#[derive(Default, Clone)] +pub struct DefaultCollectionsCodegen; + +impl CollectionsCodegen for DefaultCollectionsCodegen {} + +#[derive(Clone, Debug, Default)] +pub struct CollectionsCodegenExtension(CCG); + +impl CollectionsCodegenExtension { + pub fn new(ccg: CCG) -> Self { + Self(ccg) + } +} + +impl From for CollectionsCodegenExtension { + fn from(ccg: CCG) -> Self { + Self::new(ccg) + } +} + +impl CodegenExtension for CollectionsCodegenExtension { + fn add_extension<'a, H: HugrView + 'a>( + self, + builder: CodegenExtsBuilder<'a, H>, + ) -> CodegenExtsBuilder<'a, H> + where + Self: 'a, + { + builder + .custom_type((collections::EXTENSION_ID, collections::LIST_TYPENAME), { + let ccg = self.0.clone(); + move |ts, _hugr_type| Ok(ccg.list_type(ts).as_basic_type_enum()) + }) + .custom_const::({ + let ccg = self.0.clone(); + move |ctx, k| emit_list_value(ctx, &ccg, k) + }) + .simple_extension_op::(move |ctx, args, op| { + emit_list_op(ctx, &self.0, args, op) + }) + } +} + +impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { + /// Add a [CollectionsCodegenExtension] to the given [CodegenExtsBuilder] using `ccg` + /// as the implementation. + pub fn add_default_collections_extensions(self) -> Self { + self.add_collections_extensions(DefaultCollectionsCodegen) + } + + /// Add a [CollectionsCodegenExtension] to the given [CodegenExtsBuilder] using + /// [DefaultCollectionsCodegen] as the implementation. + pub fn add_collections_extensions(self, ccg: impl CollectionsCodegen + 'a) -> Self { + self.add_extension(CollectionsCodegenExtension::from(ccg)) + } +} + +fn emit_list_op<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ccg: &(impl CollectionsCodegen + 'c), + args: EmitOpArgs<'c, '_, ExtensionOp, H>, + op: ListOp, +) -> Result<()> { + let hugr_elem_ty = match args.node().args() { + [TypeArg::Type { ty }] => ty.clone(), + _ => { + bail!("Collections: invalid type args for list op"); + } + }; + let elem_ty = ctx.llvm_type(&hugr_elem_ty)?; + let func = CollectionsRtFunc::get_extern(op.into(), ctx, ccg)?; + match op { + ListOp::push => { + let [list, elem] = args.inputs.try_into().unwrap(); + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + ctx.builder() + .build_call(func, &[list.into(), elem_ptr.into()], "")?; + args.outputs.finish(ctx.builder(), vec![list])?; + } + ListOp::pop => { + let [list] = args.inputs.try_into().unwrap(); + let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), out_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?; + let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![list, elem_opt])?; + } + ListOp::get => { + let [list, idx] = args.inputs.try_into().unwrap(); + let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?; + let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![elem_opt])?; + } + ListOp::set => { + let [list, idx, elem] = args.inputs.try_into().unwrap(); + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let old_elem = build_load_i8_ptr(ctx, elem_ptr, elem.get_type())?; + let ok_or = + build_ok_or_else(ctx, ok, elem, hugr_elem_ty.clone(), old_elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![list, ok_or])?; + } + ListOp::insert => { + let [list, idx, elem] = args.inputs.try_into().unwrap(); + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + let ok = ctx + .builder() + .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + let unit = + ctx.llvm_sum_type(SumType::new_unary(1))? + .build_tag(ctx.builder(), 0, vec![])?; + let ok_or = build_ok_or_else(ctx, ok, unit, Type::UNIT, elem, hugr_elem_ty)?; + args.outputs.finish(ctx.builder(), vec![list, ok_or])?; + } + ListOp::length => { + let [list] = args.inputs.try_into().unwrap(); + let length = ctx + .builder() + .build_call(func, &[list.into()], "")? + .try_as_basic_value() + .unwrap_left() + .into_int_value(); + args.outputs + .finish(ctx.builder(), vec![list, length.into()])?; + } + _ => bail!("Collections: unimplemented op: {}", op.name()), + } + Ok(()) +} + +fn emit_list_value<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ccg: &(impl CollectionsCodegen + 'c), + val: &ListValue, +) -> Result> { + let elem_ty = ctx.llvm_type(val.get_element_type())?; + let iwc = ctx.typing_session().iw_context(); + let capacity = iwc + .i64_type() + .const_int(val.get_contents().len() as u64, false); + let elem_size = elem_ty.size_of().unwrap(); + let alignment = iwc.i64_type().const_int(8, false); + // TODO: Lookup destructor for elem_ty + let destructor = iwc.i8_type().ptr_type(AddressSpace::default()).const_null(); + let list = ctx + .builder() + .build_call( + CollectionsRtFunc::New.get_extern(ctx, ccg)?, + &[ + capacity.into(), + elem_size.into(), + alignment.into(), + destructor.into(), + ], + "", + )? + .try_as_basic_value() + .unwrap_left(); + // Push elements onto the list + let rt_push = CollectionsRtFunc::Push.get_extern(ctx, ccg)?; + for v in val.get_contents() { + let elem = emit_value(ctx, v)?; + let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?; + ctx.builder() + .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?; + } + Ok(list) +} + +/// Helper function to allocate space on the stack for a given type. +/// +/// Optionally also stores a value at that location. +/// +/// Returns an i8 pointer to the allocated memory. +fn build_alloca_i8_ptr<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ty: BasicTypeEnum<'c>, + value: Option>, +) -> Result> { + let builder = ctx.builder(); + let ptr = builder.build_alloca(ty, "")?; + if let Some(val) = value { + builder.build_store(ptr, val)?; + } + let i8_ptr = builder.build_pointer_cast( + ptr, + ctx.iw_context().i8_type().ptr_type(AddressSpace::default()), + "", + )?; + Ok(i8_ptr) +} + +/// Helper function to load a value from an i8 pointer. +fn build_load_i8_ptr<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + i8_ptr: PointerValue<'c>, + ty: BasicTypeEnum<'c>, +) -> Result> { + let builder = ctx.builder(); + let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?; + let val = builder.build_load(ptr, "")?; + Ok(val) +} + +#[cfg(test)] +mod test { + use hugr_core::{ + builder::{Dataflow, DataflowSubContainer}, + extension::{ + prelude::{self, qb_t, usize_t, ConstUsize}, + ExtensionRegistry, + }, + ops::{DataflowOpTrait, NamedOp, Value}, + std_extensions::collections::{self, list_type, ListOp, ListValue}, + }; + use rstest::rstest; + + use crate::{ + check_emission, + custom::CodegenExtsBuilder, + emit::test::SimpleHugrConfig, + test::{llvm_ctx, TestContext}, + }; + + #[rstest] + #[case::push(ListOp::push)] + #[case::pop(ListOp::pop)] + #[case::get(ListOp::get)] + #[case::set(ListOp::set)] + #[case::insert(ListOp::insert)] + #[case::length(ListOp::length)] + fn test_collections_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) { + let ext_op = collections::EXTENSION + .instantiate_extension_op( + op.name().as_ref(), + [qb_t().into()], + &collections::COLLECTIONS_REGISTRY, + ) + .unwrap(); + let es = ExtensionRegistry::try_new([ + collections::EXTENSION.to_owned(), + prelude::PRELUDE.to_owned(), + ]) + .unwrap(); + let hugr = SimpleHugrConfig::new() + .with_ins(ext_op.signature().input().clone()) + .with_outs(ext_op.signature().output().clone()) + .with_extensions(es) + .finish(|mut hugr_builder| { + let outputs = hugr_builder + .add_dataflow_op(ext_op, hugr_builder.input_wires()) + .unwrap() + .outputs(); + hugr_builder.finish_with_outputs(outputs).unwrap() + }); + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_collections_extensions); + check_emission!(op.name().as_str(), hugr, llvm_ctx); + } + + #[rstest] + fn test_const_list_emmission(mut llvm_ctx: TestContext) { + let elem_ty = usize_t(); + let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i))); + let es = ExtensionRegistry::try_new([ + collections::EXTENSION.to_owned(), + prelude::PRELUDE.to_owned(), + ]) + .unwrap(); + + let hugr = SimpleHugrConfig::new() + .with_ins(vec![]) + .with_outs(vec![list_type(elem_ty.clone())]) + .with_extensions(es) + .finish(|mut hugr_builder| { + let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents)); + hugr_builder.finish_with_outputs(vec![list]).unwrap() + }); + + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_collections_extensions); + check_emission!("const", hugr, llvm_ctx); + } +} diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index 990715094..45e6fd485 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, bail, ensure, Result}; use hugr_core::{ extension::{ - prelude::{sum_with_error, ConstError, BOOL_T}, + prelude::{bool_t, sum_with_error, ConstError}, simple_op::MakeExtensionOp, }, ops::{constant::Value, custom::ExtensionOp, DataflowOpTrait as _}, @@ -179,9 +179,9 @@ fn emit_conversion_op<'c, H: HugrView>( .into_int_type(); let sum_ty = context .typing_session() - .llvm_sum_type(match BOOL_T.as_type_enum() { + .llvm_sum_type(match bool_t().as_type_enum() { TypeEnum::Sum(st) => st.clone(), - _ => panic!("Hugr prelude BOOL_T not a Sum"), + _ => panic!("Hugr prelude bool_t() not a Sum"), })?; emit_custom_unary_op(context, args, |ctx, arg, _| { @@ -254,10 +254,10 @@ mod test { use hugr_core::std_extensions::arithmetic::int_types::ConstInt; use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, - extension::prelude::{ConstUsize, PRELUDE_REGISTRY, USIZE_T}, + extension::prelude::{usize_t, ConstUsize, PRELUDE_REGISTRY}, std_extensions::arithmetic::{ conversions::{ConvertOpDef, CONVERT_OPS_REGISTRY, EXTENSION}, - float_types::FLOAT64_TYPE, + float_types::float64_type, int_types::INT_TYPES, }, types::Type, @@ -302,7 +302,7 @@ mod test { .add_conversion_extensions() }); let in_ty = INT_TYPES[log_width as usize].clone(); - let out_ty = FLOAT64_TYPE; + let out_ty = float64_type(); let hugr = test_conversion_op(op_name, in_ty, out_ty, log_width); check_emission!(op_name, hugr, llvm_ctx); } @@ -322,7 +322,7 @@ mod test { .add_conversion_extensions() .add_default_prelude_extensions() }); - let in_ty = FLOAT64_TYPE; + let in_ty = float64_type(); let out_ty = sum_with_error(INT_TYPES[log_width as usize].clone()); let hugr = test_conversion_op(op_name, in_ty, out_ty.into(), log_width); check_emission!(op_name, hugr, llvm_ctx); @@ -336,7 +336,7 @@ mod test { #[case] op_name: &str, #[case] input_int: bool, ) { - let mut tys = [INT_TYPES[0].clone(), BOOL_T]; + let mut tys = [INT_TYPES[0].clone(), bool_t()]; if !input_int { tys.reverse() }; @@ -368,7 +368,7 @@ mod test { #[rstest] fn my_test_exec(mut exec_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(PRELUDE_REGISTRY.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(ConstUsize::new(42)); @@ -384,7 +384,7 @@ mod test { #[case(18_446_744_073_709_551_615)] fn usize_roundtrip(mut exec_ctx: TestContext, #[case] val: u64) -> () { let hugr = SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(CONVERT_OPS_REGISTRY.clone()) .finish(|mut builder: DFGW| { let k = builder.add_load_value(ConstUsize::new(val)); @@ -410,7 +410,7 @@ mod test { fn roundtrip_hugr(val: u64, signed: bool) -> Hugr { let int64 = INT_TYPES[6].clone(); SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(CONVERT_OPS_REGISTRY.clone()) .finish(|mut builder| { let k = builder.add_load_value(ConstUsize::new(val)); @@ -572,7 +572,7 @@ mod test { use hugr_core::type_row; let hugr = SimpleHugrConfig::new() - .with_outs(vec![USIZE_T]) + .with_outs(vec![usize_t()]) .with_extensions(CONVERT_OPS_REGISTRY.to_owned()) .finish(|mut builder| { let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap()); @@ -581,7 +581,11 @@ mod test { .unwrap(); let [b] = builder.add_dataflow_op(ext_op, [i]).unwrap().outputs_arr(); let mut cond = builder - .conditional_builder(([type_row![], type_row![]], b), [], type_row![USIZE_T]) + .conditional_builder( + ([type_row![], type_row![]], b), + [], + vec![usize_t()].into(), + ) .unwrap(); let mut case_false = cond.case_builder(0).unwrap(); let false_result = case_false.add_load_value(ConstUsize::new(1)); diff --git a/hugr-llvm/src/extension/float.rs b/hugr-llvm/src/extension/float.rs index 4aa9eae2b..7cb694cb0 100644 --- a/hugr-llvm/src/extension/float.rs +++ b/hugr-llvm/src/extension/float.rs @@ -34,7 +34,7 @@ fn emit_fcmp<'c, H: HugrView>( rhs.into_float_value(), "", )?; - // convert to whatever BOOL_T is + // convert to whatever bool_t is Ok(vec![ctx .builder() .build_select(r, true_val, false_val, "")?]) @@ -114,7 +114,7 @@ pub fn add_float_extensions<'a, H: HugrView + 'a>( cem.custom_type( ( float_types::EXTENSION_ID, - float_types::FLOAT64_CUSTOM_TYPE.name().clone(), + float_types::FLOAT_TYPE_ID.clone(), ), |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()), ) @@ -139,7 +139,7 @@ mod test { builder::{Dataflow, DataflowSubContainer}, std_extensions::arithmetic::{ float_ops::FLOAT_OPS_REGISTRY, - float_types::{ConstF64, FLOAT64_TYPE}, + float_types::{float64_type, ConstF64}, }, }; use rstest::rstest; @@ -176,7 +176,7 @@ mod test { fn const_float(mut llvm_ctx: TestContext) { llvm_ctx.add_extensions(add_float_extensions); let hugr = SimpleHugrConfig::new() - .with_outs(FLOAT64_TYPE) + .with_outs(float64_type()) .with_extensions(FLOAT_OPS_REGISTRY.to_owned()) .finish(|mut builder| { let c = builder.add_load_value(ConstF64::new(3.12)); diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index aa30a2a85..e6d045ceb 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -40,7 +40,7 @@ fn emit_icmp<'c, H: HugrView>( rhs.into_int_value(), "", )?; - // convert to whatever BOOL_T is + // convert to whatever bool_t is Ok(vec![ctx .builder() .build_select(r, true_val, false_val, "")?]) @@ -167,7 +167,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { mod test { use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, - extension::prelude::BOOL_T, + extension::prelude::bool_t, std_extensions::arithmetic::{int_ops, int_types::INT_TYPES}, types::TypeRow, Hugr, @@ -187,7 +187,7 @@ mod test { } fn test_binary_icmp_op(name: impl AsRef, log_width: u8) -> Hugr { - test_binary_int_op_with_results(name, log_width, vec![BOOL_T]) + test_binary_int_op_with_results(name, log_width, vec![bool_t()]) } fn test_binary_int_op_with_results( name: impl AsRef, diff --git a/hugr-llvm/src/extension/logic.rs b/hugr-llvm/src/extension/logic.rs index 0e05e4927..88dc77a2f 100644 --- a/hugr-llvm/src/extension/logic.rs +++ b/hugr-llvm/src/extension/logic.rs @@ -93,7 +93,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { mod test { use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, - extension::{prelude::BOOL_T, ExtensionRegistry}, + extension::{prelude::bool_t, ExtensionRegistry}, std_extensions::logic::{self, LogicOp}, Hugr, }; @@ -108,8 +108,8 @@ mod test { fn test_logic_op(op: LogicOp, arity: usize) -> Hugr { SimpleHugrConfig::new() - .with_ins(vec![BOOL_T; arity]) - .with_outs(vec![BOOL_T]) + .with_ins(vec![bool_t(); arity]) + .with_outs(vec![bool_t()]) .with_extensions(ExtensionRegistry::try_new(vec![logic::EXTENSION.to_owned()]).unwrap()) .finish(|mut builder| { let outputs = builder diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index dd5242784..445317912 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -1,10 +1,10 @@ use anyhow::{anyhow, bail, ensure, Ok, Result}; +use hugr_core::extension::prelude::{ERROR_TYPE_NAME, STRING_TYPE_NAME}; use hugr_core::{ extension::{ prelude::{ - self, ArrayOp, ArrayOpDef, ConstError, ConstExternalSymbol, ConstString, ConstUsize, - MakeTuple, TupleOpDef, UnpackTuple, ARRAY_TYPE_NAME, ERROR_CUSTOM_TYPE, ERROR_TYPE, - STRING_CUSTOM_TYPE, + self, error_type, ArrayOp, ArrayOpDef, ConstError, ConstExternalSymbol, ConstString, + ConstUsize, MakeTuple, TupleOpDef, UnpackTuple, ARRAY_TYPE_NAME, }, simple_op::MakeExtensionOp as _, }, @@ -39,18 +39,18 @@ pub mod array; /// a trivial implementation of this trait which delegates everything to those /// default implementations. pub trait PreludeCodegen: Clone { - /// Return the llvm type of [hugr_core::extension::prelude::USIZE_T]. That type + /// Return the llvm type of [hugr_core::extension::prelude::usize_t]. That type /// must be an [IntType]. fn usize_type<'c>(&self, session: &TypingSession<'c, '_>) -> IntType<'c> { session.iw_context().i64_type() } - /// Return the llvm type of [hugr_core::extension::prelude::QB_T]. + /// Return the llvm type of [hugr_core::extension::prelude::qb_t]. fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> { session.iw_context().i16_type() } - /// Return the llvm type of [hugr_core::extension::prelude::ERROR_TYPE]. + /// Return the llvm type of [hugr_core::extension::prelude::error_type()]. /// /// The returned type must always match the type of the returned value of /// [Self::emit_const_error], and the `err` argument of [Self::emit_panic]. @@ -114,7 +114,7 @@ pub trait PreludeCodegen: Clone { err: &ConstError, ) -> Result> { let builder = ctx.builder(); - let err_ty = ctx.llvm_type(&ERROR_TYPE)?.into_struct_type(); + let err_ty = ctx.llvm_type(&error_type())?.into_struct_type(); let signal = err_ty .get_field_type_at_index(0) .unwrap() @@ -222,7 +222,7 @@ fn add_prelude_extensions<'a, H: HugrView + 'a>( let pcg = pcg.clone(); move |ts, _| Ok(pcg.usize_type(&ts).as_basic_type_enum()) }) - .custom_type((prelude::PRELUDE_ID, STRING_CUSTOM_TYPE.name().clone()), { + .custom_type((prelude::PRELUDE_ID, STRING_TYPE_NAME.clone()), { move |ts, _| { // TODO allow customising string type Ok(ts @@ -232,7 +232,7 @@ fn add_prelude_extensions<'a, H: HugrView + 'a>( .into()) } }) - .custom_type((prelude::PRELUDE_ID, ERROR_CUSTOM_TYPE.name().clone()), { + .custom_type((prelude::PRELUDE_ID, ERROR_TYPE_NAME.clone()), { let pcg = pcg.clone(); move |ts, _| Ok(pcg.error_type(&ts)?.as_basic_type_enum()) }) @@ -347,7 +347,7 @@ mod test { use hugr_core::extension::{PRELUDE, PRELUDE_REGISTRY}; use hugr_core::types::{Type, TypeArg}; use hugr_core::{type_row, Hugr}; - use prelude::{BOOL_T, PANIC_OP_ID, PRINT_OP_ID, QB_T, USIZE_T}; + use prelude::{bool_t, qb_t, usize_t, PANIC_OP_ID, PRINT_OP_ID}; use rstest::rstest; use crate::check_emission; @@ -381,11 +381,11 @@ mod test { assert_eq!( iw_context.i32_type().as_basic_type_enum(), - session.llvm_type(&USIZE_T).unwrap() + session.llvm_type(&usize_t()).unwrap() ); assert_eq!( iw_context.f64_type().as_basic_type_enum(), - session.llvm_type(&QB_T).unwrap() + session.llvm_type(&qb_t()).unwrap() ); } @@ -395,11 +395,11 @@ mod test { let tc = llvm_ctx.get_typing_session(); assert_eq!( llvm_ctx.iw_context().i32_type().as_basic_type_enum(), - tc.llvm_type(&USIZE_T).unwrap() + tc.llvm_type(&usize_t()).unwrap() ); assert_eq!( llvm_ctx.iw_context().f64_type().as_basic_type_enum(), - tc.llvm_type(&QB_T).unwrap() + tc.llvm_type(&qb_t()).unwrap() ); } @@ -412,7 +412,7 @@ mod test { #[rstest] fn prelude_const_usize(prelude_llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let k = builder.add_load_value(ConstUsize::new(17)); @@ -423,10 +423,13 @@ mod test { #[rstest] fn prelude_const_external_symbol(prelude_llvm_ctx: TestContext) { - let konst1 = ConstExternalSymbol::new("sym1", USIZE_T, true); + let konst1 = ConstExternalSymbol::new("sym1", usize_t(), true); let konst2 = ConstExternalSymbol::new( "sym2", - HugrType::new_sum([type_row![USIZE_T, HugrType::new_unit_sum(3)], type_row![]]), + HugrType::new_sum([ + vec![usize_t(), HugrType::new_unit_sum(3)].into(), + type_row![], + ]), false, ); @@ -444,8 +447,8 @@ mod test { #[rstest] fn prelude_make_tuple(prelude_llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_ins(vec![BOOL_T, BOOL_T]) - .with_outs(Type::new_tuple(vec![BOOL_T, BOOL_T])) + .with_ins(vec![bool_t(), bool_t()]) + .with_outs(Type::new_tuple(vec![bool_t(), bool_t()])) .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let in_wires = builder.input_wires(); @@ -458,13 +461,13 @@ mod test { #[rstest] fn prelude_unpack_tuple(prelude_llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_ins(Type::new_tuple(vec![BOOL_T, BOOL_T])) - .with_outs(vec![BOOL_T, BOOL_T]) + .with_ins(Type::new_tuple(vec![bool_t(), bool_t()])) + .with_outs(vec![bool_t(), bool_t()]) .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let unpack = builder .add_dataflow_op( - UnpackTuple::new(vec![BOOL_T, BOOL_T].into()), + UnpackTuple::new(vec![bool_t(), bool_t()].into()), builder.input_wires(), ) .unwrap(); @@ -476,9 +479,9 @@ mod test { #[rstest] fn prelude_panic(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "PANIC"); - const TYPE_ARG_Q: TypeArg = TypeArg::Type { ty: QB_T }; + let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; let type_arg_2q: TypeArg = TypeArg::Sequence { - elems: vec![TYPE_ARG_Q, TYPE_ARG_Q], + elems: vec![type_arg_q.clone(), type_arg_q], }; let panic_op = PRELUDE .instantiate_extension_op( @@ -489,8 +492,8 @@ mod test { .unwrap(); let hugr = SimpleHugrConfig::new() - .with_ins(vec![QB_T, QB_T]) - .with_outs(vec![QB_T, QB_T]) + .with_ins(vec![qb_t(), qb_t()]) + .with_outs(vec![qb_t(), qb_t()]) .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let [q0, q1] = builder.input_wires_arr(); diff --git a/hugr-llvm/src/extension/prelude/array.rs b/hugr-llvm/src/extension/prelude/array.rs index 96f5d8c24..fa33b7407 100644 --- a/hugr-llvm/src/extension/prelude/array.rs +++ b/hugr-llvm/src/extension/prelude/array.rs @@ -42,7 +42,7 @@ fn with_array_alloca<'c, T, E: From>( }; let ptr = builder.build_array_alloca(array_ty.get_element_type(), array_len, "")?; let array_ptr = builder - .build_bitcast(ptr, array_ty.ptr_type(Default::default()), "")? + .build_bit_cast(ptr, array_ty.ptr_type(Default::default()), "")? .into_pointer_value(); builder.build_store(array_ptr, array)?; go(ptr) @@ -176,7 +176,7 @@ pub fn emit_array_op<'c, H: HugrView>( let elem_v = builder.build_load(elem_addr, "")?; builder.build_store(elem_addr, value_v)?; let ptr = builder - .build_bitcast( + .build_bit_cast( ptr, array_v.get_type().ptr_type(Default::default()), "", @@ -261,7 +261,7 @@ pub fn emit_array_op<'c, H: HugrView>( builder.build_store(elem1_addr, elem2_v)?; builder.build_store(elem2_addr, elem1_v)?; let ptr = builder - .build_bitcast( + .build_bit_cast( ptr, array_v.get_type().ptr_type(Default::default()), "", @@ -376,7 +376,7 @@ fn emit_pop_op<'c>( .get_element_type() .array_type(size as u32 - 1); let ptr = builder - .build_bitcast(ptr, new_array_ty.ptr_type(Default::default()), "")? + .build_bit_cast(ptr, new_array_ty.ptr_type(Default::default()), "")? .into_pointer_value(); let array_v = builder.build_load(ptr, "")?; Ok((elem_v, array_v)) @@ -390,7 +390,7 @@ mod test { builder::{Dataflow, DataflowSubContainer, SubContainer}, extension::{ prelude::{ - self, array_type, option_type, ConstUsize, UnwrapBuilder as _, BOOL_T, USIZE_T, + self, array_type, bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder as _, }, ExtensionRegistry, }, @@ -436,8 +436,8 @@ mod test { .finish(|mut builder| { let us1 = builder.add_load_value(ConstUsize::new(1)); let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(USIZE_T, [us1, us2]).unwrap(); - builder.add_array_get(USIZE_T, 2, arr, us1).unwrap(); + let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); + builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); builder.finish_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); @@ -465,22 +465,22 @@ mod test { // - Gets the element at the given index // - Returns the element if the index is in bounds, otherwise 0 let hugr = SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(exec_registry()) .finish(|mut builder| { let us0 = builder.add_load_value(ConstUsize::new(0)); let us1 = builder.add_load_value(ConstUsize::new(1)); let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(USIZE_T, [us1, us2]).unwrap(); + let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let i = builder.add_load_value(ConstUsize::new(index)); - let get_r = builder.add_array_get(USIZE_T, 2, arr, i).unwrap(); + let get_r = builder.add_array_get(usize_t(), 2, arr, i).unwrap(); let r = { - let ot = option_type(USIZE_T); + let ot = option_type(usize_t()); let variants = (0..ot.num_variants()) .map(|i| ot.get_variant(i).cloned().unwrap().try_into().unwrap()) .collect_vec(); let mut builder = builder - .conditional_builder((variants, get_r), [], USIZE_T.into()) + .conditional_builder((variants, get_r), [], usize_t().into()) .unwrap(); { let failure_case = builder.case_builder(0).unwrap(); @@ -521,7 +521,7 @@ mod test { use hugr_core::extension::prelude::either_type; let int_ty = int_type(3); let hugr = SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(exec_registry()) .finish_with_exts(|mut builder, reg| { let us0 = builder.add_load_value(ConstUsize::new(0)); @@ -550,7 +550,7 @@ mod test { }) .collect_vec(); let mut builder = builder - .conditional_builder((variants, get_r), [], BOOL_T.into()) + .conditional_builder((variants, get_r), [], bool_t().into()) .unwrap(); for i in 0..2 { let mut builder = builder.case_builder(i).unwrap(); @@ -584,7 +584,7 @@ mod test { }; let r = { let mut conditional = builder - .conditional_builder(([type_row![], type_row![]], r), [], USIZE_T.into()) + .conditional_builder(([type_row![], type_row![]], r), [], usize_t().into()) .unwrap(); conditional .case_builder(0) @@ -631,7 +631,7 @@ mod test { let int_ty = int_type(3); let arr_ty = array_type(2, int_ty.clone()); let hugr = SimpleHugrConfig::new() - .with_outs(USIZE_T) + .with_outs(usize_t()) .with_extensions(exec_registry()) .finish_with_exts(|mut builder, reg| { let us0 = builder.add_load_value(ConstUsize::new(0)); @@ -653,7 +653,7 @@ mod test { r, ), [], - vec![arr_ty, BOOL_T].into(), + vec![arr_ty, bool_t()].into(), ) .unwrap(); for i in 0..2 { @@ -692,7 +692,7 @@ mod test { let r = builder.add_and(r, elem_1_ok).unwrap(); let r = { let mut conditional = builder - .conditional_builder(([type_row![], type_row![]], r), [], USIZE_T.into()) + .conditional_builder(([type_row![], type_row![]], r), [], usize_t().into()) .unwrap(); conditional .case_builder(0) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap new file mode 100644 index 000000000..8ad058cf3 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@llvm14.snap @@ -0,0 +1,31 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8* @_hl.main.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %0 = call i8* @__rt__list__new(i64 3, i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 8, i8* null) + %1 = alloca i64, align 8 + store i64 1, i64* %1, align 4 + %2 = bitcast i64* %1 to i8* + call void @__rt__list__push(i8* %0, i8* %2) + %3 = alloca i64, align 8 + store i64 2, i64* %3, align 4 + %4 = bitcast i64* %3 to i8* + call void @__rt__list__push(i8* %0, i8* %4) + %5 = alloca i64, align 8 + store i64 3, i64* %5, align 4 + %6 = bitcast i64* %5 to i8* + call void @__rt__list__push(i8* %0, i8* %6) + ret i8* %0 +} + +declare i8* @__rt__list__new(i64, i64, i64, i8*) + +declare void @__rt__list__push(i8*, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..5522be9ad --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__const@pre-mem2reg@llvm14.snap @@ -0,0 +1,37 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8* @_hl.main.1() { +alloca_block: + %"0" = alloca i8*, align 8 + %"5_0" = alloca i8*, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + %0 = call i8* @__rt__list__new(i64 3, i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 8, i8* null) + %1 = alloca i64, align 8 + store i64 1, i64* %1, align 4 + %2 = bitcast i64* %1 to i8* + call void @__rt__list__push(i8* %0, i8* %2) + %3 = alloca i64, align 8 + store i64 2, i64* %3, align 4 + %4 = bitcast i64* %3 to i8* + call void @__rt__list__push(i8* %0, i8* %4) + %5 = alloca i64, align 8 + store i64 3, i64* %5, align 4 + %6 = bitcast i64* %5 to i8* + call void @__rt__list__push(i8* %0, i8* %6) + store i8* %0, i8** %"5_0", align 8 + %"5_01" = load i8*, i8** %"5_0", align 8 + store i8* %"5_01", i8** %"0", align 8 + %"02" = load i8*, i8** %"0", align 8 + ret i8* %"02" +} + +declare i8* @__rt__list__new(i64, i64, i64, i8*) + +declare void @__rt__list__push(i8*, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@llvm14.snap new file mode 100644 index 000000000..5d7d0d381 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@llvm14.snap @@ -0,0 +1,24 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i32, {}, { i16 } } @_hl.main.1(i8* %0, i64 %1) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %2 = alloca i16, align 2 + %3 = bitcast i16* %2 to i8* + %4 = call i1 @__rt__list__get(i8* %0, i64 %1, i8* %3) + %5 = bitcast i8* %3 to i16* + %6 = load i16, i16* %5, align 2 + %7 = insertvalue { i16 } undef, i16 %6, 0 + %8 = insertvalue { i32, {}, { i16 } } { i32 1, {} poison, { i16 } poison }, { i16 } %7, 2 + %9 = select i1 %4, { i32, {}, { i16 } } %8, { i32, {}, { i16 } } { i32 0, {} undef, { i16 } poison } + ret { i32, {}, { i16 } } %9 +} + +declare i1 @__rt__list__get(i8*, i64, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..a7eee4d03 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__get@pre-mem2reg@llvm14.snap @@ -0,0 +1,36 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i32, {}, { i16 } } @_hl.main.1(i8* %0, i64 %1) { +alloca_block: + %"0" = alloca { i32, {}, { i16 } }, align 8 + %"2_0" = alloca i8*, align 8 + %"2_1" = alloca i64, align 8 + %"4_0" = alloca { i32, {}, { i16 } }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* %0, i8** %"2_0", align 8 + store i64 %1, i64* %"2_1", align 4 + %"2_01" = load i8*, i8** %"2_0", align 8 + %"2_12" = load i64, i64* %"2_1", align 4 + %2 = alloca i16, align 2 + %3 = bitcast i16* %2 to i8* + %4 = call i1 @__rt__list__get(i8* %"2_01", i64 %"2_12", i8* %3) + %5 = bitcast i8* %3 to i16* + %6 = load i16, i16* %5, align 2 + %7 = insertvalue { i16 } undef, i16 %6, 0 + %8 = insertvalue { i32, {}, { i16 } } { i32 1, {} poison, { i16 } poison }, { i16 } %7, 2 + %9 = select i1 %4, { i32, {}, { i16 } } %8, { i32, {}, { i16 } } { i32 0, {} undef, { i16 } poison } + store { i32, {}, { i16 } } %9, { i32, {}, { i16 } }* %"4_0", align 4 + %"4_03" = load { i32, {}, { i16 } }, { i32, {}, { i16 } }* %"4_0", align 4 + store { i32, {}, { i16 } } %"4_03", { i32, {}, { i16 } }* %"0", align 4 + %"04" = load { i32, {}, { i16 } }, { i32, {}, { i16 } }* %"0", align 4 + ret { i32, {}, { i16 } } %"04" +} + +declare i1 @__rt__list__get(i8*, i64, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@llvm14.snap new file mode 100644 index 000000000..deb84f1b5 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@llvm14.snap @@ -0,0 +1,25 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, { i32, { i16 }, { { {} } } } } @_hl.main.1(i8* %0, i64 %1, i16 %2) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %3 = alloca i16, align 2 + store i16 %2, i16* %3, align 2 + %4 = bitcast i16* %3 to i8* + %5 = call i1 @__rt__list__insert(i8* %0, i64 %1, i8* %4) + %6 = insertvalue { i16 } undef, i16 %2, 0 + %7 = insertvalue { i32, { i16 }, { { {} } } } { i32 0, { i16 } poison, { { {} } } poison }, { i16 } %6, 1 + %8 = select i1 %5, { i32, { i16 }, { { {} } } } { i32 1, { i16 } poison, { { {} } } undef }, { i32, { i16 }, { { {} } } } %7 + %mrv = insertvalue { i8*, { i32, { i16 }, { { {} } } } } undef, i8* %0, 0 + %mrv8 = insertvalue { i8*, { i32, { i16 }, { { {} } } } } %mrv, { i32, { i16 }, { { {} } } } %8, 1 + ret { i8*, { i32, { i16 }, { { {} } } } } %mrv8 +} + +declare i1 @__rt__list__insert(i8*, i64, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..9f92cf9a6 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__insert@pre-mem2reg@llvm14.snap @@ -0,0 +1,46 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, { i32, { i16 }, { { {} } } } } @_hl.main.1(i8* %0, i64 %1, i16 %2) { +alloca_block: + %"0" = alloca i8*, align 8 + %"1" = alloca { i32, { i16 }, { { {} } } }, align 8 + %"2_0" = alloca i8*, align 8 + %"2_1" = alloca i64, align 8 + %"2_2" = alloca i16, align 2 + %"4_0" = alloca i8*, align 8 + %"4_1" = alloca { i32, { i16 }, { { {} } } }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* %0, i8** %"2_0", align 8 + store i64 %1, i64* %"2_1", align 4 + store i16 %2, i16* %"2_2", align 2 + %"2_01" = load i8*, i8** %"2_0", align 8 + %"2_12" = load i64, i64* %"2_1", align 4 + %"2_23" = load i16, i16* %"2_2", align 2 + %3 = alloca i16, align 2 + store i16 %"2_23", i16* %3, align 2 + %4 = bitcast i16* %3 to i8* + %5 = call i1 @__rt__list__insert(i8* %"2_01", i64 %"2_12", i8* %4) + %6 = insertvalue { i16 } undef, i16 %"2_23", 0 + %7 = insertvalue { i32, { i16 }, { { {} } } } { i32 0, { i16 } poison, { { {} } } poison }, { i16 } %6, 1 + %8 = select i1 %5, { i32, { i16 }, { { {} } } } { i32 1, { i16 } poison, { { {} } } undef }, { i32, { i16 }, { { {} } } } %7 + store i8* %"2_01", i8** %"4_0", align 8 + store { i32, { i16 }, { { {} } } } %8, { i32, { i16 }, { { {} } } }* %"4_1", align 4 + %"4_04" = load i8*, i8** %"4_0", align 8 + %"4_15" = load { i32, { i16 }, { { {} } } }, { i32, { i16 }, { { {} } } }* %"4_1", align 4 + store i8* %"4_04", i8** %"0", align 8 + store { i32, { i16 }, { { {} } } } %"4_15", { i32, { i16 }, { { {} } } }* %"1", align 4 + %"06" = load i8*, i8** %"0", align 8 + %"17" = load { i32, { i16 }, { { {} } } }, { i32, { i16 }, { { {} } } }* %"1", align 4 + %mrv = insertvalue { i8*, { i32, { i16 }, { { {} } } } } undef, i8* %"06", 0 + %mrv8 = insertvalue { i8*, { i32, { i16 }, { { {} } } } } %mrv, { i32, { i16 }, { { {} } } } %"17", 1 + ret { i8*, { i32, { i16 }, { { {} } } } } %mrv8 +} + +declare i1 @__rt__list__insert(i8*, i64, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@llvm14.snap new file mode 100644 index 000000000..61ddae3a3 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@llvm14.snap @@ -0,0 +1,19 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, i64 } @_hl.main.1(i8* %0) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %1 = call i64 @__rt__list__length(i8* %0) + %mrv = insertvalue { i8*, i64 } undef, i8* %0, 0 + %mrv6 = insertvalue { i8*, i64 } %mrv, i64 %1, 1 + ret { i8*, i64 } %mrv6 +} + +declare i64 @__rt__list__length(i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..e993956bb --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__length@pre-mem2reg@llvm14.snap @@ -0,0 +1,34 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, i64 } @_hl.main.1(i8* %0) { +alloca_block: + %"0" = alloca i8*, align 8 + %"1" = alloca i64, align 8 + %"2_0" = alloca i8*, align 8 + %"4_0" = alloca i8*, align 8 + %"4_1" = alloca i64, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* %0, i8** %"2_0", align 8 + %"2_01" = load i8*, i8** %"2_0", align 8 + %1 = call i64 @__rt__list__length(i8* %"2_01") + store i8* %"2_01", i8** %"4_0", align 8 + store i64 %1, i64* %"4_1", align 4 + %"4_02" = load i8*, i8** %"4_0", align 8 + %"4_13" = load i64, i64* %"4_1", align 4 + store i8* %"4_02", i8** %"0", align 8 + store i64 %"4_13", i64* %"1", align 4 + %"04" = load i8*, i8** %"0", align 8 + %"15" = load i64, i64* %"1", align 4 + %mrv = insertvalue { i8*, i64 } undef, i8* %"04", 0 + %mrv6 = insertvalue { i8*, i64 } %mrv, i64 %"15", 1 + ret { i8*, i64 } %mrv6 +} + +declare i64 @__rt__list__length(i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@llvm14.snap new file mode 100644 index 000000000..e011b3dfe --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@llvm14.snap @@ -0,0 +1,26 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, { i32, {}, { i16 } } } @_hl.main.1(i8* %0) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %1 = alloca i16, align 2 + %2 = bitcast i16* %1 to i8* + %3 = call i1 @__rt__list__pop(i8* %0, i8* %2) + %4 = bitcast i8* %2 to i16* + %5 = load i16, i16* %4, align 2 + %6 = insertvalue { i16 } undef, i16 %5, 0 + %7 = insertvalue { i32, {}, { i16 } } { i32 1, {} poison, { i16 } poison }, { i16 } %6, 2 + %8 = select i1 %3, { i32, {}, { i16 } } %7, { i32, {}, { i16 } } { i32 0, {} undef, { i16 } poison } + %mrv = insertvalue { i8*, { i32, {}, { i16 } } } undef, i8* %0, 0 + %mrv6 = insertvalue { i8*, { i32, {}, { i16 } } } %mrv, { i32, {}, { i16 } } %8, 1 + ret { i8*, { i32, {}, { i16 } } } %mrv6 +} + +declare i1 @__rt__list__pop(i8*, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..4b677b1a8 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__pop@pre-mem2reg@llvm14.snap @@ -0,0 +1,41 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, { i32, {}, { i16 } } } @_hl.main.1(i8* %0) { +alloca_block: + %"0" = alloca i8*, align 8 + %"1" = alloca { i32, {}, { i16 } }, align 8 + %"2_0" = alloca i8*, align 8 + %"4_0" = alloca i8*, align 8 + %"4_1" = alloca { i32, {}, { i16 } }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* %0, i8** %"2_0", align 8 + %"2_01" = load i8*, i8** %"2_0", align 8 + %1 = alloca i16, align 2 + %2 = bitcast i16* %1 to i8* + %3 = call i1 @__rt__list__pop(i8* %"2_01", i8* %2) + %4 = bitcast i8* %2 to i16* + %5 = load i16, i16* %4, align 2 + %6 = insertvalue { i16 } undef, i16 %5, 0 + %7 = insertvalue { i32, {}, { i16 } } { i32 1, {} poison, { i16 } poison }, { i16 } %6, 2 + %8 = select i1 %3, { i32, {}, { i16 } } %7, { i32, {}, { i16 } } { i32 0, {} undef, { i16 } poison } + store i8* %"2_01", i8** %"4_0", align 8 + store { i32, {}, { i16 } } %8, { i32, {}, { i16 } }* %"4_1", align 4 + %"4_02" = load i8*, i8** %"4_0", align 8 + %"4_13" = load { i32, {}, { i16 } }, { i32, {}, { i16 } }* %"4_1", align 4 + store i8* %"4_02", i8** %"0", align 8 + store { i32, {}, { i16 } } %"4_13", { i32, {}, { i16 } }* %"1", align 4 + %"04" = load i8*, i8** %"0", align 8 + %"15" = load { i32, {}, { i16 } }, { i32, {}, { i16 } }* %"1", align 4 + %mrv = insertvalue { i8*, { i32, {}, { i16 } } } undef, i8* %"04", 0 + %mrv6 = insertvalue { i8*, { i32, {}, { i16 } } } %mrv, { i32, {}, { i16 } } %"15", 1 + ret { i8*, { i32, {}, { i16 } } } %mrv6 +} + +declare i1 @__rt__list__pop(i8*, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@llvm14.snap new file mode 100644 index 000000000..6e9be48bc --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@llvm14.snap @@ -0,0 +1,20 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8* @_hl.main.1(i8* %0, i16 %1) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %2 = alloca i16, align 2 + store i16 %1, i16* %2, align 2 + %3 = bitcast i16* %2 to i8* + call void @__rt__list__push(i8* %0, i8* %3) + ret i8* %0 +} + +declare void @__rt__list__push(i8*, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..5e88ddf5a --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__push@pre-mem2reg@llvm14.snap @@ -0,0 +1,32 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8* @_hl.main.1(i8* %0, i16 %1) { +alloca_block: + %"0" = alloca i8*, align 8 + %"2_0" = alloca i8*, align 8 + %"2_1" = alloca i16, align 2 + %"4_0" = alloca i8*, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* %0, i8** %"2_0", align 8 + store i16 %1, i16* %"2_1", align 2 + %"2_01" = load i8*, i8** %"2_0", align 8 + %"2_12" = load i16, i16* %"2_1", align 2 + %2 = alloca i16, align 2 + store i16 %"2_12", i16* %2, align 2 + %3 = bitcast i16* %2 to i8* + call void @__rt__list__push(i8* %"2_01", i8* %3) + store i8* %"2_01", i8** %"4_0", align 8 + %"4_03" = load i8*, i8** %"4_0", align 8 + store i8* %"4_03", i8** %"0", align 8 + %"04" = load i8*, i8** %"0", align 8 + ret i8* %"04" +} + +declare void @__rt__list__push(i8*, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@llvm14.snap new file mode 100644 index 000000000..f2b0ac21a --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@llvm14.snap @@ -0,0 +1,29 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, { i32, { i16 }, { i16 } } } @_hl.main.1(i8* %0, i64 %1, i16 %2) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %3 = alloca i16, align 2 + store i16 %2, i16* %3, align 2 + %4 = bitcast i16* %3 to i8* + %5 = call i1 @__rt__list__set(i8* %0, i64 %1, i8* %4) + %6 = bitcast i8* %4 to i16* + %7 = load i16, i16* %6, align 2 + %8 = insertvalue { i16 } undef, i16 %7, 0 + %9 = insertvalue { i32, { i16 }, { i16 } } { i32 0, { i16 } poison, { i16 } poison }, { i16 } %8, 1 + %10 = insertvalue { i16 } undef, i16 %2, 0 + %11 = insertvalue { i32, { i16 }, { i16 } } { i32 1, { i16 } poison, { i16 } poison }, { i16 } %10, 2 + %12 = select i1 %5, { i32, { i16 }, { i16 } } %11, { i32, { i16 }, { i16 } } %9 + %mrv = insertvalue { i8*, { i32, { i16 }, { i16 } } } undef, i8* %0, 0 + %mrv8 = insertvalue { i8*, { i32, { i16 }, { i16 } } } %mrv, { i32, { i16 }, { i16 } } %12, 1 + ret { i8*, { i32, { i16 }, { i16 } } } %mrv8 +} + +declare i1 @__rt__list__set(i8*, i64, i8*) diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..ba89dc6cc --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__collections__test__set@pre-mem2reg@llvm14.snap @@ -0,0 +1,50 @@ +--- +source: hugr-llvm/src/extension/collections.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { i8*, { i32, { i16 }, { i16 } } } @_hl.main.1(i8* %0, i64 %1, i16 %2) { +alloca_block: + %"0" = alloca i8*, align 8 + %"1" = alloca { i32, { i16 }, { i16 } }, align 8 + %"2_0" = alloca i8*, align 8 + %"2_1" = alloca i64, align 8 + %"2_2" = alloca i16, align 2 + %"4_0" = alloca i8*, align 8 + %"4_1" = alloca { i32, { i16 }, { i16 } }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8* %0, i8** %"2_0", align 8 + store i64 %1, i64* %"2_1", align 4 + store i16 %2, i16* %"2_2", align 2 + %"2_01" = load i8*, i8** %"2_0", align 8 + %"2_12" = load i64, i64* %"2_1", align 4 + %"2_23" = load i16, i16* %"2_2", align 2 + %3 = alloca i16, align 2 + store i16 %"2_23", i16* %3, align 2 + %4 = bitcast i16* %3 to i8* + %5 = call i1 @__rt__list__set(i8* %"2_01", i64 %"2_12", i8* %4) + %6 = bitcast i8* %4 to i16* + %7 = load i16, i16* %6, align 2 + %8 = insertvalue { i16 } undef, i16 %7, 0 + %9 = insertvalue { i32, { i16 }, { i16 } } { i32 0, { i16 } poison, { i16 } poison }, { i16 } %8, 1 + %10 = insertvalue { i16 } undef, i16 %"2_23", 0 + %11 = insertvalue { i32, { i16 }, { i16 } } { i32 1, { i16 } poison, { i16 } poison }, { i16 } %10, 2 + %12 = select i1 %5, { i32, { i16 }, { i16 } } %11, { i32, { i16 }, { i16 } } %9 + store i8* %"2_01", i8** %"4_0", align 8 + store { i32, { i16 }, { i16 } } %12, { i32, { i16 }, { i16 } }* %"4_1", align 4 + %"4_04" = load i8*, i8** %"4_0", align 8 + %"4_15" = load { i32, { i16 }, { i16 } }, { i32, { i16 }, { i16 } }* %"4_1", align 4 + store i8* %"4_04", i8** %"0", align 8 + store { i32, { i16 }, { i16 } } %"4_15", { i32, { i16 }, { i16 } }* %"1", align 4 + %"06" = load i8*, i8** %"0", align 8 + %"17" = load { i32, { i16 }, { i16 } }, { i32, { i16 }, { i16 } }* %"1", align 4 + %mrv = insertvalue { i8*, { i32, { i16 }, { i16 } } } undef, i8* %"06", 0 + %mrv8 = insertvalue { i8*, { i32, { i16 }, { i16 } } } %mrv, { i32, { i16 }, { i16 } } %"17", 1 + ret { i8*, { i32, { i16 }, { i16 } } } %mrv8 +} + +declare i1 @__rt__list__set(i8*, i64, i8*) diff --git a/hugr-llvm/src/lib.rs b/hugr-llvm/src/lib.rs index 98c161658..75cb4ff58 100644 --- a/hugr-llvm/src/lib.rs +++ b/hugr-llvm/src/lib.rs @@ -79,3 +79,6 @@ pub fn llvm_version() -> &'static str { pub mod test; pub use custom::{CodegenExtension, CodegenExtsBuilder}; + +pub use inkwell; +pub use inkwell::llvm_sys; diff --git a/hugr-llvm/src/utils/array_op_builder.rs b/hugr-llvm/src/utils/array_op_builder.rs index d41a0d05d..167584806 100644 --- a/hugr-llvm/src/utils/array_op_builder.rs +++ b/hugr-llvm/src/utils/array_op_builder.rs @@ -119,7 +119,7 @@ pub mod test { builder::{DFGBuilder, HugrBuilder}, extension::{ prelude::{ - array_type, either_type, option_type, ConstUsize, UnwrapBuilder as _, USIZE_T, + array_type, either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _, }, PRELUDE_REGISTRY, }, @@ -139,11 +139,11 @@ pub mod test { let us0 = builder.add_load_value(ConstUsize::new(0)); let us1 = builder.add_load_value(ConstUsize::new(1)); let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(USIZE_T, [us1, us2]).unwrap(); + let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let [arr] = { - let r = builder.add_array_swap(USIZE_T, 2, arr, us0, us1).unwrap(); + let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); let res_sum_ty = { - let array_type = array_type(2, USIZE_T); + let array_type = array_type(2, usize_t()); either_type(array_type.clone(), array_type) }; builder @@ -152,16 +152,18 @@ pub mod test { }; let [elem_0] = { - let r = builder.add_array_get(USIZE_T, 2, arr, us0).unwrap(); + let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); builder - .build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(USIZE_T), r) + .build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(usize_t()), r) .unwrap() }; let [_elem_1, arr] = { - let r = builder.add_array_set(USIZE_T, 2, arr, us1, elem_0).unwrap(); + let r = builder + .add_array_set(usize_t(), 2, arr, us1, elem_0) + .unwrap(); let res_sum_ty = { - let row = vec![USIZE_T, array_type(2, USIZE_T)]; + let row = vec![usize_t(), array_type(2, usize_t())]; either_type(row.clone(), row) }; builder @@ -170,29 +172,29 @@ pub mod test { }; let [_elem_left, arr] = { - let r = builder.add_array_pop_left(USIZE_T, 2, arr).unwrap(); + let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); builder .build_unwrap_sum( &PRELUDE_REGISTRY, 1, - option_type(vec![USIZE_T, array_type(1, USIZE_T)]), + option_type(vec![usize_t(), array_type(1, usize_t())]), r, ) .unwrap() }; let [_elem_right, arr] = { - let r = builder.add_array_pop_right(USIZE_T, 1, arr).unwrap(); + let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); builder .build_unwrap_sum( &PRELUDE_REGISTRY, 1, - option_type(vec![USIZE_T, array_type(0, USIZE_T)]), + option_type(vec![usize_t(), array_type(0, usize_t())]), r, ) .unwrap() }; - builder.add_array_discard_empty(USIZE_T, arr).unwrap(); + builder.add_array_discard_empty(usize_t(), arr).unwrap(); builder } diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index 638676194..aef52cd4d 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -94,7 +94,7 @@ mod test { Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }, - extension::{prelude::QB_T, PRELUDE_REGISTRY}, + extension::{prelude::qb_t, PRELUDE_REGISTRY}, ops::{CallIndirect, Const, Value}, types::Signature, Hugr, HugrView, Wire, @@ -104,7 +104,7 @@ mod test { fn build_const(go: impl FnOnce(&mut DFGBuilder) -> Wire) -> Const { Value::function({ - let mut builder = DFGBuilder::new(Signature::new_endo(QB_T)).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap(); let r = go(&mut builder); builder .finish_hugr_with_outputs([r], &PRELUDE_REGISTRY) @@ -116,7 +116,7 @@ mod test { #[test] fn simple() { - let qb_sig: Signature = Signature::new_endo(QB_T); + let qb_sig: Signature = Signature::new_endo(qb_t()); let mut hugr = { let mut builder = ModuleBuilder::new(); let const_node = builder.add_constant(build_const(|builder| { @@ -152,7 +152,7 @@ mod test { #[test] fn nested() { - let qb_sig: Signature = Signature::new_endo(QB_T); + let qb_sig: Signature = Signature::new_endo(qb_t()); let mut hugr = { let mut builder = ModuleBuilder::new(); let const_node = builder.add_constant(build_const(|builder| { diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 94341beba..366de92eb 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -175,13 +175,25 @@ struct Term { } struct ListTerm { - items @0 :List(TermId); - tail @1 :OptionalTermId; + items @0 :List(ListPart); + } + + struct ListPart { + union { + item @0 :TermId; + splice @1 :TermId; + } } struct ExtSet { - extensions @0 :List(Text); - rest @1 :OptionalTermId; + items @0 :List(ExtSetPart); + } + + struct ExtSetPart { + union { + extension @0 :Text; + splice @1 :TermId; + } } struct FuncType { diff --git a/hugr-model/src/lib.rs b/hugr-model/src/lib.rs index c0bf1536f..6d50ff8a0 100644 --- a/hugr-model/src/lib.rs +++ b/hugr-model/src/lib.rs @@ -4,6 +4,7 @@ //! are not designed for efficient traversal or modification, but for simplicity and serialization. pub mod v0; +#[allow(clippy::needless_lifetimes)] pub(crate) mod hugr_v0_capnp { include!(concat!(env!("OUT_DIR"), "/hugr_v0_capnp.rs")); } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 5381a7dc8..2dfe67efc 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -296,9 +296,8 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::List(reader) => { let reader = reader?; - let items = read_scalar_list!(bump, reader, get_items, model::TermId); - let tail = reader.get_tail().checked_sub(1).map(model::TermId); - model::Term::List { items, tail } + let parts = read_list!(bump, reader, get_items, read_list_part); + model::Term::List { parts } } Which::ListType(item_type) => model::Term::ListType { @@ -307,18 +306,8 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::ExtSet(reader) => { let reader = reader?; - - let extensions = { - let extensions_reader = reader.get_extensions()?; - let mut extensions = BumpVec::with_capacity_in(extensions_reader.len() as _, bump); - for extension_reader in extensions_reader.iter() { - extensions.push(bump.alloc_str(extension_reader?.to_str()?) as &str); - } - extensions.into_bump_slice() - }; - - let rest = reader.get_rest().checked_sub(1).map(model::TermId); - model::Term::ExtSet { extensions, rest } + let parts = read_list!(bump, reader, get_items, read_ext_set_part); + model::Term::ExtSet { parts } } Which::Adt(variants) => model::Term::Adt { @@ -356,6 +345,28 @@ fn read_meta_item<'a>( Ok(model::MetaItem { name, value }) } +fn read_list_part( + _: &Bump, + reader: hugr_capnp::term::list_part::Reader, +) -> ReadResult { + use hugr_capnp::term::list_part::Which; + Ok(match reader.which()? { + Which::Item(term) => model::ListPart::Item(model::TermId(term)), + Which::Splice(list) => model::ListPart::Splice(model::TermId(list)), + }) +} + +fn read_ext_set_part<'a>( + bump: &'a Bump, + reader: hugr_capnp::term::ext_set_part::Reader, +) -> ReadResult> { + use hugr_capnp::term::ext_set_part::Which; + Ok(match reader.which()? { + Which::Extension(ext) => model::ExtSetPart::Extension(bump.alloc_str(ext?.to_str()?)), + Which::Splice(list) => model::ExtSetPart::Splice(model::TermId(list)), + }) +} + fn read_param<'a>( bump: &'a Bump, reader: hugr_capnp::param::Reader, diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index f3a0a14d2..aa377e2ec 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -187,16 +187,14 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { let _ = builder.set_args(model::TermId::unwrap_slice(args)); } - model::Term::List { items, tail } => { + model::Term::List { parts } => { let mut builder = builder.init_list(); - let _ = builder.set_items(model::TermId::unwrap_slice(items)); - builder.set_tail(tail.map_or(0, |t| t.0 + 1)); + write_list!(builder, init_items, write_list_item, parts); } - model::Term::ExtSet { extensions, rest } => { + model::Term::ExtSet { parts } => { let mut builder = builder.init_ext_set(); - let _ = builder.set_extensions(*extensions); - builder.set_rest(rest.map_or(0, |t| t.0 + 1)); + write_list!(builder, init_items, write_ext_set_item, parts); } model::Term::FuncType { @@ -215,3 +213,20 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { } } } + +fn write_list_item(mut builder: hugr_capnp::term::list_part::Builder, item: &model::ListPart) { + match item { + model::ListPart::Item(term_id) => builder.set_item(term_id.0), + model::ListPart::Splice(term_id) => builder.set_splice(term_id.0), + } +} + +fn write_ext_set_item( + mut builder: hugr_capnp::term::ext_set_part::Builder, + item: &model::ExtSetPart, +) { + match item { + model::ExtSetPart::Extension(ext) => builder.set_extension(ext), + model::ExtSetPart::Splice(term_id) => builder.set_splice(term_id.0), + } +} diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 16c7cb6c6..2b0dc1eaf 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -569,19 +569,10 @@ pub enum Term<'a> { r#type: TermId, }, - /// A list, with an optional tail. - /// - /// - `[ITEM-0 ... ITEM-n] : (list T)` where `T : static`, `ITEM-i : T`. - /// - `[ITEM-0 ... ITEM-n . TAIL] : (list item-type)` where `T : static`, `ITEM-i : T`, `TAIL : (list T)`. + /// A list. May include individual items or other lists to be spliced in. List { - /// The items in the list. - /// - /// `item-i : item-type` - items: &'a [TermId], - /// The tail of the list. - /// - /// `tail : (list item-type)` - tail: Option, + /// The parts of the list. + parts: &'a [ListPart], }, /// The type of lists, given a type for the items. @@ -615,14 +606,11 @@ pub enum Term<'a> { NatType, /// Extension set. - /// - /// - `(ext EXT-0 ... EXT-n) : ext-set` - /// - `(ext EXT-0 ... EXT-n . REST) : ext-set` where `REST : ext-set`. ExtSet { - /// The items in the extension set. - extensions: &'a [&'a str], - /// The rest of the extension set. - rest: Option, + /// The parts of the extension set. + /// + /// Since extension sets are unordered, the parts may occur in any order. + parts: &'a [ExtSetPart<'a>], }, /// The type of extension sets. @@ -676,6 +664,24 @@ pub enum Term<'a> { }, } +/// A part of a list term. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ListPart { + /// A single item. + Item(TermId), + /// A list to be spliced into the parent list. + Splice(TermId), +} + +/// A part of an extension set term. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ExtSetPart<'a> { + /// An extension. + Extension(&'a str), + /// An extension set to be spliced into the parent extension set. + Splice(TermId), +} + /// A parameter to a function or alias. /// /// Parameter names must be unique within a parameter list. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index d05e3d774..fc52b8271 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -10,8 +10,6 @@ string_raw = @{ (!("\\" | "\"") ~ ANY)+ } string_escape = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") } string_unicode = @{ "\\u" ~ "{" ~ ASCII_HEX_DIGIT+ ~ "}" } -list_tail = { "." } - module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI } meta = { "(" ~ "meta" ~ symbol ~ term ~ ")" } @@ -103,16 +101,18 @@ term_var = { "?" ~ identifier } term_apply_full = { ("(" ~ "@" ~ symbol ~ term* ~ ")") } term_apply = { symbol | ("(" ~ symbol ~ term* ~ ")") } term_quote = { "(" ~ "quote" ~ term ~ ")" } -term_list = { "[" ~ term* ~ (list_tail ~ term)? ~ "]" } +term_list = { "[" ~ (spliced_term | term)* ~ "]" } term_list_type = { "(" ~ "list" ~ term ~ ")" } term_str = { string } term_str_type = { "str" } term_nat = { (ASCII_DIGIT)+ } term_nat_type = { "nat" } -term_ext_set = { "(" ~ "ext" ~ ext_name* ~ (list_tail ~ term)? ~ ")" } +term_ext_set = { "(" ~ "ext" ~ (spliced_term | ext_name)* ~ ")" } term_ext_set_type = { "ext-set" } term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } + +spliced_term = { term ~ "..." } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 370dbeac0..8527f1a00 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -1,4 +1,4 @@ -use bumpalo::{collections::String as BumpString, Bump}; +use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; use pest::{ iterators::{Pair, Pairs}, Parser, RuleType, @@ -6,8 +6,9 @@ use pest::{ use thiserror::Error; use crate::v0::{ - AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, - NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, Term, TermId, + AliasDecl, ConstructorDecl, ExtSetPart, FuncDecl, GlobalRef, LinkRef, ListPart, LocalRef, + MetaItem, Module, Node, NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, + RegionKind, Term, TermId, }; mod pest_parser { @@ -136,21 +137,21 @@ impl<'a> ParseContext<'a> { } Rule::term_list => { - let mut items = Vec::new(); - let mut tail = None; + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); - for token in filter_rule(&mut inner, Rule::term) { - items.push(self.parse_term(token)?); - } - - if inner.next().is_some() { - let token = inner.next().unwrap(); - tail = Some(self.parse_term(token)?); + for token in inner { + match token.as_rule() { + Rule::term => parts.push(ListPart::Item(self.parse_term(token)?)), + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ListPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), + } } Term::List { - items: self.bump.alloc_slice_copy(&items), - tail, + parts: parts.into_bump_slice(), } } @@ -170,21 +171,23 @@ impl<'a> ParseContext<'a> { } Rule::term_ext_set => { - let mut extensions = Vec::new(); - let mut rest = None; + let mut parts = BumpVec::with_capacity_in(inner.len(), self.bump); - for token in filter_rule(&mut inner, Rule::ext_name) { - extensions.push(token.as_str()); - } - - if inner.next().is_some() { - let token = inner.next().unwrap(); - rest = Some(self.parse_term(token)?); + for token in inner { + match token.as_rule() { + Rule::ext_name => { + parts.push(ExtSetPart::Extension(self.bump.alloc_str(token.as_str()))) + } + Rule::spliced_term => { + let term_token = token.into_inner().next().unwrap(); + parts.push(ExtSetPart::Splice(self.parse_term(term_token)?)) + } + _ => unreachable!(), + } } Term::ExtSet { - extensions: self.bump.alloc_slice_copy(&extensions), - rest, + parts: parts.into_bump_slice(), } } diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index a320d4664..ac35b4cd4 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, - ParamSort, RegionId, RegionKind, Term, TermId, + ExtSetPart, GlobalRef, LinkRef, ListPart, LocalRef, MetaItem, ModelError, Module, NodeId, + Operation, Param, ParamSort, RegionId, RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -521,16 +521,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text("quote"); this.print_term(*r#type) }), - Term::List { items, tail } => self.print_brackets(|this| { - for item in items.iter() { - this.print_term(*item)?; - } - if let Some(tail) = tail { - this.print_text("."); - this.print_term(*tail)?; - } - Ok(()) - }), + Term::List { .. } => self.print_brackets(|this| this.print_list_parts(term_id)), Term::ListType { item_type } => self.print_parens(|this| { this.print_text("list"); this.print_term(*item_type) @@ -551,15 +542,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("nat"); Ok(()) } - Term::ExtSet { extensions, rest } => self.print_parens(|this| { + Term::ExtSet { .. } => self.print_parens(|this| { this.print_text("ext"); - for extension in *extensions { - this.print_text(*extension); - } - if let Some(rest) = rest { - this.print_text("."); - this.print_term(*rest)?; - } + this.print_ext_set_parts(term_id)?; Ok(()) }), Term::ExtSetType => { @@ -595,6 +580,54 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } + /// Prints the contents of a list. + /// + /// This is used so that spliced lists are merged into the parent list. + fn print_list_parts(&mut self, term_id: TermId) -> PrintResult<()> { + let term_data = self + .module + .get_term(term_id) + .ok_or(PrintError::TermNotFound(term_id))?; + + if let Term::List { parts } = term_data { + for part in *parts { + match part { + ListPart::Item(term) => self.print_term(*term)?, + ListPart::Splice(list) => self.print_list_parts(*list)?, + } + } + } else { + self.print_term(term_id)?; + self.print_text("..."); + } + + Ok(()) + } + + /// Prints the contents of an extension set. + /// + /// This is used so that spliced extension sets are merged into the parent extension set. + fn print_ext_set_parts(&mut self, term_id: TermId) -> PrintResult<()> { + let term_data = self + .module + .get_term(term_id) + .ok_or(PrintError::TermNotFound(term_id))?; + + if let Term::ExtSet { parts } = term_data { + for part in *parts { + match part { + ExtSetPart::Extension(ext) => self.print_text(*ext), + ExtSetPart::Splice(list) => self.print_ext_set_parts(*list)?, + } + } + } else { + self.print_term(term_id)?; + self.print_text("..."); + } + + Ok(()) + } + fn print_local_ref(&mut self, local_ref: LocalRef<'a>) -> PrintResult<()> { let name = match local_ref { LocalRef::Index(_, i) => { diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 17609c9e4..80157c23e 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -58,3 +58,8 @@ pub fn test_decl_exts() { pub fn test_constraints() { binary_roundtrip(include_str!("fixtures/model-constraints.edn")); } + +#[test] +pub fn test_lists() { + binary_roundtrip(include_str!("fixtures/model-lists.edn")); +} diff --git a/hugr-model/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn index ce849a772..87c6f7a3a 100644 --- a/hugr-model/tests/fixtures/model-call.edn +++ b/hugr-model/tests/fixtures/model-call.edn @@ -2,7 +2,7 @@ (declare-func example.callee (forall ?ext ext-set) - [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int . ?ext) + [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int ?ext ...) (meta doc.title (prelude.json "\"Callee\"")) (meta doc.description (prelude.json "\"This is a function declaration.\""))) diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn index 5db6b9886..ddfbd659f 100644 --- a/hugr-model/tests/fixtures/model-constraints.edn +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -4,10 +4,10 @@ (forall ?t type) (forall ?n nat) (where (nonlinear ?t)) - [?t] [(@ array.Array ?t ?n)] + [?t] [(@ prelude.Array ?t ?n)] (ext)) (declare-func array.copy (forall ?t type) (where (nonlinear ?t)) - [(@ array.Array ?t)] [(@ array.Array ?t) (@ array.Array ?t)] (ext)) + [(@ prelude.Array ?t)] [(@ prelude.Array ?t) (@ prelude.Array ?t)] (ext)) diff --git a/hugr-model/tests/fixtures/model-lists.edn b/hugr-model/tests/fixtures/model-lists.edn new file mode 100644 index 000000000..1385a0e2a --- /dev/null +++ b/hugr-model/tests/fixtures/model-lists.edn @@ -0,0 +1,21 @@ +(hugr 0) + +(declare-operation core.call-indirect + (forall ?inputs (list type)) + (forall ?outputs (list type)) + (forall ?exts ext-set) + (fn [(fn ?inputs ?outputs ?exts) ?inputs ...] ?outputs ?exts)) + +(declare-operation core.compose-parallel + (forall ?inputs-0 (list type)) + (forall ?inputs-1 (list type)) + (forall ?outputs-0 (list type)) + (forall ?outputs-1 (list type)) + (forall ?exts ext-set) + (fn + [(fn ?inputs-0 ?outputs-0 ?exts) + (fn ?inputs-1 ?outputs-1 ?exts) + ?inputs-0 ... + ?inputs-1 ...] + [?outputs-0 ... ?outputs-1 ...] + ?exts)) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 24a674051..67136329b 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -17,6 +17,7 @@ bench = false [dependencies] hugr-core = { path = "../hugr-core", version = "0.13.3" } +ascent = { version = "0.7.0" } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } @@ -28,3 +29,6 @@ extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } +proptest-recurse = { version = "0.5.0" } diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index b5e303d43..351f4a19e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -218,4 +218,4 @@ pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { } #[cfg(test)] -mod test; +pub(crate) mod test; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 473a37ac9..8b54f7e93 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,14 +1,13 @@ use crate::const_fold::constant_fold_pass; +use crate::test::TEST_REG; use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{ - const_ok, sum_with_error, ConstError, ConstString, UnpackTuple, BOOL_T, ERROR_TYPE, STRING_TYPE, + bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, UnpackTuple, }; -use hugr_core::extension::{ExtensionRegistry, PRELUDE}; use hugr_core::ops::Value; -use hugr_core::std_extensions::arithmetic; use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use hugr_core::std_extensions::logic::{self, LogicOp}; +use hugr_core::std_extensions::logic::LogicOp; use hugr_core::type_row; use hugr_core::types::{Signature, Type, TypeRow, TypeRowRV}; @@ -21,7 +20,7 @@ use hugr_core::builder::Container; use hugr_core::ops::OpType; use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; use hugr_core::std_extensions::arithmetic::float_ops::FloatOps; -use hugr_core::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; +use hugr_core::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; /// Check that a hugr just loads and returns a single expected constant. pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { @@ -93,7 +92,7 @@ fn test_big() { let unpack = build .add_dataflow_op( - UnpackTuple::new(type_row![FLOAT64_TYPE, FLOAT64_TYPE]), + UnpackTuple::new(vec![float64_type(), float64_type()].into()), [tup], ) .unwrap(); @@ -105,22 +104,14 @@ fn test_big() { .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), - ]) - .unwrap(); let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), ®) + .finish_hugr_with_outputs(to_int.outputs(), &TEST_REG) .unwrap(); assert_eq!(h.node_count(), 8); - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); - let expected = const_ok(i2c(2).clone(), ERROR_TYPE); + let expected = const_ok(i2c(2).clone(), error_type()); assert_fully_folded(&h, &expected); } @@ -128,15 +119,9 @@ fn test_big() { #[ignore = "Waiting for `unwrap` operation"] // TODO: https://github.com/CQCL/hugr/issues/1486 fn test_list_ops() -> Result<(), Box> { - use hugr_core::std_extensions::collections::{self, ListOp, ListValue}; + use hugr_core::std_extensions::collections::{ListOp, ListValue}; - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), - ]) - .unwrap(); - let base_list: Value = ListValue::new(BOOL_T, [Value::false_val()]).into(); + let base_list: Value = ListValue::new(bool_t(), [Value::false_val()]).into(); let mut build = DFGBuilder::new(Signature::new( type_row![], vec![base_list.get_type().clone()], @@ -147,27 +132,30 @@ fn test_list_ops() -> Result<(), Box> { let [list, maybe_elem] = build .add_dataflow_op( - ListOp::pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + ListOp::pop + .with_type(bool_t()) + .to_extension_op(&TEST_REG) + .unwrap(), [list], )? .outputs_arr(); - // FIXME: Unwrap the Option + // FIXME: Unwrap the Option let elem = maybe_elem; let [list] = build .add_dataflow_op( ListOp::push - .with_type(BOOL_T) - .to_extension_op(®) + .with_type(bool_t()) + .to_extension_op(&TEST_REG) .unwrap(), [list, elem], )? .outputs_arr(); - let mut h = build.finish_hugr_with_outputs([list], ®)?; + let mut h = build.finish_hugr_with_outputs([list], &TEST_REG)?; - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); assert_fully_folded(&h, &base_list); Ok(()) @@ -179,14 +167,14 @@ fn test_fold_and() { // x0, x1 := bool(true), bool(true) // x2 := and(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); + let mut build = DFGBuilder::new(noargfn(bool_t())).unwrap(); let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::true_val()); let x2 = build.add_dataflow_op(LogicOp::And, [x0, x1]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -197,14 +185,14 @@ fn test_fold_or() { // x0, x1 := bool(true), bool(false) // x2 := or(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); + let mut build = DFGBuilder::new(noargfn(bool_t())).unwrap(); let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::false_val()); let x2 = build.add_dataflow_op(LogicOp::Or, [x0, x1]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -215,13 +203,13 @@ fn test_fold_not() { // x0 := bool(true) // x1 := not(x0) // output x1 == false; - let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); + let mut build = DFGBuilder::new(noargfn(bool_t())).unwrap(); let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_dataflow_op(LogicOp::Not, [x0]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -238,7 +226,7 @@ fn orphan_output() { // with no outputs. use hugr_core::ops::handle::NodeHandle; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let true_wire = build.add_load_value(Value::true_val()); // this Not will be manually replaced let orig_not = build.add_dataflow_op(LogicOp::Not, [true_wire]).unwrap(); @@ -247,9 +235,9 @@ fn orphan_output() { .unwrap(); let or_node = r.node(); let parent = build.container_node(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + let mut h = build + .finish_hugr_with_outputs(r.outputs(), &TEST_REG) + .unwrap(); // we delete the original Not and create a new One. This means it will be // traversed by `constant_fold_pass` after the Or. @@ -258,7 +246,7 @@ fn orphan_output() { h.disconnect(or_node, IncomingPort::from(1)); h.connect(new_not, 0, or_node, 1); h.remove_node(orig_not.node()); - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); assert_fully_folded(&h, &Value::true_val()) } @@ -275,7 +263,7 @@ fn test_folding_pass_issue_996() { // x6 := flt(x0, x5) // false // x7 := or(x4, x6) // true // output x7 - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); @@ -288,45 +276,39 @@ fn test_folding_pass_issue_996() { let x7 = build .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x7.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } #[test] fn test_const_fold_to_nonfinite() { - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - // HUGR computing 1.0 / 1.0 - let mut build = DFGBuilder::new(noargfn(vec![FLOAT64_TYPE])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h0, ®); + let mut h0 = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h0, &TEST_REG); assert_fully_folded_with(&h0, |v| { v.get_custom_value::().unwrap().value() == 1.0 }); assert_eq!(h0.node_count(), 5); // HUGR computing 1.0 / 0.0 - let mut build = DFGBuilder::new(noargfn(vec![FLOAT64_TYPE])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h1, ®); + let mut h1 = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h1, &TEST_REG); assert_eq!(h1.node_count(), 8); } @@ -342,13 +324,10 @@ fn test_fold_iwiden_u() { let x1 = build .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 13).unwrap()); assert_fully_folded(&h, &expected); } @@ -365,13 +344,10 @@ fn test_fold_iwiden_s() { let x1 = build .add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); assert_fully_folded(&h, &expected); } @@ -416,13 +392,10 @@ fn test_fold_inarrow, E: std::fmt::Debug>( [x0], ) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); lazy_static! { static ref INARROW_ERROR_VALUE: ConstError = ConstError { signal: 0, @@ -430,7 +403,7 @@ fn test_fold_inarrow, E: std::fmt::Debug>( }; } let expected = if succeeds { - const_ok(mk_const(to_log_width, val).unwrap().into(), ERROR_TYPE) + const_ok(mk_const(to_log_width, val).unwrap().into(), error_type()) } else { INARROW_ERROR_VALUE.clone().as_either(elem_type) }; @@ -444,18 +417,15 @@ fn test_fold_itobool() { // x0 := int_u<0>(1); // x1 := itobool(x0); // output x1 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(0, 1).unwrap())); let x1 = build .add_dataflow_op(ConvertOpDef::itobool.without_log_width(), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -472,13 +442,10 @@ fn test_fold_ifrombool() { let x1 = build .add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(0, 0).unwrap()); assert_fully_folded(&h, &expected); } @@ -489,19 +456,16 @@ fn test_fold_ieq() { // x0, x1 := int_s<3>(-1), int_u<3>(255) // x2 := ieq(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(3, -1).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 255).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ieq.with_log_width(3), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -512,19 +476,16 @@ fn test_fold_ine() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ine(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -535,19 +496,16 @@ fn test_fold_ilt_u() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ilt_u(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -558,19 +516,16 @@ fn test_fold_ilt_s() { // x0, x1 := int_s<5>(3), int_s<5>(-4) // x2 := ilt_s(x0, x1) // output x2 == false; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -581,19 +536,16 @@ fn test_fold_igt_u() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ilt_u(x0, x1) // output x2 == false; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::igt_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -604,19 +556,16 @@ fn test_fold_igt_s() { // x0, x1 := int_s<5>(3), int_s<5>(-4) // x2 := ilt_s(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::igt_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -627,19 +576,16 @@ fn test_fold_ile_u() { // x0, x1 := int_u<5>(3), int_u<5>(3) // x2 := ile_u(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ile_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -650,19 +596,16 @@ fn test_fold_ile_s() { // x0, x1 := int_s<5>(-4), int_s<5>(-4) // x2 := ile_s(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ile_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -673,19 +616,16 @@ fn test_fold_ige_u() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ilt_u(x0, x1) // output x2 == false; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ige_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -696,19 +636,16 @@ fn test_fold_ige_s() { // x0, x1 := int_s<5>(3), int_s<5>(-4) // x2 := ilt_s(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ige_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -725,13 +662,10 @@ fn test_fold_imax_u() { let x2 = build .add_dataflow_op(IntOpDef::imax_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 11).unwrap()); assert_fully_folded(&h, &expected); } @@ -748,13 +682,10 @@ fn test_fold_imax_s() { let x2 = build .add_dataflow_op(IntOpDef::imax_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -771,13 +702,10 @@ fn test_fold_imin_u() { let x2 = build .add_dataflow_op(IntOpDef::imin_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 7).unwrap()); assert_fully_folded(&h, &expected); } @@ -794,13 +722,10 @@ fn test_fold_imin_s() { let x2 = build .add_dataflow_op(IntOpDef::imin_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -2).unwrap()); assert_fully_folded(&h, &expected); } @@ -817,13 +742,10 @@ fn test_fold_iadd() { let x2 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -1).unwrap()); assert_fully_folded(&h, &expected); } @@ -840,13 +762,10 @@ fn test_fold_isub() { let x2 = build .add_dataflow_op(IntOpDef::isub.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); assert_fully_folded(&h, &expected); } @@ -862,13 +781,10 @@ fn test_fold_ineg() { let x2 = build .add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -885,13 +801,10 @@ fn test_fold_imul() { let x2 = build .add_dataflow_op(IntOpDef::imul.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -14).unwrap()); assert_fully_folded(&h, &expected); } @@ -911,13 +824,10 @@ fn test_fold_idivmod_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::idivmod_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -943,13 +853,10 @@ fn test_fold_idivmod_u() { let x4 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [x2, x3]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x4.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(3, 8).unwrap()); assert_fully_folded(&h, &expected); } @@ -969,13 +876,10 @@ fn test_fold_idivmod_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::idivmod_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1003,13 +907,10 @@ fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) { let x4 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [x2, x3]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x4.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(6, c).unwrap()); assert_fully_folded(&h, &expected); } @@ -1027,13 +928,10 @@ fn test_fold_idiv_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::idiv_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1054,13 +952,10 @@ fn test_fold_idiv_u() { let x2 = build .add_dataflow_op(IntOpDef::idiv_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 6).unwrap()); assert_fully_folded(&h, &expected); } @@ -1078,13 +973,10 @@ fn test_fold_imod_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::imod_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1105,13 +997,10 @@ fn test_fold_imod_u() { let x2 = build .add_dataflow_op(IntOpDef::imod_u.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -1129,13 +1018,10 @@ fn test_fold_idiv_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::idiv_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1156,13 +1042,10 @@ fn test_fold_idiv_s() { let x2 = build .add_dataflow_op(IntOpDef::idiv_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_s(5, -7).unwrap()); assert_fully_folded(&h, &expected); } @@ -1180,13 +1063,10 @@ fn test_fold_imod_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::imod_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1207,13 +1087,10 @@ fn test_fold_imod_s() { let x2 = build .add_dataflow_op(IntOpDef::imod_s.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1229,13 +1106,10 @@ fn test_fold_iabs() { let x2 = build .add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -1252,13 +1126,10 @@ fn test_fold_iand() { let x2 = build .add_dataflow_op(IntOpDef::iand.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 4).unwrap()); assert_fully_folded(&h, &expected); } @@ -1275,13 +1146,10 @@ fn test_fold_ior() { let x2 = build .add_dataflow_op(IntOpDef::ior.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 30).unwrap()); assert_fully_folded(&h, &expected); } @@ -1298,13 +1166,10 @@ fn test_fold_ixor() { let x2 = build .add_dataflow_op(IntOpDef::ixor.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 26).unwrap()); assert_fully_folded(&h, &expected); } @@ -1320,13 +1185,10 @@ fn test_fold_inot() { let x2 = build .add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, (1u64 << 32) - 15).unwrap()); assert_fully_folded(&h, &expected); } @@ -1343,13 +1205,10 @@ fn test_fold_ishl() { let x2 = build .add_dataflow_op(IntOpDef::ishl.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 112).unwrap()); assert_fully_folded(&h, &expected); } @@ -1366,13 +1225,10 @@ fn test_fold_ishr() { let x2 = build .add_dataflow_op(IntOpDef::ishr.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1389,13 +1245,10 @@ fn test_fold_irotl() { let x2 = build .add_dataflow_op(IntOpDef::irotl.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1412,13 +1265,10 @@ fn test_fold_irotr() { let x2 = build .add_dataflow_op(IntOpDef::irotr.with_log_width(5), [x0, x1]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1429,18 +1279,15 @@ fn test_fold_itostring_u() { // x0 := int_u<5>(17); // x1 := itostring_u(x0); // output x2 := "17"; - let mut build = DFGBuilder::new(noargfn(vec![STRING_TYPE])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![string_type()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 17).unwrap())); let x1 = build .add_dataflow_op(ConvertOpDef::itostring_u.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstString::new("17".into())); assert_fully_folded(&h, &expected); } @@ -1451,18 +1298,15 @@ fn test_fold_itostring_s() { // x0 := int_s<5>(-17); // x1 := itostring_s(x0); // output x2 := "-17"; - let mut build = DFGBuilder::new(noargfn(vec![STRING_TYPE])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![string_type()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -17).unwrap())); let x1 = build .add_dataflow_op(ConvertOpDef::itostring_s.with_log_width(5), [x0]) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::extension(ConstString::new("-17".into())); assert_fully_folded(&h, &expected); } @@ -1480,7 +1324,7 @@ fn test_fold_int_ops() { // x6 := ilt_s(x0, x5) // false // x7 := or(x4, x6) // true // output x7 - let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build @@ -1499,14 +1343,10 @@ fn test_fold_int_ops() { let x7 = build .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); + let mut h = build + .finish_hugr_with_outputs(x7.outputs(), &TEST_REG) + .unwrap(); + constant_fold_pass(&mut h, &TEST_REG); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs new file mode 100644 index 000000000..bb3023c38 --- /dev/null +++ b/hugr-passes/src/dataflow.rs @@ -0,0 +1,124 @@ +#![warn(missing_docs)] +//! Dataflow analysis of Hugrs. + +mod datalog; +pub use datalog::Machine; +mod value_row; + +mod results; +pub use results::{AnalysisResults, TailLoopTermination}; + +mod partial_value; +pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; + +use hugr_core::ops::constant::OpaqueValue; +use hugr_core::ops::{ExtensionOp, Value}; +use hugr_core::types::TypeArg; +use hugr_core::{Hugr, Node}; + +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `V`). +pub trait DFContext: ConstLoader { + /// Given lattice values for each input, update lattice values for the (dataflow) outputs. + /// For extension ops only, excluding [MakeTuple] and [UnpackTuple] which are handled automatically. + /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] + /// which is the correct value to leave if nothing can be deduced about that output. + /// (The default does nothing, i.e. leaves `Top` for all outputs.) + /// + /// [MakeTuple]: hugr_core::extension::prelude::MakeTuple + /// [UnpackTuple]: hugr_core::extension::prelude::UnpackTuple + fn interpret_leaf_op( + &mut self, + _node: Node, + _e: &ExtensionOp, + _ins: &[PartialValue], + _outs: &mut [PartialValue], + ) { + } +} + +/// A location where a [Value] could be find in a Hugr. That is, +/// (perhaps deeply nested within [Value::Sum]s) within a [Node] +/// that is a [Const](hugr_core::ops::Const). +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ConstLocation<'a> { + /// The specified-index'th field of the [Value::Sum] constant identified by the RHS + Field(usize, &'a ConstLocation<'a>), + /// The entire ([Const::value](hugr_core::ops::Const::value)) of the node. + Node(Node), +} + +impl From for ConstLocation<'_> { + fn from(value: Node) -> Self { + ConstLocation::Node(value) + } +} + +/// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. +/// Implementors will likely want to override some/all of [Self::value_from_opaque], +/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// are "correct" but maximally conservative (minimally informative). +pub trait ConstLoader { + /// Produces an abstract value from an [OpaqueValue], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_opaque(&self, _loc: ConstLocation, _val: &OpaqueValue) -> Option { + None + } + + /// Produces an abstract value from a Hugr in a [Value::Function], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_const_hugr(&self, _loc: ConstLocation, _h: &Hugr) -> Option { + None + } + + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node + /// (that has been loaded via a [LoadFunction]), if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + /// + /// [FuncDefn]: hugr_core::ops::FuncDefn + /// [FuncDecl]: hugr_core::ops::FuncDecl + /// [LoadFunction]: hugr_core::ops::LoadFunction + fn value_from_function(&self, _node: Node, _type_args: &[TypeArg]) -> Option { + None + } +} + +/// Produces a [PartialValue] from a constant. Traverses [Sum](Value::Sum) constants +/// to their leaves ([Value::Extension] and [Value::Function]), +/// converts these using [ConstLoader::value_from_opaque] and [ConstLoader::value_from_const_hugr], +/// and builds nested [PartialValue::new_variant] to represent the structure. +fn partial_from_const<'a, V>( + cl: &impl ConstLoader, + loc: impl Into>, + cst: &Value, +) -> PartialValue { + let loc = loc.into(); + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values + .iter() + .enumerate() + .map(|(idx, elem)| partial_from_const(cl, ConstLocation::Field(idx, &loc), elem)); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => cl + .value_from_opaque(loc, e) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => cl + .value_from_const_hugr(loc, hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } +} + +/// A row of inputs to a node contains bottom (can't happen, the node +/// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). +pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( + elements: impl IntoIterator>, +) -> bool { + elements.into_iter().any(PartialValue::contains_bottom) +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs new file mode 100644 index 000000000..172d87c26 --- /dev/null +++ b/hugr-passes/src/dataflow/datalog.rs @@ -0,0 +1,397 @@ +//! [ascent] datalog implementation of analysis. + +use std::collections::hash_map::RandomState; +use std::collections::HashSet; // Moves to std::hash in Rust 1.76 + +use ascent::lattice::BoundedLattice; +use itertools::Itertools; + +use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; +use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +use super::value_row::ValueRow; +use super::{ + partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, + PartialValue, +}; + +type PV = PartialValue; + +/// Basic structure for performing an analysis. Usage: +/// 1. Make a new instance via [Self::new()] +/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or +/// [Self::prepopulate_df_inputs] with initial values. +/// For example, to analyse a [Module](OpType::Module)-rooted Hugr as a library, +/// [Self::prepopulate_df_inputs] can be used on each externally-callable +/// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. +/// 3. Call [Self::run] to produce [AnalysisResults] +pub struct Machine(H, Vec<(Node, IncomingPort, PartialValue)>); + +impl Machine { + /// Create a new Machine to analyse the given Hugr(View) + pub fn new(hugr: H) -> Self { + Self(hugr, Default::default()) + } +} + +impl Machine { + /// Provide initial values for a wire - these will be `join`d with any computed. + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + self.1.extend( + self.0 + .linked_inputs(w.node(), w.source()) + .map(|(n, inp)| (n, inp, v.clone())), + ); + } + + /// Provide initial values for the inputs to a [DataflowParent](hugr_core::ops::OpTag::DataflowParent) + /// (that is, values on the wires leaving the [Input](OpType::Input) child thereof). + /// Any out-ports of said same `Input` node, not given values by `in_values`, are set to [PartialValue::Top]. + pub fn prepopulate_df_inputs( + &mut self, + parent: Node, + in_values: impl IntoIterator)>, + ) { + // Put values onto out-wires of Input node + let [inp, _] = self.0.get_io(parent).unwrap(); + let mut vals = vec![PartialValue::Top; self.0.signature(inp).unwrap().output_types().len()]; + for (ip, v) in in_values { + vals[ip.index()] = v; + } + for (i, v) in vals.into_iter().enumerate() { + self.prepopulate_wire(Wire::new(inp, i), v); + } + } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. For a + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. + /// The context passed in allows interpretation of leaf operations. + /// + /// # Panics + /// May panic in various ways if the Hugr is invalid; + /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. + pub fn run( + mut self, + context: impl DFContext, + in_values: impl IntoIterator)>, + ) -> AnalysisResults { + let mut in_values = in_values.into_iter(); + let root = self.0.root(); + // Some nodes do not accept values as dataflow inputs - for these + // we must find the corresponding Input node. + let input_node_parent = match self.0.get_optype(root) { + OpType::Module(_) => { + let main = self.0.children(root).find(|n| { + self.0 + .get_optype(*n) + .as_func_defn() + .is_some_and(|f| f.name == "main") + }); + if main.is_none() && in_values.next().is_some() { + panic!("Cannot give inputs to module with no 'main'"); + } + main + } + OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), + // Could also do Dfg above, but ok here too: + _ => None, // Just feed into node inputs + }; + // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis + // (Consider: for a conditional that selects *either* the unknown input *or* value V, + // analysis must produce Top == we-know-nothing, not `V` !) + if let Some(p) = input_node_parent { + self.prepopulate_df_inputs( + p, + in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)), + ); + } else { + // Put values onto in-wires of root node, datalog will do the rest + self.1.extend(in_values.map(|(p, v)| (root, p, v))); + let got_inputs: HashSet<_, RandomState> = self + .1 + .iter() + .filter_map(|(n, p, _)| (n == &root).then_some(*p)) + .collect(); + for p in self.0.signature(root).unwrap_or_default().input_ports() { + if !got_inputs.contains(&p) { + self.1.push((root, p, PartialValue::Top)); + } + } + } + // Note/TODO, if analysis is running on a subregion then we should do similar + // for any nonlocal edges providing values from outside the region. + run_datalog(context, self.0, self.1) + } +} + +pub(super) fn run_datalog( + mut ctx: impl DFContext, + hugr: H, + in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, +) -> AnalysisResults { + // ascent-(macro-)generated code generates a bunch of warnings, + // keep code in here to a minimum. + #![allow( + clippy::clone_on_copy, + clippy::unused_enumerate_index, + clippy::collapsible_if + )] + let all_results = ascent::ascent_run! { + pub(super) struct AscentProgram; + relation node(Node); // exists in the hugr + relation in_wire(Node, IncomingPort); // has an of `EdgeKind::Value` + relation out_wire(Node, OutgoingPort); // has an of `EdgeKind::Value` + relation parent_of_node(Node, Node); // is parent of + relation input_child(Node, Node); // has 1st child that is its `Input` + relation output_child(Node, Node); // has 2nd child that is its `Output` + lattice out_wire_value(Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(Node, ValueRow); // 's inputs are + + node(n) <-- for n in hugr.nodes(); + + in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) + + parent_of_node(parent, child) <-- + node(child), if let Some(parent) = hugr.get_parent(*child); + + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent); + + // Initialize all wires to bottom + out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); + + // Outputs to inputs + in_wire_value(n, ip, v) <-- in_wire(n, ip), + if let Some((m, op)) = hugr.single_linked_output(*n, *ip), + out_wire_value(m, op, v); + + // Prepopulate in_wire_value from in_wire_value_proto. + in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); + in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), + node(n), + if let Some(sig) = hugr.signature(*n), + if sig.input_ports().contains(p); + + // Assemble node_in_value_row from in_wire_value's + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); + node_in_value_row(n, ValueRow::new(hugr.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v); + + // Interpret leaf ops + out_wire_value(n, p, v) <-- + node(n), + let op_t = hugr.get_optype(*n), + if !op_t.is_container(), + if let Some(sig) = op_t.dataflow_signature(), + node_in_value_row(n, vs), + if let Some(outs) = propagate_leaf_op(&mut ctx, &hugr, *n, &vs[..], sig.output_count()), + for (p, v) in (0..).map(OutgoingPort::from).zip(outs); + + // DFG -------------------- + relation dfg_node(Node); // is a `DFG` + dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); + + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + input_child(dfg, i), in_wire_value(dfg, p, v); + + out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + output_child(dfg, o), in_wire_value(o, p, v); + + // TailLoop -------------------- + // inputs of tail loop propagate to Input node of child region + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), + if hugr.get_optype(*tl).is_tail_loop(), + input_child(tl, i), + in_wire_value(tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + input_child(tl, in_n), + output_child(tl, out_n), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ...and select just what's possible for CONTINUE_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + for (out_p, v) in fields.enumerate(); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + output_child(tl, out_n), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ... and select just what's possible for BREAK_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + for (out_p, v) in fields.enumerate(); + + // Conditional -------------------- + // is a `Conditional` and its 'th child (a `Case`) is : + relation case_node(Node, usize, Node); + case_node(cond, i, case) <-- node(cond), + if hugr.get_optype(*cond).is_conditional(), + for (i, case) in hugr.children(*cond).enumerate(), + if hugr.get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- + case_node(cond, case_index, case), + input_child(case, i_node), + node_in_value_row(cond, in_row), + let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + for (out_p, v) in fields.enumerate(); + + // outputs of case nodes propagate to outputs of conditional *if* case reachable + out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(cond, _i, case), + case_reachable(cond, case), + output_child(case, o), + in_wire_value(o, o_p, v); + + // In `Conditional` , child `Case` is reachable given our knowledge of predicate: + relation case_reachable(Node, Node); + case_reachable(cond, case) <-- case_node(cond, i, case), + in_wire_value(cond, IncomingPort::from(0), v), + if v.supports_tag(*i); + + // CFG -------------------- + relation cfg_node(Node); // is a `CFG` + cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); + + // In `CFG` , basic block is reachable given our knowledge of predicates: + relation bb_reachable(Node, Node); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); + bb_reachable(cfg, bb) <-- cfg_node(cfg), + bb_reachable(cfg, pred), + output_child(pred, pred_out), + in_wire_value(pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); + + // Inputs of CFG propagate to entry block + out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- + cfg_node(cfg), + if let Some(entry) = hugr.children(*cfg).next(), + input_child(entry, i_node), + in_wire_value(cfg, p, v); + + // In `CFG` , values fed along a control-flow edge to + // come out of Value outports of : + relation _cfg_succ_dest(Node, Node, Node); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), + for blk in hugr.children(*cfg), + if hugr.get_optype(blk).is_dataflow_block(), + input_child(blk, inp); + + // Outputs of each reachable block propagated to successor block or CFG itself + out_wire_value(dest, OutgoingPort::from(out_p), v) <-- + bb_reachable(cfg, pred), + if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), + output_child(pred, out_n), + _cfg_succ_dest(cfg, succ, dest), + node_in_value_row(out_n, out_in_row), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + for (out_p, v) in fields.enumerate(); + + // Call -------------------- + relation func_call(Node, Node); // is a `Call` to `FuncDefn` + func_call(call, func_defn) <-- + node(call), + if hugr.get_optype(*call).is_call(), + if let Some(func_defn) = hugr.static_source(*call); + + out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + input_child(func, inp), + in_wire_value(call, p, v); + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + output_child(func, outp), + in_wire_value(outp, p, v); + }; + let out_wire_values = all_results + .out_wire_value + .iter() + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults { + hugr, + out_wire_values, + in_wire_value: all_results.in_wire_value, + case_reachable: all_results.case_reachable, + bb_reachable: all_results.bb_reachable, + } +} + +fn propagate_leaf_op( + ctx: &mut impl DFContext, + hugr: &impl HugrView, + n: Node, + ins: &[PV], + num_outs: usize, +) -> Option> { + match hugr.get_optype(n) { + // Handle basics here. We could instead leave these to DFContext, + // but at least we'd want these impls to be easily reusable. + op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( + 0, + ins.iter().cloned(), + )])), + op if op.cast::().is_some() => { + let elem_tys = op.cast::().unwrap().0; + let tup = ins.iter().exactly_one().unwrap(); + tup.variant_values(0, elem_tys.len()) + .map(ValueRow::from_iter) + } + OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant( + t.tag, + ins.iter().cloned(), + )])), + OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent + OpType::Call(_) => None, // handled via Input/Output of FuncDefn + OpType::Const(_) => None, // handled by LoadConstant: + OpType::LoadConstant(load_op) => { + assert!(ins.is_empty()); // static edge, so need to find constant + let const_node = hugr + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_val = hugr.get_optype(const_node).as_const().unwrap().value(); + Some(ValueRow::singleton(partial_from_const(ctx, n, const_val))) + } + OpType::LoadFunction(load_op) => { + assert!(ins.is_empty()); // static edge + let func_node = hugr + .single_linked_output(n, load_op.function_port()) + .unwrap() + .0; + // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself + Some(ValueRow::singleton( + ctx.value_from_function(func_node, &load_op.type_args) + .map_or(PV::Top, PV::Value), + )) + } + OpType::ExtensionOp(e) => { + Some(ValueRow::from_iter(if row_contains_bottom(ins) { + // So far we think one or more inputs can't happen. + // So, don't pollute outputs with Top, and wait for better knowledge of inputs. + vec![PartialValue::Bottom; num_outs] + } else { + // Interpret op using DFContext + // Default to Top i.e. can't figure out anything about the outputs + let mut outs = vec![PartialValue::Top; num_outs]; + // It might be nice to convert `ins` to [(IncomingPort, Value)], or some + // other concrete value, for the context, but PV contains more information, + // and try_into_concrete may fail. + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); + outs + })) + } + o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + } +} diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs new file mode 100644 index 000000000..f2a497806 --- /dev/null +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -0,0 +1,706 @@ +use ascent::lattice::BoundedLattice; +use ascent::Lattice; +use hugr_core::ops::Value; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use itertools::{zip_eq, Itertools}; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use thiserror::Error; + +use super::row_contains_bottom; + +/// Trait for an underlying domain of abstract values which can form the *elements* of a +/// [PartialValue] and thus be used in dataflow analysis. +pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + /// Computes the join of two values (i.e. towards `Top``), if this is representable + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. + /// + /// If the join is not representable, return `None` - i.e., we should use [PartialValue::Top]. + /// + /// The default checks equality between `self` and `other` and returns `(self,false)` if + /// the two are identical, otherwise `None`. + fn try_join(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) + } + + /// Computes the meet of two values (i.e. towards `Bottom`), if this is representable + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. + /// If the meet is not representable, return `None` - i.e., we should use [PartialValue::Bottom]. + /// + /// The default checks equality between `self` and `other` and returns `(self, false)` if + /// the two are identical, otherwise `None`. + fn try_meet(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) + } +} + +/// Represents a sum with a single/known tag, abstracted over the representation of the elements. +/// (Identical to [Sum](hugr_core::ops::constant::Sum) except for the type abstraction.) +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Sum { + /// The tag index of the variant. + pub tag: usize, + /// The value of the variant. + /// + /// Sum variants are always a row of values, hence the Vec. + pub values: Vec, + /// The full type of the Sum, including the other variants. + pub st: SumType, +} + +/// A representation of a value of [SumType], that may have one or more possible tags, +/// with a [PartialValue] representation of each element-value of each possible tag. +#[derive(PartialEq, Clone, Eq)] +pub struct PartialSum(pub HashMap>>); + +impl PartialSum { + /// New instance for a single known tag. + /// (Multi-tag instances can be created via [Self::try_join_mut].) + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + Self(HashMap::from([(tag, Vec::from_iter(values))])) + } + + /// The number of possible variants we know about. (NOT the number + /// of tags possible for the value's type, whatever [SumType] that might be.) + pub fn num_variants(&self) -> usize { + self.0.len() + } + + fn assert_invariants(&self) { + assert_ne!(self.num_variants(), 0); + for pv in self.0.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); + } + } +} + +impl PartialSum { + /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns + /// whether `self` has changed. + /// + /// Fails (without mutation) with the conflicting tag if any common rows have different lengths. + pub fn try_join_mut(&mut self, other: Self) -> Result { + for (k, v) in &other.0 { + if self.0.get(k).is_some_and(|row| row.len() != v.len()) { + return Err(*k); + } + } + let mut changed = false; + + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + self.0.insert(k, v); + changed = true; + } + } + Ok(changed) + } + + /// Mutates self according to lattice meet operation (towards `Bottom`). If successful, + /// returns whether `self` has changed. + /// + /// # Errors + /// Fails without mutation, either: + /// * `Some(tag)` if the two [PartialSum]s both had rows with that `tag` but of different lengths + /// * `None` if the two instances had no rows in common (i.e., the result is "Bottom") + pub fn try_meet_mut(&mut self, other: Self) -> Result> { + let mut changed = false; + let mut keys_to_remove = vec![]; + for (k, v) in self.0.iter() { + match other.0.get(k) { + None => keys_to_remove.push(*k), + Some(o_v) => { + if v.len() != o_v.len() { + return Err(Some(*k)); + } + } + } + } + if keys_to_remove.len() == self.0.len() { + return Err(None); + } + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + self.0.remove(&k); + changed = true; + } + Ok(changed) + } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Turns this instance into a [Sum] of some "concrete" value type `C`, + /// *if* this PartialSum has exactly one possible tag. + /// + /// # Errors + /// + /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] + /// supporting the single possible tag with the correct number of elements and no row variables; + /// or if converting a child element failed via [PartialValue::try_into_concrete]. + pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> + where + V: TryInto, + Sum: TryInto, + { + if self.0.len() != 1 { + return Err(ExtractValueError::MultipleVariants(self)); + } + let (tag, v) = self.0.into_iter().exactly_one().unwrap(); + if let TypeEnum::Sum(st) = typ.as_type_enum() { + if let Some(r) = st.get_variant(tag) { + if let Ok(r) = TypeRow::try_from(r.clone()) { + if v.len() == r.len() { + return Ok(Sum { + tag, + values: zip_eq(v, r.iter()) + .map(|(v, t)| v.try_into_concrete(t)) + .collect::, _>>()?, + st: st.clone(), + }); + } + } + } + } + Err(ExtractValueError::BadSumType { + typ: typ.clone(), + tag, + num_elements: v.len(), + }) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } +} + +/// An error converting a [PartialValue] or [PartialSum] into a concrete value type +/// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] +#[derive(Clone, Debug, PartialEq, Eq, Error)] +#[allow(missing_docs)] +pub enum ExtractValueError { + #[error("PartialSum value had multiple possible tags: {0}")] + MultipleVariants(PartialSum), + #[error("Value contained `Bottom`")] + ValueIsBottom, + #[error("Value contained `Top`")] + ValueIsTop, + #[error("Could not convert element from abstract value into concrete: {0}")] + CouldNotConvert(V, #[source] VE), + #[error("Could not build Sum from concrete element values")] + CouldNotBuildSum(#[source] SE), + #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] + BadSumType { + typ: Type, + tag: usize, + num_elements: usize, + }, +} + +impl PartialSum { + /// If this Sum might have the specified `tag`, get the elements inside that tag. + pub fn variant_values(&self, variant: usize) -> Option>> { + self.0.get(&variant).cloned() + } +} + +impl PartialOrd for PartialSum { + fn partial_cmp(&self, other: &Self) -> Option { + let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in self.0.keys() { + keys1[*k] = 1; + } + + for k in other.0.keys() { + keys2[*k] = 1; + } + + Some(match keys1.cmp(&keys2) { + ord @ Ordering::Greater | ord @ Ordering::Less => ord, + Ordering::Equal => { + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(k) else { + unreachable!() + }; + let key_cmp = lhs.partial_cmp(rhs); + if key_cmp != Some(Ordering::Equal) { + return key_cmp; + } + } + Ordering::Equal + } + }) + } +} + +impl std::fmt::Debug for PartialSum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Hash for PartialSum { + fn hash(&self, state: &mut H) { + for (k, v) in &self.0 { + k.hash(state); + v.hash(state); + } + } +} + +/// Wraps some underlying representation (knowledge) of values into a lattice +/// for use in dataflow analysis, including that an instance may be a [PartialSum] +/// of values of the underlying representation +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + /// No possibilities known (so far) + Bottom, + /// A single value (of the underlying representation) + Value(V), + /// Sum (with at least one, perhaps several, possible tags) of underlying values + PartialSum(PartialSum), + /// Might be more than one distinct value of the underlying type `V` + Top, +} + +impl From for PartialValue { + fn from(v: V) -> Self { + Self::Value(v) + } +} + +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { + Self::PartialSum(v) + } +} + +impl PartialValue { + fn assert_invariants(&self) { + if let Self::PartialSum(ps) = self { + ps.assert_invariants(); + } + } + + /// New instance of a sum with a single known tag. + pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() + } + + /// New instance of unit type (i.e. the only possible value, with no contents) + pub fn new_unit() -> Self { + Self::new_variant(0, []) + } +} + +impl PartialValue { + /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. + /// + /// # Panics + /// + /// if the value is believed, for that tag, to have a number of values other than `len` + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + let vals = match self { + PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::PartialSum(ps) => ps.variant_values(tag)?, + PartialValue::Top => vec![PartialValue::Top; len], + }; + assert_eq!(vals.len(), len); + Some(vals) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) => false, + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, + /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by + /// [PartialSum::try_into_sum]. + /// + /// # Errors + /// + /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) + /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is + /// incorrect), or if that [Sum] could not be converted into a `V2`. + pub fn try_into_concrete(self, typ: &Type) -> Result> + where + V: TryInto, + Sum: TryInto, + { + match self { + Self::Value(v) => v + .clone() + .try_into() + .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), + Self::PartialSum(ps) => ps + .try_into_sum(typ)? + .try_into() + .map_err(ExtractValueError::CouldNotBuildSum), + Self::Top => Err(ExtractValueError::ValueIsTop), + Self::Bottom => Err(ExtractValueError::ValueIsBottom), + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } +} + +impl TryFrom> for Value { + type Error = ConstTypeError; + + fn try_from(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } +} + +impl Lattice for PartialValue { + fn join_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); + let mut old_self = Self::Top; + std::mem::swap(self, &mut old_self); + let (res, ch) = match (old_self, other) { + (old @ Self::Top, _) | (old, Self::Bottom) => (old, false), + (_, other @ Self::Top) | (Self::Bottom, other) => (other, true), + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { + Some((h3, b)) => (Self::Value(h3), b), + None => (Self::Top, true), + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Top, true), + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Top, true) + } + }; + *self = res; + ch + } + + fn meet_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); + let mut old_self = Self::Bottom; + std::mem::swap(self, &mut old_self); + let (res, ch) = match (old_self, other) { + (old @ Self::Bottom, _) | (old, Self::Top) => (old, false), + (_, other @ Self::Bottom) | (Self::Top, other) => (other, true), + (Self::Value(h1), Self::Value(h2)) => match h1.try_meet(h2) { + Some((h3, ch)) => (Self::Value(h3), ch), + None => (Self::Bottom, true), + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Bottom, true), + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Bottom, true) + } + }; + *self = res; + ch + } +} + +impl BoundedLattice for PartialValue { + fn top() -> Self { + Self::Top + } + + fn bottom() -> Self { + Self::Bottom + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use ascent::{lattice::BoundedLattice, Lattice}; + use itertools::{zip_eq, Itertools as _}; + use prop::sample::subsequence; + use proptest::prelude::*; + + use proptest_recurse::{StrategyExt, StrategySet}; + + use super::{AbstractValue, PartialSum, PartialValue}; + + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumType { + Branch(Vec>>), + /// None => unit, Some => TestValue <= this *usize* + Leaf(Option), + } + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct TestValue(usize); + + impl AbstractValue for TestValue {} + + #[derive(Clone)] + struct SumTypeParams { + depth: usize, + desired_size: usize, + expected_branch_size: usize, + } + + impl Default for SumTypeParams { + fn default() -> Self { + Self { + depth: 5, + desired_size: 20, + expected_branch_size: 5, + } + } + } + + impl TestSumType { + fn check_value(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), + (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::Branch(sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.check_value(rhs)) { + return false; + } + } + true + } + _ => false, + } + } + } + + impl Arbitrary for TestSumType { + type Parameters = SumTypeParams; + type Strategy = SBoxedStrategy; + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { + use proptest::collection::vec; + let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); + let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + leaf_strat.prop_mutually_recursive( + params.depth as u32, + params.desired_size as u32, + params.expected_branch_size as u32, + set, + move |set| { + let params2 = params.clone(); + vec( + vec( + set.get::(move |set| arb(params2, set)) + .prop_map(Arc::new), + 1..=params.expected_branch_size, + ), + 1..=params.expected_branch_size, + ) + .prop_map(TestSumType::Branch) + .sboxed() + }, + ) + } + + arb(params, &mut StrategySet::default()) + } + } + + fn single_sum_strat( + tag: usize, + elems: Vec>, + ) -> impl Strategy> { + elems + .iter() + .map(Arc::as_ref) + .map(any_partial_value_of_type) + .collect::>() + .prop_map(move |elems| PartialSum::new_variant(tag, elems)) + } + + fn partial_sum_strat( + variants: &[Vec>], + ) -> impl Strategy> { + // We have to clone the `variants` here but only as far as the Vec>> + let tagged_variants = variants.iter().cloned().enumerate().collect::>(); + // The type annotation here (and the .boxed() enabling it) are just for documentation + let sum_variants_strat: BoxedStrategy>> = + subsequence(tagged_variants, 1..=variants.len()) + .prop_flat_map(|selected_variants| { + selected_variants + .into_iter() + .map(|(tag, elems)| single_sum_strat(tag, elems)) + .collect::>() + }) + .boxed(); + sum_variants_strat.prop_map(|psums: Vec>| { + let mut psums = psums.into_iter(); + let first = psums.next().unwrap(); + psums.fold(first, |mut a, b| { + a.try_join_mut(b).unwrap(); + a + }) + }) + } + + fn any_partial_value_of_type( + ust: &TestSumType, + ) -> impl Strategy> { + match ust { + TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), + TestSumType::Leaf(Some(i)) => (0..*i) + .prop_map(TestValue) + .prop_map(PartialValue::from) + .boxed(), + TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), + } + } + + fn any_partial_value_with( + params: ::Parameters, + ) -> impl Strategy> { + any_with::(params).prop_flat_map(|t| any_partial_value_of_type(&t)) + } + + fn any_partial_value() -> impl Strategy> { + any_partial_value_with(Default::default()) + } + + fn any_partial_values() -> impl Strategy; N]> { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(&ust)) + .collect_vec(), + ) + .unwrap() + }) + } + + fn any_typed_partial_value() -> impl Strategy)> { + any::() + .prop_flat_map(|t| any_partial_value_of_type(&t).prop_map(move |v| (t.clone(), v))) + } + + proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.check_value(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + prop_assert!(meet == v2.clone().meet(v1.clone()), "meet not symmetric"); + prop_assert!(meet == meet.clone().meet(v1.clone()), "repeated meet should be a no-op"); + prop_assert!(meet == meet.clone().meet(v2.clone()), "repeated meet should be a no-op"); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + prop_assert!(join == v2.clone().join(v1.clone()), "join not symmetric"); + prop_assert!(join == join.clone().join(v1.clone()), "repeated join should be a no-op"); + prop_assert!(join == join.clone().join(v2.clone()), "repeated join should be a no-op"); + } + + #[test] + fn lattice_associative([v1, v2, v3] in any_partial_values()) { + let a = v1.clone().meet(v2.clone()).meet(v3.clone()); + let b = v1.clone().meet(v2.clone().meet(v3.clone())); + prop_assert!(a==b, "meet not associative"); + + let a = v1.clone().join(v2.clone()).join(v3.clone()); + let b = v1.clone().join(v2.clone().join(v3.clone())); + prop_assert!(a==b, "join not associative") + } + } +} diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs new file mode 100644 index 000000000..0f4704b42 --- /dev/null +++ b/hugr-passes/src/dataflow/results.rs @@ -0,0 +1,126 @@ +use std::collections::HashMap; + +use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; + +use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; + +/// Results of a dataflow analysis, packaged with the Hugr for easy inspection. +/// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). +pub struct AnalysisResults { + pub(super) hugr: H, + pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, + pub(super) case_reachable: Vec<(Node, Node)>, + pub(super) bb_reachable: Vec<(Node, Node)>, + pub(super) out_wire_values: HashMap>, +} + +impl AnalysisResults { + /// Gets the lattice value computed for the given wire + pub fn read_out_wire(&self, w: Wire) -> Option> { + self.out_wire_values.get(&w).cloned() + } + + /// Tells whether a [TailLoop] node can terminate, i.e. whether + /// `Break` and/or `Continue` tags may be returned by the nested DFG. + /// Returns `None` if the specified `node` is not a [TailLoop]. + /// + /// [TailLoop]: hugr_core::ops::TailLoop + pub fn tail_loop_terminates(&self, node: Node) -> Option { + self.hugr.get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr.get_io(node).unwrap(); + Some(TailLoopTermination::from_control_value( + self.in_wire_value + .iter() + .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .unwrap(), + )) + } + + /// Tells whether a [Case] node is reachable, i.e. whether the predicate + /// to its parent [Conditional] may possibly have the tag corresponding to the [Case]. + /// Returns `None` if the specified `case` is not a [Case], or is not within a [Conditional] + /// (e.g. a [Case]-rooted Hugr). + /// + /// [Case]: hugr_core::ops::Case + /// [Conditional]: hugr_core::ops::Conditional + pub fn case_reachable(&self, case: Node) -> Option { + self.hugr.get_optype(case).as_case()?; + let cond = self.hugr.get_parent(case)?; + self.hugr.get_optype(cond).as_conditional()?; + Some( + self.case_reachable + .iter() + .any(|(cond2, case2)| &cond == cond2 && &case == case2), + ) + } + + /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known + /// to be reachable. (Returns `None` if argument is not a child of a CFG.) + /// + /// [CFG]: hugr_core::ops::CFG + /// [DataflowBlock]: hugr_core::ops::DataflowBlock + /// [ExitBlock]: hugr_core::ops::ExitBlock + pub fn bb_reachable(&self, bb: Node) -> Option { + let cfg = self.hugr.get_parent(bb)?; // Not really required...?? + self.hugr.get_optype(cfg).as_cfg()?; + let t = self.hugr.get_optype(bb); + (t.is_dataflow_block() || t.is_exit_block()).then(|| { + self.bb_reachable + .iter() + .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) + }) + } + + /// Reads a concrete representation of the value on an output wire, if the lattice value + /// computed for the wire can be turned into such. (The lattice value must be either a + /// [PartialValue::Value] or a [PartialValue::PartialSum] with a single possible tag.) + /// + /// # Errors + /// `None` if the analysis did not produce a result for that wire, or if + /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire + /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` + pub fn try_read_wire_concrete( + &self, + w: Wire, + ) -> Result>> + where + V2: TryFrom + TryFrom, Error = SE>, + { + let v = self.read_out_wire(w).ok_or(None)?; + let (_, typ) = self + .hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .ok_or(None)?; + v.try_into_concrete(&typ).map_err(Some) + } +} + +/// Tells whether a loop iterates (never, always, sometimes) +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum TailLoopTermination { + /// The loop never exits (is an infinite loop); no value is ever + /// returned out of the loop. (aka, Bottom.) + // TODO what about a loop that never exits OR continues because of a nested infinite loop? + NeverBreaks, + /// The loop never iterates (so is equivalent to a [DFG](hugr_core::ops::DFG), + /// modulo untupling of the control value) + NeverContinues, + /// The loop might iterate and/or exit. (aka, Top) + BreaksAndContinues, +} + +impl TailLoopTermination { + fn from_control_value(v: &PartialValue) -> Self { + let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); + if may_break { + if may_continue { + Self::BreaksAndContinues + } else { + Self::NeverContinues + } + } else { + Self::NeverBreaks + } + } +} diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs new file mode 100644 index 000000000..13815d186 --- /dev/null +++ b/hugr-passes/src/dataflow/test.rs @@ -0,0 +1,548 @@ +use ascent::{lattice::BoundedLattice, Lattice}; + +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::extension::PRELUDE_REGISTRY; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::handle::DfgID; +use hugr_core::ops::TailLoop; +use hugr_core::types::TypeRow; +use hugr_core::{ + builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, + extension::{ + prelude::{bool_t, UnpackTuple}, + ExtensionSet, EMPTY_REG, + }, + ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, + type_row, + types::{Signature, SumType, Type}, + HugrView, +}; +use hugr_core::{Hugr, Wire}; +use rstest::{fixture, rstest}; + +use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; + +// ------- Minimal implementation of DFContext and AbstractValue ------- +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum Void {} + +impl AbstractValue for Void {} + +struct TestContext; + +impl ConstLoader for TestContext {} +impl DFContext for TestContext {} + +// This allows testing creation of tuple/sum Values (only) +impl From for Value { + fn from(v: Void) -> Self { + match v {} + } +} + +fn pv_false() -> PartialValue { + PartialValue::new_variant(0, []) +} + +fn pv_true() -> PartialValue { + PartialValue::new_variant(1, []) +} + +fn pv_true_or_false() -> PartialValue { + pv_true().join(pv_false()) +} + +#[test] +fn test_make_tuple() { + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let x: Value = results.try_read_wire_concrete(v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); +} + +#[test] +fn test_unpack_tuple_const() { + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); + let v = builder.add_load_value(Value::tuple([Value::false_val(), Value::true_val()])); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(vec![bool_t(); 2].into()), [v]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o1_r: Value = results.try_read_wire_concrete(o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r: Value = results.try_read_wire_concrete(o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value(r_v.clone()); + let tag = Tag::new( + TailLoop::BREAK_TAG, + vec![type_row![], r_v.get_type().into()], + ); + let tagged = builder.add_dataflow_op(tag, [r_w]).unwrap(); + + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(tagged.out_wire(0), []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r: Value = results.try_read_wire_concrete(tl_o).unwrap(); + assert_eq!(o_r, r_v); + assert_eq!( + Some(TailLoopTermination::NeverContinues), + results.tail_loop_terminates(tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + TailLoop::CONTINUE_TAG, + [], + SumType::new([type_row![], bool_t().into()]), + ) + .unwrap(), + ); + let true_w = builder.add_load_value(Value::true_val()); + + let tlb = builder + .tail_loop_builder([], [(bool_t(), true_w)], vec![bool_t()].into()) + .unwrap(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r1 = results.read_out_wire(tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = results.read_out_wire(tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!( + Some(TailLoopTermination::NeverBreaks), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_tail_loop_two_iters() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let tlb = builder + .tail_loop_builder_exts( + [], + [(bool_t(), false_w), (bool_t(), true_w)], + type_row![], + ExtensionSet::new(), + ) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().signature(), + Signature::new_endo(vec![bool_t(); 2]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r1 = results.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true_or_false()); + let o_r2 = results.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_true_or_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_tail_loop_containing_conditional() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let control_variants = vec![vec![bool_t(); 2].into(); 2]; + let control_t = Type::new_sum(control_variants.clone()); + let body_out_variants = vec![TypeRow::from(control_t.clone()), vec![bool_t(); 2].into()]; + + let init = builder.add_load_value( + Value::sum( + 0, + [Value::false_val(), Value::true_val()], + SumType::new(control_variants.clone()), + ) + .unwrap(), + ); + + let mut tlb = builder + .tail_loop_builder([(control_t, init)], [], vec![bool_t(); 2].into()) + .unwrap(); + let tl = tlb.loop_signature().unwrap().clone(); + let [in_w] = tlb.input_wires_arr(); + + // Branch on in_wire, so first iter 0(false, true)... + let mut cond = tlb + .conditional_builder( + (control_variants.clone(), in_w), + [], + Type::new_sum(body_out_variants.clone()).into(), + ) + .unwrap(); + let mut case0_b = cond.case_builder(0).unwrap(); + let [a, b] = case0_b.input_wires_arr(); + // Builds value for next iter as 1(true, false) by flipping arguments + let [next_input] = case0_b + .add_dataflow_op(Tag::new(1, control_variants), [b, a]) + .unwrap() + .outputs_arr(); + let cont = case0_b.make_continue(tl.clone(), [next_input]).unwrap(); + case0_b.finish_with_outputs([cont]).unwrap(); + // Second iter 1(true, false) => exit with (true, false) + let mut case1_b = cond.case_builder(1).unwrap(); + let loop_res = case1_b.make_break(tl, case1_b.input_wires()).unwrap(); + case1_b.finish_with_outputs([loop_res]).unwrap(); + let [r] = cond.finish_sub_container().unwrap().outputs_arr(); + + let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let results = Machine::new(&hugr).run(TestContext, []); + + let o_r1 = results.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true()); + let o_r2 = results.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + results.tail_loop_terminates(tail_loop.node()) + ); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); +} + +#[test] +fn test_conditional() { + let variants = vec![type_row![], type_row![], bool_t().into()]; + let cond_t = Type::new_sum(variants.clone()); + let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); + let [arg_w] = builder.input_wires_arr(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut cond_builder = builder + .conditional_builder( + (variants, arg_w), + [(bool_t(), true_w)], + vec![bool_t(); 2].into(), + ) + .unwrap(); + // will be unreachable + let case1_b = cond_builder.case_builder(0).unwrap(); + let case1 = case1_b.finish_with_outputs([false_w, false_w]).unwrap(); + + let case2_b = cond_builder.case_builder(1).unwrap(); + let [c2a] = case2_b.input_wires_arr(); + let case2 = case2_b.finish_with_outputs([false_w, c2a]).unwrap(); + + let case3_b = cond_builder.case_builder(2).unwrap(); + let [c3_1, _c3_2] = case3_b.input_wires_arr(); + let case3 = case3_b.finish_with_outputs([c3_1, false_w]).unwrap(); + + let cond = cond_builder.finish_sub_container().unwrap(); + + let [cond_o1, cond_o2] = cond.outputs_arr(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( + 2, + [PartialValue::new_variant(0, [])], + )); + let results = Machine::new(&hugr).run(TestContext, [(0.into(), arg_pv)]); + + let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); + assert_eq!(cond_r1, Value::false_val()); + assert!(results + .try_read_wire_concrete::(cond_o2) + .is_err()); + + assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only + assert_eq!(results.case_reachable(case2.node()), Some(true)); + assert_eq!(results.case_reachable(case3.node()), Some(true)); + assert_eq!(results.case_reachable(cond.node()), None); +} + +// A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) +#[fixture] +fn xor_and_cfg() -> Hugr { + // Entry branch on first arg, passes arguments on unchanged + // /T F\ + // A --T-> B A(x=true, y) branch on second arg, passing (first arg == true, false) + // \F / B(w,v) => X(v,w) + // > X < + // Inputs received: + // Entry A B X + // F,F - F,F F,F + // F,T - F,T T,F + // T,F T,F - T,F + // T,T T,T T,F F,T + let mut builder = + CFGBuilder::new(Signature::new(vec![bool_t(); 2], vec![bool_t(); 2])).unwrap(); + + // entry (x, y) => (if x then A else B)(x=true, y) + let entry = builder + .entry_builder(vec![type_row![]; 2], vec![bool_t(); 2].into()) + .unwrap(); + let [in_x, in_y] = entry.input_wires_arr(); + let entry = entry.finish_with_outputs(in_x, [in_x, in_y]).unwrap(); + + // A(x==true, y) => (if y then B else X)(x, false) + let mut a = builder + .block_builder( + vec![bool_t(); 2].into(), + vec![type_row![]; 2], + vec![bool_t(); 2].into(), + ) + .unwrap(); + let [in_x, in_y] = a.input_wires_arr(); + let false_w1 = a.add_load_value(Value::false_val()); + let a = a.finish_with_outputs(in_y, [in_x, false_w1]).unwrap(); + + // B(w, v) => X(v, w) + let mut b = builder + .block_builder( + vec![bool_t(); 2].into(), + [type_row![]], + vec![bool_t(); 2].into(), + ) + .unwrap(); + let [in_w, in_v] = b.input_wires_arr(); + let [control] = b + .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) + .unwrap() + .outputs_arr(); + let b = b.finish_with_outputs(control, [in_v, in_w]).unwrap(); + + let x = builder.exit_block(); + + let [fals, tru]: [usize; 2] = [0, 1]; + builder.branch(&entry, tru, &a).unwrap(); // if true + builder.branch(&entry, fals, &b).unwrap(); // if false + builder.branch(&a, tru, &b).unwrap(); // if true + builder.branch(&a, fals, &x).unwrap(); // if false + builder.branch(&b, 0, &x).unwrap(); + builder.finish_hugr(&EMPTY_REG).unwrap() +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_false(), pv_true())] +#[case(pv_true(), pv_false(), pv_true(), pv_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true(), PartialValue::Top, pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_false(), pv_false())] +#[case(pv_false(), pv_true_or_false(), pv_true_or_false(), pv_false())] +#[case(pv_false(), PartialValue::Top, PartialValue::Top, pv_false())] // if !inp0 then out0=inp1 +#[case(pv_true_or_false(), pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::Top, pv_true(), pv_true_or_false(), PartialValue::Top)] +#[case(PartialValue::Top, pv_false(), PartialValue::Top, PartialValue::Top)] +fn test_cfg( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out0: PartialValue, + #[case] out1: PartialValue, + xor_and_cfg: Hugr, +) { + let root = xor_and_cfg.root(); + let results = Machine::new(&xor_and_cfg).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); + + assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); + assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_false(), pv_false(), pv_false())] +#[case(pv_true(), pv_false(), pv_true_or_false())] // Two calls alias +fn test_call( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out: PartialValue, +) { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); + let func_bldr = builder + .define_function("id", Signature::new_endo(bool_t())) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); + let [a, b] = builder.input_wires_arr(); + let [a2] = builder + .call(func_defn.handle(), &[], [a], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let [b2] = builder + .call(func_defn.handle(), &[], [b], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let hugr = builder + .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) + .unwrap(); + + let results = Machine::new(&hugr).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); + + let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); + // The two calls alias so both results will be the same: + assert_eq!(res0, out); + assert_eq!(res1, out); +} + +#[test] +fn test_region() { + let mut builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap(); + let [in_w] = builder.input_wires_arr(); + let cst_w = builder.add_load_const(Value::false_val()); + let nested = builder + .dfg_builder(Signature::new_endo(vec![bool_t(); 2]), [in_w, cst_w]) + .unwrap(); + let nested_ins = nested.input_wires(); + let nested = nested.finish_with_outputs(nested_ins).unwrap(); + let hugr = builder + .finish_prelude_hugr_with_outputs(nested.outputs()) + .unwrap(); + let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); + let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(pv_false()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 1)), + Some(pv_false()) + ); + + let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); + // Do not provide a value on the second input (constant false in the whole hugr, above) + let sub_hugr_results = Machine::new(subview).run(TestContext, [(0.into(), pv_true())]); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(PartialValue::Top) + ); + for w in [0, 1] { + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(hugr.root(), w)), + None + ); + } +} + +#[test] +fn test_module() { + let mut modb = ModuleBuilder::new(); + let leaf_fn = modb + .define_function("leaf", Signature::new_endo(vec![bool_t(); 2])) + .unwrap(); + let outs = leaf_fn.input_wires(); + let leaf_fn = leaf_fn.finish_with_outputs(outs).unwrap(); + + let mut f2 = modb + .define_function("f2", Signature::new(bool_t(), vec![bool_t(); 2])) + .unwrap(); + let [inp] = f2.input_wires_arr(); + let cst_true = f2.add_load_value(Value::true_val()); + let f2_call = f2 + .call(leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG) + .unwrap(); + let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap(); + + let mut main = modb + .define_function("main", Signature::new(bool_t(), vec![bool_t(); 2])) + .unwrap(); + let [inp] = main.input_wires_arr(); + let cst_false = main.add_load_value(Value::false_val()); + let main_call = main + .call(leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG) + .unwrap(); + main.finish_with_outputs(main_call.outputs()).unwrap(); + let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); + let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); + + let results_just_main = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); + assert_eq!( + results_just_main.read_out_wire(Wire::new(f2_inp, 0)), + Some(PartialValue::Bottom) + ); + for call in [f2_call, main_call] { + // The first output of the Call comes from `main` because no value was fed in from f2 + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true()) + ); + // (Without reachability) the second output of the Call is the join of the two constant inputs from the two calls + assert_eq!( + results_just_main.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false()) + ); + } + + let results_two_calls = { + let mut m = Machine::new(&hugr); + m.prepopulate_df_inputs(f2.node(), [(0.into(), pv_true())]); + m.run(TestContext, [(0.into(), pv_false())]) + }; + + for call in [f2_call, main_call] { + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 0)), + Some(pv_true_or_false()) + ); + assert_eq!( + results_two_calls.read_out_wire(Wire::new(call.node(), 1)), + Some(pv_true_or_false()) + ); + } +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs new file mode 100644 index 000000000..50cf10318 --- /dev/null +++ b/hugr-passes/src/dataflow/value_row.rs @@ -0,0 +1,103 @@ +// Wrap a (known-length) row of values into a lattice. + +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; + +use ascent::{lattice::BoundedLattice, Lattice}; +use itertools::zip_eq; + +use super::{AbstractValue, PartialValue}; + +#[derive(PartialEq, Clone, Debug, Eq, Hash)] +pub(super) struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + *self.0.get_mut(idx).unwrap() = v; + self + } + + pub fn singleton(v: PartialValue) -> Self { + Self(vec![v]) + } + + /// The first value in this ValueRow must be a sum; + /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, + /// then appending the rest of the values in this row. + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option>> { + let vals = self[0].variant_values(variant, len)?; + Some(vals.into_iter().chain(self.0[1..].to_owned())) + } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 13dd47776..a2208430d 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod dataflow; pub mod force_order; mod half_node; pub mod lower; @@ -11,3 +12,29 @@ pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; + +#[cfg(test)] +pub(crate) mod test { + + use lazy_static::lazy_static; + + use hugr_core::extension::{ExtensionRegistry, PRELUDE}; + use hugr_core::std_extensions::arithmetic; + use hugr_core::std_extensions::collections; + use hugr_core::std_extensions::logic; + + lazy_static! { + /// A registry containing various extensions for testing. + pub(crate) static ref TEST_REG: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_ops::EXTENSION.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), + ]) + .unwrap(); + } +} diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index b799638fc..3c3d8a40c 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -80,7 +80,7 @@ pub fn lower_ops( mod test { use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::prelude::{Noop, BOOL_T}, + extension::prelude::{bool_t, Noop}, std_extensions::logic::LogicOp, types::Signature, HugrView, @@ -91,9 +91,9 @@ mod test { #[fixture] fn noop_hugr() -> Hugr { - let mut b = DFGBuilder::new(Signature::new_endo(BOOL_T).with_prelude()).unwrap(); + let mut b = DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); let out = b - .add_dataflow_op(Noop::new(BOOL_T), [b.input_wires().next().unwrap()]) + .add_dataflow_op(Noop::new(bool_t()), [b.input_wires().next().unwrap()]) .unwrap() .out_wire(0); b.finish_prelude_hugr_with_outputs([out]).unwrap() @@ -101,7 +101,7 @@ mod test { #[fixture] fn identity_hugr() -> Hugr { - let b = DFGBuilder::new(Signature::new_endo(BOOL_T)).unwrap(); + let b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let out = b.input_wires().next().unwrap(); b.finish_prelude_hugr_with_outputs([out]).unwrap() } @@ -110,7 +110,7 @@ mod test { fn test_replace(noop_hugr: Hugr) { let mut h = noop_hugr; let mut replaced = replace_many_ops(&mut h, |op| { - let noop = Noop::new(BOOL_T); + let noop = Noop::new(bool_t()); if op.cast() == Some(noop) { Some(LogicOp::Not) } else { @@ -121,7 +121,7 @@ mod test { assert_eq!(replaced.len(), 1); let (n, op) = replaced.remove(0); - assert_eq!(op, Noop::new(BOOL_T).into()); + assert_eq!(op, Noop::new(bool_t()).into()); assert_eq!(h.get_optype(n), &LogicOp::Not.into()); } @@ -130,7 +130,7 @@ mod test { let mut h = noop_hugr; let lowered = lower_ops(&mut h, |op| { - let noop = Noop::new(BOOL_T); + let noop = Noop::new(bool_t()); if op.cast() == Some(noop) { Some(identity_hugr.clone()) } else { diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index a34ecc351..1acda2ba7 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -157,13 +157,14 @@ fn mk_rep( #[cfg(test)] mod test { use std::collections::HashSet; + use std::sync::Arc; use hugr_core::extension::prelude::Lift; use itertools::Itertools; use rstest::rstest; use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; - use hugr_core::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; + use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize, PRELUDE_ID}; use hugr_core::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY}; use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::ops::constant::Value; @@ -178,28 +179,30 @@ mod test { const EXT_ID: ExtensionId = "TestExt"; } - fn extension() -> Extension { - let mut e = Extension::new(EXT_ID, hugr_core::extension::Version::new(0, 1, 0)); - e.add_op( - "Test".into(), - String::new(), - Signature::new( - type_row![QB_T, USIZE_T], - TypeRow::from(vec![Type::new_sum(vec![ - type_row![QB_T], - type_row![USIZE_T], - ])]), - ), + fn extension() -> Arc { + Extension::new_arc( + EXT_ID, + hugr_core::extension::Version::new(0, 1, 0), + |ext, extension_ref| { + ext.add_op( + "Test".into(), + String::new(), + Signature::new( + vec![qb_t(), usize_t()], + TypeRow::from(vec![Type::new_sum(vec![vec![qb_t()], vec![usize_t()]])]), + ), + extension_ref, + ) + .unwrap(); + }, ) - .unwrap(); - e } fn lifted_unary_unit_sum + AsRef, T>(b: &mut DFGWrapper) -> Wire { let lc = b.add_load_value(Value::unary_unit_sum()); let lift = b .add_dataflow_op( - Lift::new(type_row![Type::new_unit_sum(1)], PRELUDE_ID), + Lift::new(vec![Type::new_unit_sum(1)].into(), PRELUDE_ID), [lc], ) .unwrap(); @@ -224,14 +227,14 @@ mod test { */ use hugr_core::extension::prelude::Noop; - let loop_variants = type_row![QB_T]; - let exit_types = type_row![USIZE_T]; + let loop_variants: TypeRow = vec![qb_t()].into(); + let exit_types: TypeRow = vec![usize_t()].into(); let e = extension(); let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; - let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e.into()])?; + let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; - let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; + let n = no_b1.add_dataflow_op(Noop::new(qb_t()), no_b1.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; let mut test_block = h.block_builder( @@ -249,7 +252,7 @@ mod test { no_b1 } else { let mut no_b2 = h.simple_block_builder(endo_sig(loop_variants), 1)?; - let n = no_b2.add_dataflow_op(Noop::new(QB_T), no_b2.input_wires())?; + let n = no_b2.add_dataflow_op(Noop::new(qb_t()), no_b2.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b2); let nid = no_b2.finish_with_outputs(br, n.outputs())?; h.branch(&nid, 0, &no_b1)?; @@ -299,7 +302,7 @@ mod test { // And the Noop in the entry block is consumed by the custom Test op let tst = find_unique( h.nodes(), - |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension() != &PRELUDE_ID), + |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension_id() != &PRELUDE_ID), ); assert_eq!(h.get_parent(tst), Some(entry)); assert_eq!( @@ -325,24 +328,24 @@ mod test { .into_owned() .try_into() .unwrap(); - let mut h = CFGBuilder::new(inout_sig(QB_T, res_t.clone()))?; - let mut bb1 = h.simple_entry_builder(type_row![USIZE_T, QB_T], 1)?; + let mut h = CFGBuilder::new(inout_sig(qb_t(), res_t.clone()))?; + let mut bb1 = h.simple_entry_builder(vec![usize_t(), qb_t()].into(), 1)?; let [inw] = bb1.input_wires_arr(); let load_cst = bb1.add_load_value(ConstUsize::new(1)); let pred = lifted_unary_unit_sum(&mut bb1); let bb1 = bb1.finish_with_outputs(pred, [load_cst, inw])?; let mut bb2 = h.block_builder( - type_row![USIZE_T, QB_T], + vec![usize_t(), qb_t()].into(), vec![type_row![]], - type_row![QB_T, USIZE_T], + vec![qb_t(), usize_t()].into(), )?; let [u, q] = bb2.input_wires_arr(); let pred = lifted_unary_unit_sum(&mut bb2); let bb2 = bb2.finish_with_outputs(pred, [q, u])?; let mut bb3 = h.block_builder( - type_row![QB_T, USIZE_T], + vec![qb_t(), usize_t()].into(), vec![type_row![]], res_t.clone().into(), )?; @@ -355,7 +358,7 @@ mod test { h.branch(&bb2, 0, &bb3)?; h.branch(&bb3, 0, &h.exit_block())?; - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; + let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; let mut h = h.finish_hugr(®)?; let root = h.root(); merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); @@ -365,7 +368,7 @@ mod test { let [bb, _exit] = h.children(h.root()).collect::>().try_into().unwrap(); let tst = find_unique( h.nodes(), - |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension() != &PRELUDE_ID), + |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension_id() != &PRELUDE_ID), ); assert_eq!(h.get_parent(tst), Some(bb)); @@ -375,7 +378,10 @@ mod test { let [other_input] = tst_inputs.try_into().unwrap(); assert_eq!( h.get_optype(other_input), - &(LoadConstant { datatype: USIZE_T }.into()) + &(LoadConstant { + datatype: usize_t() + } + .into()) ); Ok(()) } diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index fa7106432..b8091c7af 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -571,16 +571,14 @@ pub(crate) mod test { endo_sig, BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::PRELUDE_REGISTRY; - use hugr_core::extension::{prelude::USIZE_T, ExtensionSet}; + use hugr_core::extension::{prelude::usize_t, ExtensionSet}; use hugr_core::hugr::rewrite::insert_identity::{IdentityInsertion, IdentityInsertionError}; use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::{ConstID, NodeHandle}; use hugr_core::ops::Value; - use hugr_core::type_row; - use hugr_core::types::{EdgeKind, Signature, Type}; + use hugr_core::types::{EdgeKind, Signature}; use hugr_core::utils::depth; - const NAT: Type = USIZE_T; pub fn group_by(h: HashMap) -> HashSet> { let mut res = HashMap::new(); @@ -601,23 +599,27 @@ pub(crate) mod test { // /-> left --\ // entry -> split > merge -> head -> tail -> exit // \-> right -/ \-<--<-/ - let mut cfg_builder = CFGBuilder::new(Signature::new_endo(NAT))?; + let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?; let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder_exts(type_row![NAT], 1, ExtensionSet::new())?, + cfg_builder.simple_entry_builder_exts( + vec![usize_t()].into(), + 1, + ExtensionSet::new(), + )?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; cfg_builder.branch(&entry, 0, &split)?; let head = n_identity( - cfg_builder.simple_block_builder(endo_sig(NAT), 1)?, + cfg_builder.simple_block_builder(endo_sig(usize_t()), 1)?, &const_unit, )?; let tail = n_identity( - cfg_builder.simple_block_builder(endo_sig(NAT), 2)?, + cfg_builder.simple_block_builder(endo_sig(usize_t()), 2)?, &pred_const, )?; cfg_builder.branch(&tail, 1, &head)?; @@ -844,7 +846,10 @@ pub(crate) mod test { const_pred: &ConstID, unit_const: &ConstID, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let split = n_identity(cfg.simple_block_builder(endo_sig(NAT), 2)?, const_pred)?; + let split = n_identity( + cfg.simple_block_builder(endo_sig(usize_t()), 2)?, + const_pred, + )?; let merge = build_then_else_merge_from_if(cfg, unit_const, split)?; Ok((split, merge)) } @@ -854,9 +859,18 @@ pub(crate) mod test { unit_const: &ConstID, split: BasicBlockID, ) -> Result { - let merge = n_identity(cfg.simple_block_builder(endo_sig(NAT), 1)?, unit_const)?; - let left = n_identity(cfg.simple_block_builder(endo_sig(NAT), 1)?, unit_const)?; - let right = n_identity(cfg.simple_block_builder(endo_sig(NAT), 1)?, unit_const)?; + let merge = n_identity( + cfg.simple_block_builder(endo_sig(usize_t()), 1)?, + unit_const, + )?; + let left = n_identity( + cfg.simple_block_builder(endo_sig(usize_t()), 1)?, + unit_const, + )?; + let right = n_identity( + cfg.simple_block_builder(endo_sig(usize_t()), 1)?, + unit_const, + )?; cfg.branch(&split, 0, &left)?; cfg.branch(&split, 1, &right)?; cfg.branch(&left, 0, &merge)?; @@ -869,18 +883,18 @@ pub(crate) mod test { // \-> right -/ \-<--<-/ // Result is Hugr plus merge and tail blocks fn build_cond_then_loop_cfg() -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { - let mut cfg_builder = CFGBuilder::new(Signature::new_endo(NAT))?; + let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?; let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2")); let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 2)?, + cfg_builder.simple_entry_builder(vec![usize_t()].into(), 2)?, &pred_const, )?; let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?; // The merge block is also the loop header (so it merges three incoming control-flow edges) let tail = n_identity( - cfg_builder.simple_block_builder(endo_sig(NAT), 2)?, + cfg_builder.simple_block_builder(endo_sig(usize_t()), 2)?, &pred_const, )?; cfg_builder.branch(&tail, 1, &merge)?; @@ -896,7 +910,7 @@ pub(crate) mod test { pub(crate) fn build_conditional_in_loop_cfg( separate_headers: bool, ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { - let mut cfg_builder = CFGBuilder::new(Signature::new_endo(NAT))?; + let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?; let (head, tail) = build_conditional_in_loop(&mut cfg_builder, separate_headers)?; let h = cfg_builder.finish_prelude_hugr()?; Ok((h, head, tail)) @@ -910,14 +924,14 @@ pub(crate) mod test { let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 1)?, + cfg_builder.simple_entry_builder(vec![usize_t()].into(), 1)?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(cfg_builder, &pred_const, &const_unit)?; let head = if separate_headers { let head = n_identity( - cfg_builder.simple_block_builder(endo_sig(NAT), 1)?, + cfg_builder.simple_block_builder(endo_sig(usize_t()), 1)?, &const_unit, )?; cfg_builder.branch(&head, 0, &split)?; @@ -927,7 +941,7 @@ pub(crate) mod test { split }; let tail = n_identity( - cfg_builder.simple_block_builder(endo_sig(NAT), 2)?, + cfg_builder.simple_block_builder(endo_sig(usize_t()), 2)?, &pred_const, )?; cfg_builder.branch(&tail, 1, &head)?; diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 347e5be27..598622089 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -42,28 +42,28 @@ pub fn ensure_no_nonlocal_edges(hugr: &impl HugrView) -> Result<(), NonLocalEdge mod test { use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, - extension::{ - prelude::{Noop, BOOL_T}, - EMPTY_REG, - }, + extension::prelude::{bool_t, Noop}, ops::handle::NodeHandle, type_row, types::Signature, }; + use crate::test::TEST_REG; + use super::*; #[test] fn ensures_no_nonlocal_edges() { let hugr = { - let mut builder = DFGBuilder::new(Signature::new_endo(BOOL_T).with_prelude()).unwrap(); + let mut builder = + DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); let [in_w] = builder.input_wires_arr(); let [out_w] = builder - .add_dataflow_op(Noop::new(BOOL_T), [in_w]) + .add_dataflow_op(Noop::new(bool_t()), [in_w]) .unwrap() .outputs_arr(); builder - .finish_hugr_with_outputs([out_w], &EMPTY_REG) + .finish_hugr_with_outputs([out_w], &TEST_REG) .unwrap() }; ensure_no_nonlocal_edges(&hugr).unwrap(); @@ -72,14 +72,15 @@ mod test { #[test] fn find_nonlocal_edges() { let (hugr, edge) = { - let mut builder = DFGBuilder::new(Signature::new_endo(BOOL_T).with_prelude()).unwrap(); + let mut builder = + DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); let [in_w] = builder.input_wires_arr(); let ([out_w], edge) = { let mut dfg_builder = builder - .dfg_builder(Signature::new(type_row![], BOOL_T).with_prelude(), []) + .dfg_builder(Signature::new(type_row![], bool_t()).with_prelude(), []) .unwrap(); let noop = dfg_builder - .add_dataflow_op(Noop::new(BOOL_T), [in_w]) + .add_dataflow_op(Noop::new(bool_t()), [in_w]) .unwrap(); let noop_edge = (noop.node(), IncomingPort::from(0)); ( @@ -92,7 +93,7 @@ mod test { }; ( builder - .finish_hugr_with_outputs([out_w], &EMPTY_REG) + .finish_hugr_with_outputs([out_w], &TEST_REG) .unwrap(), edge, ) diff --git a/hugr-py/pyproject.toml b/hugr-py/pyproject.toml index b7ed4714f..aaec4df67 100644 --- a/hugr-py/pyproject.toml +++ b/hugr-py/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ ] dependencies = [ - "pydantic>=2.8,<2.10", + "pydantic>=2.8,<2.11", "pydantic-extra-types>=2.9.0", "semver>=3.0.2", "graphviz>=0.20.3", diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 3abcee535..9cc64f610 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -1,21 +1,22 @@ //! Builders and utilities for benchmarks. +use std::sync::Arc; + use hugr::builder::{ BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }; -use hugr::extension::prelude::{BOOL_T, QB_T, USIZE_T}; +use hugr::extension::prelude::{bool_t, qb_t, usize_t}; use hugr::extension::PRELUDE_REGISTRY; use hugr::ops::OpName; use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; -use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; +use hugr::std_extensions::arithmetic::float_types::float64_type; use hugr::types::Signature; use hugr::{type_row, Extension, Hugr, Node}; use lazy_static::lazy_static; pub fn simple_dfg_hugr() -> Hugr { - let dfg_builder = - DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap(); + let dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap(); let [i1] = dfg_builder.input_wires_arr(); dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap() } @@ -23,7 +24,7 @@ pub fn simple_dfg_hugr() -> Hugr { pub fn simple_cfg_builder + AsRef>( cfg_builder: &mut CFGBuilder, ) -> Result<(), BuildError> { - let sum2_variants = vec![type_row![USIZE_T], type_row![USIZE_T]]; + let sum2_variants = vec![vec![usize_t()].into(), vec![usize_t()].into()]; let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?; let entry = { let [inw] = entry_b.input_wires_arr(); @@ -31,8 +32,8 @@ pub fn simple_cfg_builder + AsRef>( let sum = entry_b.make_sum(1, sum2_variants, [inw])?; entry_b.finish_with_outputs(sum, [])? }; - let mut middle_b = cfg_builder - .simple_block_builder(Signature::new(type_row![USIZE_T], type_row![USIZE_T]), 1)?; + let mut middle_b = + cfg_builder.simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?; let middle = { let c = middle_b.add_load_const(hugr::ops::Value::unary_unit_sum()); let [inw] = middle_b.input_wires_arr(); @@ -47,41 +48,41 @@ pub fn simple_cfg_builder + AsRef>( pub fn simple_cfg_hugr() -> Hugr { let mut cfg_builder = - CFGBuilder::new(Signature::new(type_row![USIZE_T], type_row![USIZE_T])).unwrap(); + CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); simple_cfg_builder(&mut cfg_builder).unwrap(); cfg_builder.finish_prelude_hugr().unwrap() } lazy_static! { - static ref QUANTUM_EXT: Extension = { - let mut extension = Extension::new( + static ref QUANTUM_EXT: Arc = { + Extension::new_arc( "bench.quantum".try_into().unwrap(), hugr::extension::Version::new(0, 0, 0), - ); - - extension - .add_op( - OpName::new_inline("H"), - "".into(), - Signature::new_endo(QB_T), - ) - .unwrap(); - extension - .add_op( - OpName::new_inline("Rz"), - "".into(), - Signature::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("CX"), - "".into(), - Signature::new_endo(type_row![QB_T, QB_T]), - ) - .unwrap(); - extension + |ext, extension_ref| { + ext.add_op( + OpName::new_inline("H"), + "".into(), + Signature::new_endo(qb_t()), + extension_ref, + ) + .unwrap(); + ext.add_op( + OpName::new_inline("Rz"), + "".into(), + Signature::new(vec![qb_t(), float64_type()], vec![qb_t()]), + extension_ref, + ) + .unwrap(); + + ext.add_op( + OpName::new_inline("CX"), + "".into(), + Signature::new_endo(vec![qb_t(), qb_t()]), + extension_ref, + ) + .unwrap(); + }, + ) }; } @@ -104,7 +105,7 @@ pub fn circuit(layers: usize) -> (Hugr, Vec) { // .instantiate_extension_op("Rz", [], &FLOAT_OPS_REGISTRY) // .unwrap(); let signature = - Signature::new_endo(type_row![QB_T, QB_T]).with_extension_delta(QUANTUM_EXT.name().clone()); + Signature::new_endo(vec![qb_t(), qb_t()]).with_extension_delta(QUANTUM_EXT.name().clone()); let mut module_builder = ModuleBuilder::new(); let mut f_build = module_builder.define_function("main", signature).unwrap(); diff --git a/hugr/benches/benchmarks/types.rs b/hugr/benches/benchmarks/types.rs index a109a05d8..07cbdb24d 100644 --- a/hugr/benches/benchmarks/types.rs +++ b/hugr/benches/benchmarks/types.rs @@ -1,6 +1,6 @@ // Required for black_box uses #![allow(clippy::unit_arg)] -use hugr::extension::prelude::{QB_T, USIZE_T}; +use hugr::extension::prelude::{qb_t, usize_t}; use hugr::ops::AliasDecl; use hugr::types::{Signature, Type, TypeBound}; @@ -8,8 +8,8 @@ use criterion::{black_box, criterion_group, AxisScale, Criterion, PlotConfigurat /// Construct a complex type. fn make_complex_type() -> Type { - let qb = QB_T; - let int = USIZE_T; + let qb = qb_t(); + let int = usize_t(); let q_register = Type::new_tuple(vec![qb; 8]); let b_register = Type::new_tuple(vec![int; 8]); let q_alias = Type::new_alias(AliasDecl::new("QReg", TypeBound::Any)); diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 94dd141a3..3a96696e8 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -28,17 +28,17 @@ //! a simple quantum extension and then use the [[builder::DFGBuilder]] as follows: //! ``` //! use hugr::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr, inout_sig}; -//! use hugr::extension::prelude::{BOOL_T, QB_T}; +//! use hugr::extension::prelude::{bool_t, qb_t}; //! use hugr::hugr::Hugr; //! use hugr::type_row; //! use hugr::types::FuncValueType; //! -//! // The type of qubits, `QB_T` is in the prelude but, by default, no gateset +//! // The type of qubits, `qb_t()` is in the prelude but, by default, no gateset //! // is defined. This module provides Hadamard and CX gates. //! mod mini_quantum_extension { //! use hugr::{ //! extension::{ -//! prelude::{BOOL_T, QB_T}, +//! prelude::{bool_t, qb_t}, //! ExtensionId, ExtensionRegistry, PRELUDE, Version, //! }, //! ops::{ExtensionOp, OpName}, @@ -51,41 +51,37 @@ //! use lazy_static::lazy_static; //! //! fn one_qb_func() -> PolyFuncTypeRV { -//! FuncValueType::new_endo(type_row![QB_T]).into() +//! FuncValueType::new_endo(vec![qb_t()]).into() //! } //! //! fn two_qb_func() -> PolyFuncTypeRV { -//! FuncValueType::new_endo(type_row![QB_T, QB_T]).into() +//! FuncValueType::new_endo(vec![qb_t(), qb_t()]).into() //! } //! /// The extension identifier. //! pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("mini.quantum"); //! pub const VERSION: Version = Version::new(0, 1, 0); //! fn extension() -> Arc { -//! let mut extension = Extension::new(EXTENSION_ID, VERSION); +//! Extension::new_arc(EXTENSION_ID, VERSION, |ext, extension_ref| { +//! ext.add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func(), extension_ref) +//! .unwrap(); //! -//! extension -//! .add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func()) -//! .unwrap(); -//! -//! extension -//! .add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func()) -//! .unwrap(); +//! ext.add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func(), extension_ref) +//! .unwrap(); //! -//! extension -//! .add_op( +//! ext.add_op( //! OpName::new_inline("Measure"), //! "Measure a qubit, returning the qubit and the measurement result.".into(), -//! FuncValueType::new(type_row![QB_T], type_row![QB_T, BOOL_T]), +//! FuncValueType::new(vec![qb_t()], vec![qb_t(), bool_t()]), +//! extension_ref, //! ) //! .unwrap(); -//! -//! Arc::new(extension) +//! }) //! } //! //! lazy_static! { //! /// Quantum extension definition. //! pub static ref EXTENSION: Arc = extension(); -//! static ref REG: ExtensionRegistry = +//! pub static ref REG: ExtensionRegistry = //! ExtensionRegistry::try_new([EXTENSION.clone(), PRELUDE.clone()]).unwrap(); //! } //! fn get_gate(gate_name: impl Into) -> ExtensionOp { @@ -107,7 +103,7 @@ //! } //! } //! -//! use mini_quantum_extension::{cx_gate, h_gate, measure}; +//! use mini_quantum_extension::{cx_gate, h_gate, measure, REG}; //! //! // ┌───┐ //! // q_0: ┤ H ├──■───── @@ -117,15 +113,15 @@ //! // c: ╚═ //! fn make_dfg_hugr() -> Result { //! let mut dfg_builder = DFGBuilder::new(inout_sig( -//! type_row![QB_T, QB_T], -//! type_row![QB_T, QB_T, BOOL_T], +//! vec![qb_t(), qb_t()], +//! vec![qb_t(), qb_t(), bool_t()], //! ))?; //! let [wire0, wire1] = dfg_builder.input_wires_arr(); //! let h0 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?; //! let h1 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; //! let cx = dfg_builder.add_dataflow_op(cx_gate(), h0.outputs().chain(h1.outputs()))?; //! let measure = dfg_builder.add_dataflow_op(measure(), cx.outputs().last())?; -//! dfg_builder.finish_prelude_hugr_with_outputs(cx.outputs().take(1).chain(measure.outputs())) +//! dfg_builder.finish_hugr_with_outputs(cx.outputs().take(1).chain(measure.outputs()), ®) //! } //! //! let h: Hugr = make_dfg_hugr().unwrap(); diff --git a/justfile b/justfile index 78267180a..ed8c1b73c 100644 --- a/justfile +++ b/justfile @@ -32,7 +32,7 @@ fix language="[rust|python]": (_run_lang language \ # Format the code. format language="[rust|python]": (_run_lang language \ - "cargo fmt" \ + "cargo fmt --all" \ "uv run ruff format" ) diff --git a/uv.lock b/uv.lock index 9dc359a56..7f8bd6af4 100644 --- a/uv.lock +++ b/uv.lock @@ -280,7 +280,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "graphviz", specifier = ">=0.20.3" }, - { name = "pydantic", specifier = ">=2.8,<2.10" }, + { name = "pydantic", specifier = ">=2.8,<2.11" }, { name = "pydantic-extra-types", specifier = ">=2.9.0" }, { name = "semver", specifier = ">=3.0.2" }, { name = "sphinx", marker = "extra == 'docs'", specifier = ">=8.0.2,<9.0.0" },