Skip to content

Commit

Permalink
egraph: move around debugging code
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 20, 2024
1 parent 52b72ad commit 3fb8ffd
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 105 deletions.
94 changes: 0 additions & 94 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
use baa::BitVecOps;
use egg::{define_language, Id, Language, RecExpr};
use patronus::expr::*;
use rustc_hash::FxHashMap;
use std::cmp::{max, Ordering};
use std::fmt::{Display, Formatter};
use std::io::Write;
use std::str::FromStr;

define_language! {
Expand Down Expand Up @@ -428,98 +426,6 @@ fn extend(
}
}

fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> {
use std::process::{Command, Stdio};
let mut child = Command::new("dot")
.args(["-Tpdf", "-o", filename])
.stdin(Stdio::piped())
.stdout(Stdio::null())
.spawn()?;
let stdin = child.stdin.as_mut().expect("Failed to open stdin");
write_to_dot(stdin, egraph)?;
match child.wait()?.code() {
Some(0) => Ok(()),
Some(e) => panic!("dot program returned error code {}", e),
None => panic!("dot program was killed by a signal"),
}
}

/// Reimplements egg's `to_dot` functionality.
/// This is necessary because we do not want to show the Width nodes in the graph, because
/// otherwise it becomes very confusing.
fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> {
writeln!(out, "digraph egraph {{")?;

// set compound=true to enable edges to clusters
writeln!(out, " compound=true")?;
writeln!(out, " clusterrank=local")?;

// create a map from e-class id to width
let widths = FxHashMap::from_iter(
egraph
.classes()
.flat_map(|class| get_const_width_or_sign(egraph, class.id).map(|w| (class.id, w))),
);

// define all the nodes, clustered by eclass
for class in egraph.classes() {
if !widths.contains_key(&class.id) {
writeln!(out, " subgraph cluster_{} {{", class.id)?;
writeln!(out, " style=dotted")?;
for (i, node) in class.iter().enumerate() {
writeln!(out, " {}.{}[label = \"{}\"]", class.id, i, node)?;
}
writeln!(out, " }}")?;
}
}

for class in egraph.classes() {
if !widths.contains_key(&class.id) {
for (i_in_class, node) in class.iter().enumerate() {
let nodes_and_labels = if is_bin_op(node) {
// w, w_a, s_a, a, w_b, s_b, b
let cc = node.children();
let w_a = widths[&cc[1]];
let s_a = widths[&cc[2]];
let a = cc[3];
let w_b = widths[&cc[4]];
let s_b = widths[&cc[5]];
let b = cc[6];
vec![
(a, format!("{w_a}{}", if s_a == 0 { "" } else { "s" })),
(b, format!("{w_b}{}", if s_b == 0 { "" } else { "s" })),
]
} else {
assert_eq!(node.len(), 0);
vec![]
};
for (child, label) in nodes_and_labels.into_iter() {
// write the edge to the child, but clip it to the eclass with lhead
let anchor = "";
let child_leader = egraph.find(child);
if child_leader == class.id {
writeln!(
out,
// {}.0 to pick an arbitrary node in the cluster
" {}.{}{} -> {}.{}:n [lhead = cluster_{}, label=\"{}\"]",
class.id, i_in_class, anchor, class.id, i_in_class, class.id, label
)?;
} else {
writeln!(
out,
// {}.0 to pick an arbitrary node in the cluster
" {}.{}{} -> {}.0 [lhead = cluster_{}, label=\"{}\"]",
class.id, i_in_class, anchor, child, child_leader, label
)?;
}
}
}
}
}

write!(out, "}}")
}

pub type EGraph = egg::EGraph<Arith, ()>;

/// Finds a width or sign constant in the e-class referred to by the substitution
Expand Down
101 changes: 101 additions & 0 deletions patronus-egraphs/src/dot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2024 Cornell University
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>
// some of the code is based on `egg` source code which is licenced under MIT

use crate::{get_const_width_or_sign, is_bin_op, EGraph};
use egg::Language;
use rustc_hash::FxHashMap;
use std::io::Write;

pub fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> {
use std::process::{Command, Stdio};
let mut child = Command::new("dot")
.args(["-Tpdf", "-o", filename])
.stdin(Stdio::piped())
.stdout(Stdio::null())
.spawn()?;
let stdin = child.stdin.as_mut().expect("Failed to open stdin");
write_to_dot(stdin, egraph)?;
match child.wait()?.code() {
Some(0) => Ok(()),
Some(e) => panic!("dot program returned error code {}", e),
None => panic!("dot program was killed by a signal"),
}
}

