Skip to content

Commit

Permalink
Rewrite with rstest and struct fixture to common-up node extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed May 11, 2024
1 parent a3b4211 commit 871d684
Showing 1 changed file with 121 additions and 80 deletions.
201 changes: 121 additions & 80 deletions hugr/src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,74 +263,107 @@ mod test {
use crate::{type_row, Hugr, HugrView, Node};
use cool_asserts::assert_matches;
use itertools::Itertools;
use rstest::rstest;

use super::{OutlineCfg, OutlineCfgError};

// /-> left --\
// entry > merge -> head -> tail -> exit
// \-> right -/ \-<--<-/
// Result is Hugr plus merge and tail blocks
fn build_cond_then_loop_cfg() -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let block_ty = FunctionType::new_endo(USIZE_T);
let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
fn n_identity(
mut bbldr: BlockBuilder<&mut Hugr>,
cst: &ConstID,
) -> Result<BasicBlockID, BuildError> {
let pred = bbldr.load_const(cst);
let vals = bbldr.input_wires();
bbldr.finish_with_outputs(pred, vals)
}
let id_block = |c: &mut CFGBuilder<_>| {
n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
};
/// /-> left --\
/// entry > merge -> head -> tail -> exit
/// \-> right -/ \-<--<-/
struct CondThenLoopCfg {
h: Hugr,
left: Node,
right: Node,
merge: Node,
head: Node,
tail: Node,
}
impl CondThenLoopCfg {
fn new() -> Result<CondThenLoopCfg, BuildError> {
let block_ty = FunctionType::new_endo(USIZE_T);
let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
fn n_identity(
mut bbldr: BlockBuilder<&mut Hugr>,
cst: &ConstID,
) -> Result<BasicBlockID, BuildError> {
let pred = bbldr.load_const(cst);
let vals = bbldr.input_wires();
bbldr.finish_with_outputs(pred, vals)
}
let id_block = |c: &mut CFGBuilder<_>| {
n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
};

let entry = n_identity(
cfg_builder.simple_entry_builder(USIZE_T.into(), 2, ExtensionSet::new())?,
&pred_const,
)?;

let entry = n_identity(
cfg_builder.simple_entry_builder(USIZE_T.into(), 2, ExtensionSet::new())?,
&pred_const,
)?;
let merge = {
let merge = id_block(&mut cfg_builder)?;
let left = id_block(&mut cfg_builder)?;
let right = id_block(&mut cfg_builder)?;
cfg_builder.branch(&entry, 0, &left)?;
cfg_builder.branch(&entry, 1, &right)?;

let merge = id_block(&mut cfg_builder)?;
cfg_builder.branch(&left, 0, &merge)?;
cfg_builder.branch(&right, 0, &merge)?;
merge
};
let head = id_block(&mut cfg_builder)?;
cfg_builder.branch(&merge, 0, &head)?;
let tail = n_identity(
cfg_builder.simple_block_builder(FunctionType::new_endo(USIZE_T), 2)?,
&pred_const,
)?;
cfg_builder.branch(&tail, 1, &head)?;
cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body"
let exit = cfg_builder.exit_block();
cfg_builder.branch(&tail, 0, &exit)?;

let h = cfg_builder.finish_prelude_hugr()?;
Ok((h, merge, tail))

let head = id_block(&mut cfg_builder)?;
cfg_builder.branch(&merge, 0, &head)?;
let tail = n_identity(
cfg_builder.simple_block_builder(FunctionType::new_endo(USIZE_T), 2)?,
&pred_const,
)?;
cfg_builder.branch(&tail, 1, &head)?;
cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body"
let exit = cfg_builder.exit_block();
cfg_builder.branch(&tail, 0, &exit)?;

let h = cfg_builder.finish_prelude_hugr()?;
let (left, right) = (left.node(), right.node());
let (merge, head, tail) = (merge.node(), head.node(), tail.node());
Ok(Self {
h,
left,
right,
merge,
head,
tail,
})
}
fn entry_exit(&self) -> (Node, Node) {
self.h
.children(self.h.root())
.take(2)
.collect_tuple()
.unwrap()
}
}

#[test]
fn test_outline_cfg_errors() {
let (mut h, merge, tail) = build_cond_then_loop_cfg().unwrap();
let (merge, tail) = (merge.node(), tail.node());
let head = h.input_neighbours(tail).exactly_one().unwrap();
assert_eq!(h.output_neighbours(merge).collect_vec(), vec![head]);
let entry = h.children(h.root()).next().unwrap();
#[rstest::fixture]
fn cond_then_loop_cfg() -> CondThenLoopCfg {
CondThenLoopCfg::new().unwrap()
}

#[rstest]
fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) {
let (entry, _) = cond_then_loop_cfg.entry_exit();
let CondThenLoopCfg {
mut h,
left,
right,
merge,
head,
tail,
} = cond_then_loop_cfg;
let backup = h.clone();

let r = h.apply_rewrite(OutlineCfg::new([tail]));
assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _)));
assert_eq!(h, backup);

