Skip to content

Commit

Permalink
egraph: add function to debug match failures
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 20, 2024
1 parent 086cfae commit 52b72ad
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
2 changes: 1 addition & 1 deletion patronus-egraphs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ mod rewrites;
pub use arithmetic::{
from_arith, get_const_width_or_sign, is_bin_op, to_arith, Arith, EGraph, Sign,
};
pub use rewrites::{create_egg_rewrites, create_rewrites, ArithRewrite};
pub use rewrites::{create_egg_rewrites, create_rewrites, ArithRewrite, Assignment};
64 changes: 57 additions & 7 deletions patronus-egraphs/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ introspect them in order to check re-write conditions or debug matches.
!*/

use crate::{get_const_width_or_sign, is_bin_op, Arith, EGraph};
use egg::{ConditionalApplier, ENodeOrVar, Language, Pattern, PatternAst, Rewrite, Subst, Var};
use egg::{
ConditionalApplier, ENodeOrVar, Id, Language, Pattern, PatternAst, Rewrite, Searcher, Subst,
Var,
};
use patronus::expr::WidthInt;
use std::cmp::max;

Expand Down Expand Up @@ -159,6 +162,54 @@ impl ArithRewrite {
true
}
}

pub fn find_lhs_matches(&self, egraph: &EGraph) -> Vec<ArithMatch> {
self.lhs
.search(egraph)
.into_iter()
.flat_map(|m| {
let eclass = m.eclass;
m.substs.into_iter().map(move |s| {
let assign = substitution_to_assignment(egraph, &s, &self.lhs.ast);
let cond_res = self.eval_condition(&assign);
ArithMatch {
eclass,
assign,
cond_res,
}
})
})
.collect()
}
}

fn substitution_to_assignment(
egraph: &EGraph,
s: &Subst,
pattern: &PatternAst<Arith>,
) -> Assignment {
vars_in_pattern(pattern)
.flat_map(|v| match get_const_width_or_sign(egraph, s[v]) {
Some(w) => Some((v, w)),
None => None,
})
.collect()
}

fn vars_in_pattern(pattern: &PatternAst<Arith>) -> impl Iterator<Item = Var> + '_ {
pattern.as_ref().iter().flat_map(|e| match e {
ENodeOrVar::Var(v) => Some(*v),
ENodeOrVar::ENode(_) => None,
})
}

pub type Assignment = Vec<(Var, WidthInt)>;

#[derive(Debug, Clone)]
pub struct ArithMatch {
eclass: Id,
assign: Assignment,
cond_res: bool,
}

/// Checks that input and output widths of operations are consistent.
Expand Down Expand Up @@ -222,7 +273,6 @@ mod tests {
use super::*;
use crate::arithmetic::verification_fig_1;
use crate::to_arith;
use egg::Searcher;
use patronus::expr::{Context, SerializableIrNode};
#[test]
fn test_data_path_verification_fig_1_rewrites() {
Expand All @@ -247,11 +297,11 @@ mod tests {
let impl_class = runner.roots[1];
println!("{spec_class} {impl_class}");

let pat: Pattern<Arith> =
"(<< ?wo ?wab unsign (* ?wab ?wa unsign ?a ?wb unsign ?b) ?wc unsign ?c)"
.parse()
.unwrap();
let r = pat.search(&runner.egraph);
let left_shift_mult = create_rewrites()
.into_iter()
.find(|r| r.name == "left-shift-mult")
.unwrap();
let r = left_shift_mult.find_lhs_matches(&runner.egraph);
for m in r {
println!("{m:?}");
}
Expand Down
2 changes: 0 additions & 2 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ pub struct RuleSymbol {
sign: VarOrConst,
}

pub type Assignment = Vec<(Var, WidthInt)>;

impl RuleInfo {
pub fn signs(&self) -> impl Iterator<Item = Var> + '_ {
self.signs.iter().cloned()
Expand Down

0 comments on commit 52b72ad

Please sign in to comment.