Skip to content

Commit

Permalink
Fix post-merge extension issues by using TEST_REG everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Dec 4, 2024
1 parent 26466db commit 79e98f2
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions hugr-passes/src/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use itertools::Itertools;
use lazy_static::lazy_static;
use rstest::rstest;

use crate::test::TEST_REG;
use hugr_core::builder::{
endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
SubContainer,
Expand All @@ -13,12 +12,11 @@ use hugr_core::extension::prelude::{
bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, MakeTuple,
UnpackTuple,
};
use hugr_core::extension::ExtensionRegistry;

use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::views::{DescendantsGraph, HierarchyView};
use hugr_core::ops::{constant::CustomConst, handle::BasicBlockID, OpTag, OpTrait, OpType, Value};
use hugr_core::std_extensions::arithmetic::{
self,
conversions::ConvertOpDef,
float_ops::FloatOps,
float_types::{float64_type, ConstF64},
Expand All @@ -29,8 +27,10 @@ use hugr_core::std_extensions::logic::LogicOp;
use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV};
use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node};

use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle};
use crate::dataflow::{partial_from_const, DFContext, PartialValue};
use crate::test::TEST_REG;

use super::{constant_fold_pass, ConstFoldContext, ConstFoldPass, ValueHandle};

#[rstest]
#[case(ConstInt::new_u(4, 2).unwrap(), true)]
Expand Down Expand Up @@ -1423,12 +1423,11 @@ fn test_via_part_unknown_tuple() {
let res = builder
.add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a, c])
.unwrap();
let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap();
let mut hugr = builder
.finish_hugr_with_outputs(res.outputs(), &reg)
.finish_hugr_with_outputs(res.outputs(), &TEST_REG)
.unwrap();

constant_fold_pass(&mut hugr, &reg);
constant_fold_pass(&mut hugr, &TEST_REG);

// We expect: root dfg, input, output, const 9, load constant, iadd
let mut expected_op_tags: HashSet<_, std::hash::RandomState> = [
Expand All @@ -1452,8 +1451,7 @@ fn test_via_part_unknown_tuple() {
assert!(expected_op_tags.is_empty());
}

fn tail_loop_hugr(int_cst: ConstInt) -> (Hugr, ExtensionRegistry) {
let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap();
fn tail_loop_hugr(int_cst: ConstInt) -> Hugr {
let int_ty = int_cst.get_type();
let lw = int_cst.log_width();
let mut builder = DFGBuilder::new(inout_sig(bool_t(), int_ty.clone())).unwrap();
Expand All @@ -1471,17 +1469,17 @@ fn tail_loop_hugr(int_cst: ConstInt) -> (Hugr, ExtensionRegistry) {
.unwrap();

let hugr = builder
.finish_hugr_with_outputs(add.outputs(), &reg)
.finish_hugr_with_outputs(add.outputs(), &TEST_REG)
.unwrap();
(hugr, reg)
hugr
}

#[test]
fn test_tail_loop_unknown() {
let cst5 = ConstInt::new_u(3, 5).unwrap();
let (mut h, reg) = tail_loop_hugr(cst5.clone());
let mut h = tail_loop_hugr(cst5.clone());

constant_fold_pass(&mut h, &reg);
constant_fold_pass(&mut h, &TEST_REG);
// Must keep the loop, even though we know the output, in case the output doesn't happen
assert_eq!(h.node_count(), 12);
let tl = h
Expand Down Expand Up @@ -1509,15 +1507,15 @@ fn test_tail_loop_unknown() {
.map(tag_string)
.sorted()
.collect::<Vec<_>>(),
Vec::from([
vec![
"Const",
"Const",
"Input",
"LoadConst",
"LoadConst",
"Output",
"TailLoop"
])
]
);

assert_eq!(
Expand Down Expand Up @@ -1563,26 +1561,25 @@ fn test_tail_loop_unknown() {

#[test]
fn test_tail_loop_never_iterates() {
let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap());
let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap());
ConstFoldPass::default()
.with_inputs([(0, Value::true_val())]) // true = 1 = break
.run(&mut h, &reg)
.run(&mut h, &TEST_REG)
.unwrap();
assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into());
}

#[test]
fn test_tail_loop_increase_termination() {
let (mut h, reg) = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap());
let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap());
ConstFoldPass::default()
.allow_increase_termination()
.run(&mut h, &reg)
.run(&mut h, &TEST_REG)
.unwrap();
assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into());
}

fn cfg_hugr() -> (Hugr, ExtensionRegistry) {
let reg = ExtensionRegistry::try_new([arithmetic::int_types::EXTENSION.to_owned()]).unwrap();
fn cfg_hugr() -> Hugr {
let int_ty = INT_TYPES[4].clone();
let mut builder = DFGBuilder::new(inout_sig(vec![bool_t(); 2], int_ty.clone())).unwrap();
let [p, q] = builder.input_wires_arr();
Expand Down Expand Up @@ -1621,9 +1618,9 @@ fn cfg_hugr() -> (Hugr, ExtensionRegistry) {
let cfg = cfg.finish_sub_container().unwrap();
let nested = nested.finish_with_outputs(cfg.outputs()).unwrap();
let hugr = builder
.finish_hugr_with_outputs(nested.outputs(), &reg)
.finish_hugr_with_outputs(nested.outputs(), &TEST_REG)
.unwrap();
(hugr, reg)
hugr
}

#[rstest]
Expand All @@ -1637,11 +1634,11 @@ fn test_cfg(
#[case] fold_blk: bool,
#[case] fold_res: Option<u16>,
) {
let (backup, reg) = cfg_hugr();
let backup = cfg_hugr();
let mut hugr = backup.clone();
let pass = ConstFoldPass::default()
.with_inputs(inputs.into_iter().map(|(p, b)| (*p, Value::from_bool(*b))));
pass.run(&mut hugr, &reg).unwrap();
pass.run(&mut hugr, &TEST_REG).unwrap();
// CFG inside DFG retained
let nested = hugr
.children(hugr.root())
Expand Down Expand Up @@ -1701,7 +1698,7 @@ fn test_cfg(

let mut hugr2 = backup;
pass.allow_increase_termination()
.run(&mut hugr2, &reg)
.run(&mut hugr2, &TEST_REG)
.unwrap();
assert_fully_folded(&hugr2, &res_v);
} else {
Expand Down

0 comments on commit 79e98f2

Please sign in to comment.