diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 7958a1194..35888f1b5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -4,7 +4,6 @@ use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; -use crate::test::TEST_REG; use hugr_core::builder::{ endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, @@ -13,12 +12,11 @@ use hugr_core::extension::prelude::{ bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, MakeTuple, UnpackTuple, }; -use hugr_core::extension::ExtensionRegistry; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::{constant::CustomConst, handle::BasicBlockID, OpTag, OpTrait, OpType, Value}; use hugr_core::std_extensions::arithmetic::{ - self, conversions::ConvertOpDef, float_ops::FloatOps, float_types::{float64_type, ConstF64}, @@ -29,8 +27,10 @@ use hugr_core::std_extensions::logic::LogicOp; use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; -use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::test::TEST_REG; + +use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle}; #[rstest] #[case(ConstInt::new_u(4, 2).unwrap(), true)] @@ -1423,12 +1423,11 @@ fn test_via_part_unknown_tuple() { let res = builder .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a, c]) .unwrap(); - let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); let mut hugr = builder - .finish_hugr_with_outputs(res.outputs(), ®) + .finish_hugr_with_outputs(res.outputs(), &TEST_REG) .unwrap(); - constant_fold_pass(&mut hugr, ®); + constant_fold_pass(&mut hugr, &TEST_REG); // We expect: root dfg, input, output, const 9, load constant, iadd let mut expected_op_tags: HashSet<_, std::hash::RandomState> = [ @@ -1452,8 +1451,7 @@ fn test_via_part_unknown_tuple() { assert!(expected_op_tags.is_empty()); } -fn tail_loop_hugr(int_cst: ConstInt) -> (Hugr, ExtensionRegistry) { - let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); +fn tail_loop_hugr(int_cst: ConstInt) -> Hugr { let int_ty = int_cst.get_type(); let lw = int_cst.log_width(); let mut builder = DFGBuilder::new(inout_sig(bool_t(), int_ty.clone())).unwrap(); @@ -1471,17 +1469,17 @@ fn tail_loop_hugr(int_cst: ConstInt) -> (Hugr, ExtensionRegistry) { .unwrap(); let hugr = builder - .finish_hugr_with_outputs(add.outputs(), ®) + .finish_hugr_with_outputs(add.outputs(), &TEST_REG) .unwrap(); - (hugr, reg) + hugr } #[test] fn test_tail_loop_unknown() { let cst5 = ConstInt::new_u(3, 5).unwrap(); - let (mut h, reg) = tail_loop_hugr(cst5.clone()); + let mut h = tail_loop_hugr(cst5.clone()); - constant_fold_pass(&mut h, ®); + constant_fold_pass(&mut h, &TEST_REG); // Must keep the loop, even though we know the output, in case the output doesn't happen assert_eq!(h.node_count(), 12); let tl = h @@ -1509,7 +1507,7 @@ fn test_tail_loop_unknown() { .map(tag_string) .sorted() .collect::>(), - Vec::from([ + vec![ "Const", "Const", "Input", @@ -1517,7 +1515,7 @@ fn test_tail_loop_unknown() { "LoadConst", "Output", "TailLoop" - ]) + ] ); assert_eq!( @@ -1563,26 +1561,25 @@ fn test_tail_loop_unknown() { #[test] fn test_tail_loop_never_iterates() { - let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); + let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); ConstFoldPass::default() .with_inputs([(0, Value::true_val())]) // true = 1 = break - .run(&mut h, ®) + .run(&mut h, &TEST_REG) .unwrap(); assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); } #[test] fn test_tail_loop_increase_termination() { - let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); + let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap()); ConstFoldPass::default() .allow_increase_termination() - .run(&mut h, ®) + .run(&mut h, &TEST_REG) .unwrap(); assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into()); } -fn cfg_hugr() -> (Hugr, ExtensionRegistry) { - let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap(); +fn cfg_hugr() -> Hugr { let int_ty = INT_TYPES[4].clone(); let mut builder = DFGBuilder::new(inout_sig(vec![bool_t(); 2], int_ty.clone())).unwrap(); let [p, q] = builder.input_wires_arr(); @@ -1621,9 +1618,9 @@ fn cfg_hugr() -> (Hugr, ExtensionRegistry) { let cfg = cfg.finish_sub_container().unwrap(); let nested = nested.finish_with_outputs(cfg.outputs()).unwrap(); let hugr = builder - .finish_hugr_with_outputs(nested.outputs(), ®) + .finish_hugr_with_outputs(nested.outputs(), &TEST_REG) .unwrap(); - (hugr, reg) + hugr } #[rstest] @@ -1637,11 +1634,11 @@ fn test_cfg( #[case] fold_blk: bool, #[case] fold_res: Option, ) { - let (backup, reg) = cfg_hugr(); + let backup = cfg_hugr(); let mut hugr = backup.clone(); let pass = ConstFoldPass::default() .with_inputs(inputs.into_iter().map(|(p, b)| (*p, Value::from_bool(*b)))); - pass.run(&mut hugr, ®).unwrap(); + pass.run(&mut hugr, &TEST_REG).unwrap(); // CFG inside DFG retained let nested = hugr .children(hugr.root()) @@ -1701,7 +1698,7 @@ fn test_cfg( let mut hugr2 = backup; pass.allow_increase_termination() - .run(&mut hugr2, ®) + .run(&mut hugr2, &TEST_REG) .unwrap(); assert_fully_folded(&hugr2, &res_v); } else {