Skip to content

Commit

Permalink
Add diverge_callback, extract all from AnalysisResults
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Jan 29, 2025
1 parent 6fa6ff6 commit cb7bf75
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 50 deletions.
35 changes: 27 additions & 8 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! An (example) use of the [dataflow analysis framework](super::dataflow).
pub mod value_handle;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;

use hugr_core::{
Expand All @@ -20,10 +20,14 @@ use hugr_core::{
};
use value_handle::ValueHandle;

use crate::dataflow::{
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
};
use crate::validation::{ValidatePassError, ValidationLevel};
use crate::{
dataflow::{
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
TailLoopTermination,
},
dead_code::NodeDivergence,
};
use crate::{dead_code::DeadCodeElimPass, find_main};

#[derive(Debug, Clone, Default)]
Expand Down Expand Up @@ -111,6 +115,13 @@ impl ConstantFoldPass {
))
})
.collect::<Vec<_>>();
// Sadly the results immutably borrow the hugr, so we must extract everything we need before mutation
let terminating_tail_loops = hugr
.nodes()
.filter(|n| {
results.tail_loop_terminates(*n) == Some(TailLoopTermination::NeverContinues)
})
.collect::<Vec<_>>();

for (n, inport, v) in wires_to_break {
let parent = hugr.get_parent(n).unwrap();
Expand All @@ -126,11 +137,19 @@ impl ConstantFoldPass {
let mut dce = DeadCodeElimPass::default();
if hugr.get_optype(hugr.root()).is_module() {
dce = dce.with_entry_points([find_main(hugr).unwrap()])
};
if self.allow_increase_termination {
dce = dce.allow_increase_termination()
}
dce.run(hugr)?;
dce.set_diverge_callback(if self.allow_increase_termination {
Arc::new(|_, _| NodeDivergence::CanRemove)
} else {
Arc::new(move |_, n| {
if terminating_tail_loops.contains(&n) {
NodeDivergence::RemoveIfAllChildrenCanBeRemoved
} else {
NodeDivergence::UseDefault
}
})
})
.run(hugr)?;
Ok(())
}

Expand Down
133 changes: 91 additions & 42 deletions hugr-passes/src/dead_code.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,51 @@
//! Pass for removing dead code, i.e. that computes values that are then discarded
use std::collections::{HashSet, VecDeque};

use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, HugrView, Node};
use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node};
use std::fmt::{Debug, Formatter};
use std::{
collections::{HashSet, VecDeque},
sync::Arc,
};

use crate::validation::{ValidatePassError, ValidationLevel};

/// Configuration for Dead Code Elimination pass
#[derive(Clone, Debug, Default)]
#[derive(Clone, Default)]
pub struct DeadCodeElimPass {
entry_points: Vec<Node>,
allow_increase_termination: bool,
diverge_callback: Option<Arc<dyn Fn(&Hugr, Node) -> NodeDivergence>>,
validation: ValidationLevel,
}

impl Debug for DeadCodeElimPass {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
// Use "derive Debug" by defining an identical struct without the unprintable fields

#[allow(unused)] // Rust ignores the derive-Debug in figuring out what's used
#[derive(Debug)]
struct DCEDebug<'a> {
entry_points: &'a Vec<Node>,
validation: ValidationLevel,
}

Debug::fmt(
&DCEDebug {
entry_points: &self.entry_points,
validation: self.validation,
},
f,
)
}
}

pub enum NodeDivergence {
#[allow(unused)]
MustKeep,
CanRemove,
UseDefault,
RemoveIfAllChildrenCanBeRemoved,
}

