Skip to content

Commit

Permalink
egraph: extract derived widths from egraph
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 19, 2024
1 parent 863bd24 commit f2839a5
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ macro_rules! arith_rewrite {
/// Generate our ROVER inspired rewrite rules.
pub fn create_rewrites() -> Vec<ArithRewrite> {
vec![
// a + b <=> b + a
// a + b => b + a
arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"),
// (a << b) << x => a << (b + c)
arith_rewrite!("merge-left-shift";
Expand Down Expand Up @@ -530,7 +530,7 @@ impl ArithRewrite {
let condition = move |egraph: &mut EGraph, _, subst: &Subst| {
let values: Vec<WidthInt> = vars
.iter()
.map(|v| get_const_width_or_sign(egraph, subst, *v))
.map(|v| get_const_width_or_sign(egraph, subst[*v]))
.collect();
cond(values.as_slice())
};
Expand Down Expand Up @@ -568,13 +568,18 @@ type EGraph = egg::EGraph<Arith, ()>;

/// Finds a width or sign constant in the e-class referred to by the substitution
/// and returns its value. Errors if no such constant can be found.
fn get_const_width_or_sign(egraph: &EGraph, subst: &Subst, v: Var) -> WidthInt {
egraph[subst[v]]
fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> WidthInt {
egraph[id]
.nodes
.iter()
.flat_map(|n| match n {
Arith::Width(w) => Some((*w).into()),
Arith::Sign(s) => Some((*s).into()),
Arith::WidthMaxPlus1([a, b]) => {
let a = get_const_width_or_sign(egraph, *a);
let b = get_const_width_or_sign(egraph, *b);
Some(max(a, b) + 1)
}
_ => None,
})
.next()
Expand Down Expand Up @@ -690,12 +695,12 @@ mod tests {

// run egraph operations
let egg_rewrites = create_egg_rewrites();
// let runner = egg::Runner::default()
// .with_expr(&spec_e)
// .with_expr(&impl_e)
// .run(&egg_rewrites);
//
// runner.print_report();
let runner = egg::Runner::default()
.with_expr(&spec_e)
.with_expr(&impl_e)
.run(&egg_rewrites);

runner.print_report();

// todo!();
}
Expand Down

0 comments on commit f2839a5

Please sign in to comment.