Skip to content

Commit

Permalink
Add SiblingMut allowing modification of a SiblingGraph (#522)
Browse files Browse the repository at this point in the history
* Combine `HugrView::get_io` methods - could pull that out into a
preliminary PR, it's just tidying a wrapper function
* Also `HugrMut::add_op_with_parent` is just a wrapper around
`add_node_with_parent` so make that a default-impl too
* SibingMut implementation parallels SiblingView for many methods, but
not all, as does not its own PortGraph
* Reimplement OutlineCfg to use nested SiblingMut instances (escaping to
&mut Hugr for `set_parent`)
* Refactor tests so we test calling OutlineCfg on a SiblingMut of only a
portion of a Hugr
  • Loading branch information
acl-cqc authored Sep 18, 2023
1 parent 1ab326f commit 0ea19dd
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 121 deletions.
21 changes: 13 additions & 8 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,13 +708,19 @@ pub(crate) mod test {
}

// Build a CFG, returning the Hugr
pub fn build_conditional_in_loop_cfg(
pub(crate) fn build_conditional_in_loop_cfg(
separate_headers: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
//let sum2_type = Type::new_predicate(2);

let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let (head, tail) = build_conditional_in_loop(&mut cfg_builder, separate_headers)?;
let h = cfg_builder.finish_prelude_hugr()?;
Ok((h, head, tail))
}

pub(crate) fn build_conditional_in_loop<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg_builder: &mut CFGBuilder<T>,
separate_headers: bool,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let pred_const =
cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit =
Expand All @@ -724,15 +730,15 @@ pub(crate) mod test {
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
&const_unit,
)?;
let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?;
let (split, merge) = build_if_then_else_merge(cfg_builder, &pred_const, &const_unit)?;

let (head, tail) = if separate_headers {
let (head, tail) = build_loop(&mut cfg_builder, &pred_const, &const_unit)?;
let (head, tail) = build_loop(cfg_builder, &pred_const, &const_unit)?;
cfg_builder.branch(&head, 0, &split)?;
(head, tail)
} else {
// Combine loop header with split.
let tail = build_loop_from_header(&mut cfg_builder, &pred_const, split)?;
let tail = build_loop_from_header(cfg_builder, &pred_const, split)?;
(split, tail)
};
cfg_builder.branch(&merge, 0, &tail)?;
Expand All @@ -742,7 +748,6 @@ pub(crate) mod test {
cfg_builder.branch(&entry, 0, &head)?;
cfg_builder.branch(&tail, 0, &exit)?;

let h = cfg_builder.finish_prelude_hugr()?;
Ok((h, head, tail))
Ok((head, tail))
}
}
13 changes: 2 additions & 11 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ pub trait HugrMut: HugrView + HugrMutInternals {
parent: Node,
op: impl Into<OpType>,
) -> Result<Node, HugrError> {
self.valid_node(parent)?;
self.hugr_mut().add_op_with_parent(parent, op)
// TODO: Default to `NodeType::open_extensions` once we can infer extensions
self.add_node_with_parent(parent, NodeType::pure(op))
}

/// Add a node to the graph with a parent in the hierarchy.
Expand Down Expand Up @@ -192,15 +192,6 @@ impl<T> HugrMut for T
where
T: HugrView + AsMut<Hugr>,
{
fn add_op_with_parent(
&mut self,
parent: Node,
op: impl Into<OpType>,
) -> Result<Node, HugrError> {
// TODO: Default to `NodeType::open_extensions` once we can infer extensions
self.add_node_with_parent(parent, NodeType::pure(op))
}

fn add_node_with_parent(&mut self, parent: Node, node: NodeType) -> Result<Node, HugrError> {
let node = self.as_mut().add_node(node);
self.as_mut()
Expand Down
121 changes: 93 additions & 28 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use thiserror::Error;

use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
use crate::extension::ExtensionSet;
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::rewrite::Rewrite;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
use crate::ops::handle::NodeHandle;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpTrait, OpType};
use crate::{type_row, Node};

Expand Down Expand Up @@ -162,20 +164,8 @@ impl Rewrite for OutlineCfg {
h.move_before_sibling(new_block, outer_entry).unwrap();
}

// 4. Children of new CFG.
let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
// Entry node must be first
h.move_before_sibling(entry, inner_exit).unwrap();
// And remaining nodes
for n in self.blocks {
// Do not move the entry node, as we have already
if n != entry {
h.set_parent(n, cfg_node).unwrap();
}
}

// 5. Exit edges.
// Retarget edge from exit_node (that used to target outside) to inner_exit
// 4(a). Exit edges.
// Remove edge from exit_node (that used to target outside)
let exit_port = h
.node_outputs(exit)
.filter(|p| {
Expand All @@ -187,10 +177,37 @@ impl Rewrite for OutlineCfg {
.ok() // NodePorts does not implement Debug
.unwrap();
h.disconnect(exit, exit_port).unwrap();
h.connect(exit, exit_port.index(), inner_exit, 0).unwrap();
// And connect new_block to outside instead
h.connect(new_block, 0, outside, 0).unwrap();

// 5. Children of new CFG.
let inner_exit = {
// These operations do not fit within any CSG/SiblingMut
// so we need to access the Hugr directly.
let h = h.hugr_mut();
let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
// Entry node must be first
h.move_before_sibling(entry, inner_exit).unwrap();
// And remaining nodes
for n in self.blocks {
// Do not move the entry node, as we have already
if n != entry {
h.set_parent(n, cfg_node).unwrap();
}
}
inner_exit
};

// 4(b). Reconnect exit edge to the new exit node within the inner CFG
// Use nested SiblingMut's in case the outer `h` is only a SiblingMut itself.
let mut in_bb_view: SiblingMut<'_, BasicBlockID> =
SiblingMut::try_new(h, new_block).unwrap();
let mut in_cfg_view: SiblingMut<'_, CfgID> =
SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap();
in_cfg_view
.connect(exit, exit_port.index(), inner_exit, 0)
.unwrap();

Ok(())
}
}
Expand Down Expand Up @@ -226,18 +243,24 @@ mod test {
use std::collections::HashSet;

use crate::algorithm::nest_cfgs::test::{
build_cond_then_loop_cfg, build_conditional_in_loop_cfg,
build_cond_then_loop_cfg, build_conditional_in_loop, build_conditional_in_loop_cfg,
};
use crate::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer,
};
use crate::extension::prelude::USIZE_T;
use crate::extension::PRELUDE_REGISTRY;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::HugrMut;
use crate::ops::handle::NodeHandle;
use crate::{HugrView, Node};
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::types::FunctionType;
use crate::{type_row, Hugr, HugrView, Node};
use cool_asserts::assert_matches;
use itertools::Itertools;

use super::{OutlineCfg, OutlineCfgError};

fn depth(h: &impl HugrView, n: Node) -> u32 {
fn depth(h: &Hugr, n: Node) -> u32 {
match h.get_parent(n) {
Some(p) => 1 + depth(h, p),
None => 0,
Expand Down Expand Up @@ -274,6 +297,17 @@ mod test {
#[test]
fn test_outline_cfg() {
let (mut h, head, tail) = build_conditional_in_loop_cfg(false).unwrap();
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
do_outline_cfg_test(&mut h, head, tail, 1);
h.validate(&PRELUDE_REGISTRY).unwrap();
}

fn do_outline_cfg_test(
h: &mut impl HugrMut,
head: BasicBlockID,
tail: BasicBlockID,
expected_depth: u32,
) {
let head = head.node();
let tail = tail.node();
let parent = h.get_parent(head).unwrap();
Expand All @@ -283,29 +317,60 @@ mod test {
// | \-> right -/ |
// \---<---<---<---<---<--<---/
// merge is unique predecessor of tail
let merge = h.input_neighbours(tail).exactly_one().unwrap();
let merge = h.input_neighbours(tail).exactly_one().ok().unwrap();
let [left, right]: [Node; 2] = h.output_neighbours(head).collect_vec().try_into().unwrap();
for n in [head, tail, merge] {
assert_eq!(depth(&h, n), 1);
assert_eq!(depth(h.base_hugr(), n), expected_depth);
}
h.infer_and_validate(&PRELUDE_REGISTRY).unwrap();
let blocks = [head, left, right, merge];
h.apply_rewrite(OutlineCfg::new(blocks)).unwrap();
h.validate(&PRELUDE_REGISTRY).unwrap();
for n in blocks {
assert_eq!(depth(&h, n), 3);
assert_eq!(depth(h.base_hugr(), n), expected_depth + 2);
}
let new_block = h.output_neighbours(entry).exactly_one().unwrap();
let new_block = h.output_neighbours(entry).exactly_one().ok().unwrap();
for n in [entry, exit, tail, new_block] {
assert_eq!(depth(&h, n), 1);
assert_eq!(depth(h.base_hugr(), n), expected_depth);
}
assert_eq!(h.input_neighbours(tail).exactly_one().unwrap(), new_block);
assert_eq!(
h.input_neighbours(tail).exactly_one().ok().unwrap(),
new_block
);
assert_eq!(
h.output_neighbours(tail).take(2).collect::<HashSet<Node>>(),
HashSet::from([exit, new_block])
);
}

#[test]
fn test_outline_cfg_subregion() {
let mut module_builder = ModuleBuilder::new();
let mut fbuild = module_builder
.define_function(
"main",
FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]).pure(),
)
.unwrap();
let [i1] = fbuild.input_wires_arr();
let mut cfg_builder = fbuild
.cfg_builder(
[(USIZE_T, i1)],
None,
type_row![USIZE_T],
Default::default(),
)
.unwrap();
let (head, tail) = build_conditional_in_loop(&mut cfg_builder, false).unwrap();
let cfg = cfg_builder.finish_sub_container().unwrap();
fbuild.finish_with_outputs(cfg.outputs()).unwrap();
let mut h = module_builder.finish_prelude_hugr().unwrap();
do_outline_cfg_test(
&mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg.node()).unwrap(),
head,
tail,
3,
);
}

#[test]
fn test_outline_cfg_move_entry() {
// /-> left --\
Expand Down
21 changes: 10 additions & 11 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,16 @@ pub trait HugrView: sealed::HugrInternals {

/// Get the input and output child nodes of a dataflow parent.
/// If the node isn't a dataflow parent, then return None
fn get_io(&self, node: Node) -> Option<[Node; 2]>;
#[inline]
fn get_io(&self, node: Node) -> Option<[Node; 2]> {
let op = self.get_nodetype(node);
// Nodes outside the view have no children (and a non-DataflowParent NodeType::default())
if OpTag::DataflowParent.is_superset(op.tag()) {
self.children(node).take(2).collect_vec().try_into().ok()
} else {
None
}
}

/// For function-like HUGRs (DFG, FuncDefn, FuncDecl), report the function
/// type. Otherwise return None.
Expand Down Expand Up @@ -407,16 +416,6 @@ where
fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> {
self.as_ref().graph.all_neighbours(node.index).map_into()
}

#[inline]
fn get_io(&self, node: Node) -> Option<[Node; 2]> {
let op = self.get_nodetype(node);
if OpTag::DataflowParent.is_superset(op.tag()) {
self.children(node).take(2).collect_vec().try_into().ok()
} else {
None
}
}
}

pub(crate) mod sealed {
Expand Down
5 changes: 0 additions & 5 deletions src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ where
fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> {
self.graph.all_neighbours(node.index).map_into()
}

#[inline]
fn get_io(&self, node: Node) -> Option<[Node; 2]> {
self.base_hugr().get_io(node)
}
}

impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root>
Expand Down
Loading

0 comments on commit 0ea19dd

Please sign in to comment.