let [left, right]: [Node; 2] = h.output_neighbours(entry).collect_vec().try_into().unwrap();
let r = h.apply_rewrite(OutlineCfg::new([entry, left, right]));
assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b))
=> assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right])));
Expand All @@ -348,19 +381,20 @@ mod test {
assert_eq!(h, backup);
}

#[test]
fn test_outline_cfg() {
#[rstest::rstest]
fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) {
// Outline the loop, producing:
// /-> left -->\
// entry merge -> newblock -> exit
// \-> right ->/

let (mut h, merge, tail) = build_cond_then_loop_cfg().unwrap();
let (merge, tail) = (merge.node(), tail.node());
let exit = h.children(h.root()).nth(1).unwrap();
let head = h.input_neighbours(tail).exactly_one().unwrap();
assert_eq!(h.output_neighbours(merge.node()).collect_vec(), vec![head]);

let (_, exit) = cond_then_loop_cfg.entry_exit();
let CondThenLoopCfg {
mut h,
merge,
head,
tail,
..
} = cond_then_loop_cfg;
let root = h.root();
let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]);
assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]);
Expand All @@ -371,36 +405,38 @@ mod test {
);
}

#[test]
fn test_outline_cfg_multiple_in_edges() {
#[rstest]
fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) {
// Outline merge, head and tail, producing
// /-> left -->\
// entry newblock -> exit
// \-> right ->/
let (mut h, merge, tail) = build_cond_then_loop_cfg().unwrap();
let (merge, tail) = (merge.node(), tail.node());
let exit = h.children(h.root()).nth(1).unwrap();
let left_and_right = h.input_neighbours(merge).collect::<HashSet<_>>();
assert_eq!(left_and_right.len(), 2);
let head = h.input_neighbours(tail).exactly_one().unwrap();
assert_eq!(h.output_neighbours(merge.node()).collect_vec(), vec![head]);
let (_, exit) = cond_then_loop_cfg.entry_exit();
let CondThenLoopCfg {
mut h,
left,
right,
merge,
head,
tail,
} = cond_then_loop_cfg;

let root = h.root();
let (new_block, _, inner_exit) =
outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]);
assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
assert_eq!(
h.input_neighbours(new_block).collect::<HashSet<_>>(),
left_and_right
HashSet::from([left, right])
);
assert_eq!(
h.output_neighbours(tail).collect::<HashSet<Node>>(),
HashSet::from([head, inner_exit])
);
}

#[test]
fn test_outline_cfg_subregion() {
#[rstest]
fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) {
// Outline the loop, as above, but with the CFG inside a Function + Module,
// operating via a SiblingMut
let mut module_builder = ModuleBuilder::new();
Expand All @@ -411,10 +447,12 @@ mod test {
)
.unwrap();
let [i1] = fbuild.input_wires_arr();
let (h, _, _) = build_cond_then_loop_cfg().unwrap();
let cfg = fbuild.add_hugr_with_wires(h, [i1]).unwrap();
let cfg = fbuild
.add_hugr_with_wires(cond_then_loop_cfg.h, [i1])
.unwrap();
fbuild.finish_with_outputs(cfg.outputs()).unwrap();
let mut h = module_builder.finish_prelude_hugr().unwrap();
// `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually:
let cfg = cfg.node();
let exit_node = h.children(cfg).nth(1).unwrap();
let tail = h.input_neighbours(exit_node).exactly_one().unwrap();
Expand All @@ -433,19 +471,22 @@ mod test {
h.update_validate(&PRELUDE_REGISTRY).unwrap();
}

#[test]
fn test_outline_cfg_move_entry() {
#[rstest]
fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) {
// Outline the conditional, producing
//
// newblock -> head -> tail -> exit
// \<--</
// (where the new block becomes the entry block)

let (mut h, merge, tail) = build_cond_then_loop_cfg().unwrap();
let (entry, _) = h.children(h.root()).take(2).collect_tuple().unwrap();
let (left, right) = h.output_neighbours(entry).take(2).collect_tuple().unwrap();
let (merge, _) = (merge.node(), tail.node());
let head = h.output_neighbours(merge).exactly_one().unwrap();
let (entry, _) = cond_then_loop_cfg.entry_exit();
let CondThenLoopCfg {
mut h,
left,
right,
merge,
head,
..
} = cond_then_loop_cfg;

let root = h.root();
let (new_block, _, _) =
Expand Down

0 comments on commit 871d684

Please sign in to comment.