impl DeadCodeElimPass {
/// Sets the validation level used before and after the pass is run
#[allow(unused)]
Expand All @@ -22,13 +54,21 @@ impl DeadCodeElimPass {
self
}

/// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their
/// result (if/when they do terminate) is either known or not needed.
/// Allows setting a callback that determines whether a node is considered as diverging
/// (non-terminating) - that is, nodes for which the callback returns
/// * Some(true) => cannot be removed, even if we don't need their results
/// * Some(false) => can be removed so long as we don't need their results
/// (note that this means we can remove their descendants too, *even if* said descendants diverge)
/// * None => use default algorithm for whether we can remove or not
///
/// [TailLoop]: hugr_core::ops::TailLoop
/// The default algorithm says that [Cfg], [Call] and [TailLoop] nodes can never be removed,
/// nor can any node that (recursively) contains a diverging node.
///
/// [Call]: hugr_core::ops::Call
/// [CFG]: hugr_core::ops::CFG
pub fn allow_increase_termination(mut self) -> Self {
self.allow_increase_termination = true;
/// [TailLoop]: hugr_core::ops::TailLoop
pub fn set_diverge_callback(mut self, cb: Arc<dyn Fn(&Hugr, Node) -> NodeDivergence>) -> Self {
self.diverge_callback = Some(cb);
self
}

Expand Down Expand Up @@ -73,13 +113,11 @@ impl DeadCodeElimPass {
// including StateOrder edges.
q.extend(inout); // Input also necessary for legality even if unreachable

if !self.allow_increase_termination {
// Also add on anything that might not terminate (even if results not required -
// if its results are required we'll add it by following dataflow, below.)
for ch in h.children(n) {
if might_diverge(&h, ch) {
q.push_back(ch);
}
// Also add on anything that might not terminate (even if results not required -
// if its results are required we'll add it by following dataflow, below.)
for ch in h.children(n) {
if self.might_diverge(&h, ch) {
q.push_back(ch);
}
}
}
Expand Down Expand Up @@ -109,32 +147,43 @@ impl DeadCodeElimPass {
}
Ok(())
}
}

// "Diverge" aka "never-terminate"
// TODO would be more efficient to compute this bottom-up and cache (dynamic programming)
fn might_diverge(h: &impl HugrView, n: Node) -> bool {
match h.get_optype(n) {
OpType::CFG(_) => {
// TODO if the CFG has no cycles (that are possible given predicates)
// then we could say it definitely terminates (i.e. return false)
true
}
OpType::TailLoop(_) => {
// If the TailLoop never continues, clearly it doesn't terminate, but we haven't got
// dataflow results to tell us that. Instead rely on an earlier pass having rewritten
// such a TailLoop into a non-loop.
// Even just an upper-bound on the number of iterations would allow returning false.
true
}
OpType::Call(_) => {
// We could scan the target FuncDefn, but that might contain calls to itself, so we'd need
// a "seen" set...instead just rely on calls being inlined if we want to remove them.
true
}
_ => {
// Node does not introduce non-termination, but still non-terminates if any of its children does
h.children(n).any(|ch| might_diverge(h, ch))
// "Diverge" aka "never-terminate"
// TODO would be more efficient to compute this bottom-up and cache (dynamic programming)
fn might_diverge(&self, h: &impl HugrView, n: Node) -> bool {
match self
.diverge_callback
.as_ref()
.map_or(NodeDivergence::UseDefault, |f| f(h.base_hugr(), n))
{
NodeDivergence::MustKeep => return true,
NodeDivergence::CanRemove => return false,
NodeDivergence::UseDefault => {
match h.get_optype(n) {
OpType::CFG(_) => {
// TODO if the CFG has no cycles (that are possible given predicates)
// then we could say it definitely terminates (i.e. return false)
return true;
}
OpType::TailLoop(_) => {
// If the TailLoop never continues, clearly it doesn't terminate, but we haven't got
// dataflow results to tell us that. Instead rely on an earlier pass having rewritten
// such a TailLoop into a non-loop.
// Even just an upper-bound on the number of iterations would allow returning false.
return true;
}
OpType::Call(_) => {
// We could scan the target FuncDefn, but that might contain calls to itself, so we'd need
// a "seen" set...instead just rely on calls being inlined if we want to remove them.
return true;
}
_ => (), // fall through to check children
}
}
NodeDivergence::RemoveIfAllChildrenCanBeRemoved => (), // fall through to check children
}

// Node does not introduce non-termination, but still non-terminates if any of its children does
h.children(n).any(|ch| self.might_diverge(h, ch))
}
}

0 comments on commit cb7bf75

Please sign in to comment.