Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 14, 2024
1 parent f37a55b commit b50ffc4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 91 deletions.
32 changes: 1 addition & 31 deletions hugr/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,4 @@ pub mod const_fold;
mod half_node;
pub mod merge_bbs;
pub mod nest_cfgs;

#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)]
/// A type for algorithms to take as configuration, specifying how much
/// verification they should do. Algorithms that accept this configuration
/// should at least verify that input HUGRs are valid, and that output HUGRs are
/// valid.
///
/// The default level is `None` because verification can be expensive.
pub enum VerifyLevel {
/// Do no verification.
None,
/// Verify using [HugrView::validate_no_extensions]. This is useful when you
/// do not expect valid Extension annotations on Nodes.
///
/// [HugrView::validate_no_extensions]: crate::HugrView::validate_no_extensions
WithoutExtensions,
/// Verify using [HugrView::validate].
///
/// [HugrView::validate]: crate::HugrView::validate
WithExtensions,
}

impl Default for VerifyLevel {
fn default() -> Self {
if cfg!(test) {
Self::WithoutExtensions
} else {
Self::None
}
}
}
pub mod verify;
97 changes: 37 additions & 60 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::collections::{BTreeSet, HashMap};
use itertools::Itertools;
use thiserror::Error;

use crate::hugr::{SimpleReplacementError, ValidationError};
use crate::algorithm::verify::{VerifyError, VerifyLevel};
use crate::hugr::SimpleReplacementError;
use crate::types::SumType;
use crate::Direction;
use crate::{
Expand All @@ -22,91 +23,67 @@ use crate::{
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};

use super::VerifyLevel;

#[derive(Error, Debug)]
#[allow(missing_docs)]
pub enum ConstFoldError {
#[error("Failed to verify {label} HUGR: {err}")]
VerifyError {
label: String,
#[source]
err: ValidationError,
},
#[error(transparent)]
SimpleReplaceError(#[from] SimpleReplacementError),
}

impl ConstFoldError {
fn verify_err(label: impl Into<String>, err: ValidationError) -> Self {
Self::VerifyError {
label: label.into(),
err,
}
}
SimpleReplacementError(#[from] SimpleReplacementError),
#[error(transparent)]
VerifyError(#[from] VerifyError),
}

#[derive(Debug, Clone, Copy, Default)]
/// A configuration for the Constant Folding pass.
pub struct ConstFoldConfig {
pub struct ConstantFoldPass {
verify: VerifyLevel,
}

impl ConstFoldConfig {
impl ConstantFoldPass {
/// Create a new `ConstFoldConfig` with default configuration.
pub fn new() -> Self {
Self::default()
}

/// Build a `ConstFoldConfig` with the given [VerifyLevel].
pub fn with_verify(mut self, verify: VerifyLevel) -> Self {
pub fn verify_level(mut self, verify: VerifyLevel) -> Self {
self.verify = verify;
self
}

fn verify_impl(
/// Run the Constant Folding pass.
pub fn run<H: HugrMut>(
&self,
label: &str,
h: &impl HugrView,
hugr: &mut H,
reg: &ExtensionRegistry,
) -> Result<(), ConstFoldError> {
match self.verify {
VerifyLevel::None => Ok(()),
VerifyLevel::WithoutExtensions => h.validate_no_extensions(reg),
VerifyLevel::WithExtensions => h.validate(reg),
}
.map_err(|err| ConstFoldError::verify_err(label, err))
}

/// Run the Constant Folding pass.
pub fn run(&self, h: &mut impl HugrMut, reg: &ExtensionRegistry) -> Result<(), ConstFoldError> {
self.verify_impl("input", h, reg)?;
loop {
// We can only safely apply a single replacement. Applying a
// replacement removes nodes and edges which may be referenced by
// further replacements returned by find_consts. Even worse, if we
// attempted to apply those replacements, expecting them to fail if
// the nodes and edges they reference had been deleted, they may
// succeed because new nodes and edges reused the ids.
//
// We could be a lot smarter here, keeping track of `LoadConstant`
// nodes and only looking at their out neighbours.
let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else {
break;
};
h.apply_rewrite(replace)?;
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = h.apply_rewrite(RemoveConst(const_node));
self.verify.run_verified_pass(hugr, reg, |hugr: &mut H| {
loop {
// We can only safely apply a single replacement. Applying a
// replacement removes nodes and edges which may be referenced by
// further replacements returned by find_consts. Even worse, if we
// attempted to apply those replacements, expecting them to fail if
// the nodes and edges they reference had been deleted, they may
// succeed because new nodes and edges reused the ids.
//
// We could be a lot smarter here, keeping track of `LoadConstant`
// nodes and only looking at their out neighbours.
let Some((replace, removes)) = find_consts(hugr, hugr.nodes(), reg).next() else {
break Ok(());
};
hugr.apply_rewrite(replace)?;
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = hugr.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = hugr.apply_rewrite(RemoveConst(const_node));
}
}
}
}
self.verify_impl("output", h, reg)
})
}
}

Expand Down Expand Up @@ -275,7 +252,7 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass<H: HugrMut>(h: &mut H, reg: &ExtensionRegistry) {
ConstFoldConfig::default().run(h, reg).unwrap()
ConstantFoldPass::default().run(h, reg).unwrap()
}

#[cfg(test)]
Expand Down

0 comments on commit b50ffc4

Please sign in to comment.