diff --git a/patronus-egraphs/src/lib.rs b/patronus-egraphs/src/lib.rs index 01141d0..81897b1 100644 --- a/patronus-egraphs/src/lib.rs +++ b/patronus-egraphs/src/lib.rs @@ -7,4 +7,4 @@ mod rewrites; pub use arithmetic::{ from_arith, get_const_width_or_sign, is_bin_op, to_arith, Arith, EGraph, Sign, }; -pub use rewrites::{create_egg_rewrites, create_rewrites, ArithRewrite}; +pub use rewrites::{create_egg_rewrites, create_rewrites, ArithRewrite, Assignment}; diff --git a/patronus-egraphs/src/rewrites.rs b/patronus-egraphs/src/rewrites.rs index 7022016..0d2cdd3 100644 --- a/patronus-egraphs/src/rewrites.rs +++ b/patronus-egraphs/src/rewrites.rs @@ -10,7 +10,10 @@ introspect them in order to check re-write conditions or debug matches. !*/ use crate::{get_const_width_or_sign, is_bin_op, Arith, EGraph}; -use egg::{ConditionalApplier, ENodeOrVar, Language, Pattern, PatternAst, Rewrite, Subst, Var}; +use egg::{ + ConditionalApplier, ENodeOrVar, Id, Language, Pattern, PatternAst, Rewrite, Searcher, Subst, + Var, +}; use patronus::expr::WidthInt; use std::cmp::max; @@ -159,6 +162,54 @@ impl ArithRewrite { true } } + + pub fn find_lhs_matches(&self, egraph: &EGraph) -> Vec { + self.lhs + .search(egraph) + .into_iter() + .flat_map(|m| { + let eclass = m.eclass; + m.substs.into_iter().map(move |s| { + let assign = substitution_to_assignment(egraph, &s, &self.lhs.ast); + let cond_res = self.eval_condition(&assign); + ArithMatch { + eclass, + assign, + cond_res, + } + }) + }) + .collect() + } +} + +fn substitution_to_assignment( + egraph: &EGraph, + s: &Subst, + pattern: &PatternAst, +) -> Assignment { + vars_in_pattern(pattern) + .flat_map(|v| match get_const_width_or_sign(egraph, s[v]) { + Some(w) => Some((v, w)), + None => None, + }) + .collect() +} + +fn vars_in_pattern(pattern: &PatternAst) -> impl Iterator + '_ { + pattern.as_ref().iter().flat_map(|e| match e { + ENodeOrVar::Var(v) => Some(*v), + ENodeOrVar::ENode(_) => None, + }) +} + +pub type Assignment = Vec<(Var, WidthInt)>; + +#[derive(Debug, Clone)] +pub struct ArithMatch { + eclass: Id, + assign: Assignment, + cond_res: bool, } /// Checks that input and output widths of operations are consistent. @@ -222,7 +273,6 @@ mod tests { use super::*; use crate::arithmetic::verification_fig_1; use crate::to_arith; - use egg::Searcher; use patronus::expr::{Context, SerializableIrNode}; #[test] fn test_data_path_verification_fig_1_rewrites() { @@ -247,11 +297,11 @@ mod tests { let impl_class = runner.roots[1]; println!("{spec_class} {impl_class}"); - let pat: Pattern = - "(<< ?wo ?wab unsign (* ?wab ?wa unsign ?a ?wb unsign ?b) ?wc unsign ?c)" - .parse() - .unwrap(); - let r = pat.search(&runner.egraph); + let left_shift_mult = create_rewrites() + .into_iter() + .find(|r| r.name == "left-shift-mult") + .unwrap(); + let r = left_shift_mult.find_lhs_matches(&runner.egraph); for m in r { println!("{m:?}"); } diff --git a/tools/egraphs-cond-synth/src/samples.rs b/tools/egraphs-cond-synth/src/samples.rs index cb99b73..3fbc7c5 100644 --- a/tools/egraphs-cond-synth/src/samples.rs +++ b/tools/egraphs-cond-synth/src/samples.rs @@ -321,8 +321,6 @@ pub struct RuleSymbol { sign: VarOrConst, } -pub type Assignment = Vec<(Var, WidthInt)>; - impl RuleInfo { pub fn signs(&self) -> impl Iterator + '_ { self.signs.iter().cloned()