Skip to content

Commit

Permalink
fix: constant folding a LoadConstant with no out neighbours
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 10, 2024
1 parent 4c9cc1f commit 4bcf6b1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
31 changes: 20 additions & 11 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub fn find_consts<'a, 'r: 'a>(
hugr: &'a impl HugrView,
candidate_nodes: impl IntoIterator<Item = Node> + 'a,
reg: &'r ExtensionRegistry,
) -> impl Iterator<Item = (SimpleReplacement, Vec<RemoveLoadConstant>)> + 'a {
) -> impl Iterator<Item = (Option<SimpleReplacement>, Vec<RemoveLoadConstant>)> + 'a {
// track nodes for operations that have already been considered for folding
let mut used_neighbours = BTreeSet::new();

Expand All @@ -113,13 +113,22 @@ pub fn find_consts<'a, 'r: 'a>(
.filter(|(n, _)| used_neighbours.insert(*n))
.collect_vec();
if neighbours.is_empty() {
// no uses of LoadConstant that haven't already been considered.
return None;
if hugr.linked_inputs(n, out_p).next().is_none() {
// LoadConstant with no out neighbours
Some(itertools::Either::Left(std::iter::once((
None,
vec![RemoveLoadConstant(n)],
))))
} else {
// No uses of LoadConstant that haven't already been considered.
None
}
} else {
let fold_iter = neighbours
.into_iter()
.filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg));
Some(itertools::Either::Right(fold_iter))
}
let fold_iter = neighbours
.into_iter()
.filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg));
Some(fold_iter)
})
.flatten()
}
Expand All @@ -129,7 +138,7 @@ fn fold_op(
hugr: &impl HugrView,
op_node: Node,
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveLoadConstant>)> {
) -> Option<(Option<SimpleReplacement>, Vec<RemoveLoadConstant>)> {
// only support leaf folding for now.
let neighbour_op = hugr.get_optype(op_node);
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
Expand Down Expand Up @@ -163,7 +172,7 @@ fn fold_op(
HashMap::new(),
nu_out,
);
Some((simple_replace, removals))
Some((Some(simple_replace), removals))
}

/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant
Expand Down Expand Up @@ -192,8 +201,8 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
if rewrites.is_empty() {
break;
}
for (replace, removes) in rewrites {
h.apply_rewrite(replace).unwrap();
for (mb_replace, removes) in rewrites {
let _ = mb_replace.and_then(|x| h.apply_rewrite(x).ok());
for rem in removes {
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
Expand Down
65 changes: 54 additions & 11 deletions hugr/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Implementation of the `SimpleReplace` operation.
use std::collections::{HashMap, HashSet};
use std::collections::{HashMap};

use crate::hugr::views::sealed::HugrInternals;
use crate::hugr::views::sibling_subgraph::InvalidSubgraph;
Expand Down Expand Up @@ -74,34 +74,78 @@ impl SimpleReplacement {

let self_output_node = h.children(parent).nth(1).unwrap();
assert!(h.get_optype(self_output_node).is_output());
let replacement_output_node = self.replacement.children(self.replacement.root()).nth(1).unwrap();
assert!(self.replacement.get_optype(replacement_output_node).is_output());
let replacement_output_node = self
.replacement
.children(self.replacement.root())
.nth(1)
.unwrap();
assert!(self
.replacement
.get_optype(replacement_output_node)
.is_output());

let problem_unless = |is_good: bool, err: SimpleReplacementError| is_good.then_some(()).ok_or(err);
let problem_unless =
|is_good: bool, err: SimpleReplacementError| is_good.then_some(()).ok_or(err);

for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp {
use SimpleReplacementError::*;

// (rem_inp_node,rem_inp_port) should exist in h
problem_unless(h.valid_non_root(*rem_inp_node), InvalidRemovedNode())?;
problem_unless(h.portgraph().port_index(rem_inp_node.pg_index(), Into::<Port>::into(*rem_inp_port).pg_offset()).is_some(), InvalidRemovedNode())?;
problem_unless(
h.portgraph()
.port_index(
rem_inp_node.pg_index(),
Into::<Port>::into(*rem_inp_port).pg_offset(),
)
.is_some(),
InvalidRemovedNode(),
)?;

// (rep_inp_node, rep_inp_port) should exist in replacement
problem_unless(self.replacement.valid_non_root(*rep_inp_node), InvalidReplacementNode())?;
problem_unless(self.replacement.portgraph().port_index(rep_inp_node.pg_index(), Into::<Port>::into(*rep_inp_port).pg_offset()).is_some(), InvalidReplacementNode())?;
problem_unless(
self.replacement.valid_non_root(*rep_inp_node),
InvalidReplacementNode(),
)?;
problem_unless(
self.replacement
.portgraph()
.port_index(
rep_inp_node.pg_index(),
Into::<Port>::into(*rep_inp_port).pg_offset(),
)
.is_some(),
InvalidReplacementNode(),
)?;
}

for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out {
use SimpleReplacementError::*;
// (rem_out_node, rem_out_port) should exist in h
problem_unless(h.valid_non_root(*rem_out_node), InvalidRemovedNode())?;
problem_unless(h.portgraph().port_index(rem_out_node.pg_index(), Into::<Port>::into(*rem_out_port).pg_offset()).is_some(), InvalidRemovedNode())?;
problem_unless(
h.portgraph()
.port_index(
rem_out_node.pg_index(),
Into::<Port>::into(*rem_out_port).pg_offset(),
)
.is_some(),
InvalidRemovedNode(),
)?;

// rep_out_port must be valid on replacement_output_node
problem_unless(self.replacement.portgraph().port_index(replacement_output_node.pg_index(), Into::<Port>::into(*rep_out_port).pg_offset()).is_some(), InvalidReplacementNode())?;
problem_unless(
self.replacement
.portgraph()
.port_index(
replacement_output_node.pg_index(),
Into::<Port>::into(*rep_out_port).pg_offset(),
)
.is_some(),
InvalidReplacementNode(),
)?;
}
Ok(())

}
}

Expand Down Expand Up @@ -239,7 +283,6 @@ pub enum SimpleReplacementError {
InvalidSubgraph(#[from] InvalidSubgraph),
}


#[cfg(test)]
pub(in crate::hugr::rewrite) mod test {
use itertools::Itertools;
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ pub(crate) mod test {
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c.value() == expected_value => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
_ => panic!("unexpected op: {:?}\n\n{}", op, h.mermaid_string()),
}
}

Expand Down

0 comments on commit 4bcf6b1

Please sign in to comment.