/// Reimplements egg's `to_dot` functionality.
/// This is necessary because we do not want to show the Width nodes in the graph, because
/// otherwise it becomes very confusing.
pub fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> {
writeln!(out, "digraph egraph {{")?;

// set compound=true to enable edges to clusters
writeln!(out, " compound=true")?;
writeln!(out, " clusterrank=local")?;

// create a map from e-class id to width
let widths = FxHashMap::from_iter(
egraph
.classes()
.flat_map(|class| get_const_width_or_sign(egraph, class.id).map(|w| (class.id, w))),
);

// define all the nodes, clustered by eclass
for class in egraph.classes() {
if !widths.contains_key(&class.id) {
writeln!(out, " subgraph cluster_{} {{", class.id)?;
writeln!(out, " style=dotted")?;
for (i, node) in class.iter().enumerate() {
writeln!(out, " {}.{}[label = \"{}\"]", class.id, i, node)?;
}
writeln!(out, " }}")?;
}
}

for class in egraph.classes() {
if !widths.contains_key(&class.id) {
for (i_in_class, node) in class.iter().enumerate() {
let nodes_and_labels = if is_bin_op(node) {
// w, w_a, s_a, a, w_b, s_b, b
let cc = node.children();
let w_a = widths[&cc[1]];
let s_a = widths[&cc[2]];
let a = cc[3];
let w_b = widths[&cc[4]];
let s_b = widths[&cc[5]];
let b = cc[6];
vec![
(a, format!("{w_a}{}", if s_a == 0 { "" } else { "s" })),
(b, format!("{w_b}{}", if s_b == 0 { "" } else { "s" })),
]
} else {
assert_eq!(node.len(), 0);
vec![]
};
for (child, label) in nodes_and_labels.into_iter() {
// write the edge to the child, but clip it to the eclass with lhead
let anchor = "";
let child_leader = egraph.find(child);
if child_leader == class.id {
writeln!(
out,
// {}.0 to pick an arbitrary node in the cluster
" {}.{}{} -> {}.{}:n [lhead = cluster_{}, label=\"{}\"]",
class.id, i_in_class, anchor, class.id, i_in_class, class.id, label
)?;
} else {
writeln!(
out,
// {}.0 to pick an arbitrary node in the cluster
" {}.{}{} -> {}.0 [lhead = cluster_{}, label=\"{}\"]",
class.id, i_in_class, anchor, child, child_leader, label
)?;
}
}
}
}
}

write!(out, "}}")
}
8 changes: 4 additions & 4 deletions patronus-egraphs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>
mod arithmetic;
mod dot;
mod rewrites;

pub use arithmetic::{
from_arith, get_const_width_or_sign, is_bin_op, to_arith, Arith, EGraph, Sign,
};
pub use rewrites::{create_egg_rewrites, create_rewrites, ArithRewrite, Assignment};
pub use arithmetic::*;
pub use dot::*;
pub use rewrites::*;
15 changes: 8 additions & 7 deletions patronus-egraphs/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ impl ArithRewrite {
}
}

/// Find all matches of the left-hand-side and returns information about them.
/// This can be very useful when debugging why a certain rules does not match, when you expect
/// it to match.
pub fn find_lhs_matches(&self, egraph: &EGraph) -> Vec<ArithMatch> {
self.lhs
.search(egraph)
Expand All @@ -189,10 +192,7 @@ fn substitution_to_assignment(
pattern: &PatternAst<Arith>,
) -> Assignment {
vars_in_pattern(pattern)
.flat_map(|v| match get_const_width_or_sign(egraph, s[v]) {
Some(w) => Some((v, w)),
None => None,
})
.flat_map(|v| get_const_width_or_sign(egraph, s[v]).map(|w| (v, w)))
.collect()
}

Expand All @@ -207,9 +207,9 @@ pub type Assignment = Vec<(Var, WidthInt)>;

#[derive(Debug, Clone)]
pub struct ArithMatch {
eclass: Id,
assign: Assignment,
cond_res: bool,
pub eclass: Id,
pub assign: Assignment,
pub cond_res: bool,
}

/// Checks that input and output widths of operations are consistent.
Expand Down Expand Up @@ -301,6 +301,7 @@ mod tests {
.into_iter()
.find(|r| r.name == "left-shift-mult")
.unwrap();
println!("{}", left_shift_mult.patterns().0);
let r = left_shift_mult.find_lhs_matches(&runner.egraph);
for m in r {
println!("{m:?}");
Expand Down

0 comments on commit 3fb8ffd

Please sign in to comment.