Skip to content

Latest commit

 

History

History
451 lines (380 loc) · 22.8 KB

2024-10-19.md

File metadata and controls

451 lines (380 loc) · 22.8 KB

Understanding E-Graphs

This is a work in progress, as I try and further understand e-graphs.

I was introduced to e-graphs a while ago, and while they seemed tasty at the time, I did not fully understand either their benefits, or their implementations. In learning more about them, they both more interesting and less complicated than I initially thought. However, I took a slightly different path than the papers I read, and I thought sharing that might help you as it has me.

First, I'm largely drawing from the egg paper, which you can find among other helpful resources on the egg website. Egg is among other things a Rust implementation of e-graphs, but also a fairly clear and clarifying presentation of the concepts. I recommend you start there, especially if in reading what follows you worry that I may have botched it all. Secondly, my implementation is the result of a few days of typing, and likely has innumerable defects; I don't recommend you use it.

We'll tell the story through a bit of code review, where I show off what I've written (a few hundred lines of Rust).

For Starters: An Abstract Language

E-graphs represent information about expressions in abstract languages. We'll need one of these languages to get started. I'm going to use a concrete language for presentation, but one of the points of egg is that it is generic over your choice of language. The only thing you need to point out is how many arguments each operator in your language requires, which we'll do with a trait:

/// A type that requires inputs and produces an output.
pub trait Function {
    /// The number of inputs the function requires.
    fn inputs(&self) -> usize;
}

We'll use a simple expression language, chosen mostly to be able to get through the worked example in the egg paper. There are a few binary operators, as well as literal integers, and placeholders for variables.

#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
enum Op { Add, Mul, Div, Shf, Lit(i32), Var(i32), }

impl Function for Op {
    fn inputs(&self) -> usize {
        match self {
            Op::Add | Op::Mul | Op::Div | Op::Shf => 2,
            Op::Lit(_) | Op::Var(_) => 0,
        }
    }
}

We'll think about expressions as pre-order sequences of Op. For example, we could represent the expression (a x 2) / 2 as:

[Op::Div, Op::Mul, Op::Var(0), Op::Lit(2), Op::Lit(2)]

You can also think about putting these into expression trees, if that appeals to you. Each tree node would have an Op and a list of other tree nodes for their inputs. We won't be doing that, but it's a fine way to do things if you don't get stressed out by the allocations.

We could write a parser and evaluator and such for this language, but we won't. At least, not yet, because they are not an essential part of understanding e-graphs.

Introducing E-Graphs

An e-graph is a compact way to represent "equivalence" information between expressions. "Equivalence" means something like "can be freely substituted for the other in all contexts", I'm hoping. Our expression above, (a x 2) / 2, is equivalent to the expression a. But, how might we both figure that out and represent it, efficiently?

We will use a representation called an e-graph, short for (I think?) "equivalence graph". An e-graph represents sets of equivalent expressions, in equivalence classes each call an e-class. An e-class represents one set of equivalent expressions, but specifically it is a set of e-nodes. An e-node is a pair of an Op and as many e-class identifiers as the operator takes arguments. So, without understanding anything yet, an e-graph contains several e-classes, each of which contain several e-nodes, each of which references some e-classes.

The intent with an e-node is to observe that we don't really care about the specific arguments to an Op, but really only about the equivalence class of each argument, represented by the e-class identifier. What's nice about this is that when you replace arguments with their e-class identifiers, more expressions can become equivalent. Moreover, if the operator is the same, and the inputs are equivalent, we can judge the results equivalent. Equating all expressions with the same operator and equivalent inputs is a process called "closure under congruence", and is one of the primary exciting properties of e-graphs.

Here are the Rust types I'll be using.

/// An e-class identifier.
pub type Id = usize;

/// An operator, and a list of e-class identifiers for arguments.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct ENode<T> {
    pub op: T,
    pub args: Vec<Id>,  // Could/should be a `SmallVec<Id>`.
}

/// An equivalence relation on e-nodes, closed under congruence.
///
/// The `self.nodes` member is the source of truth, but there are additional reverse
/// indexes that allow us to go from e-class identifiers to the e-nodes that use them.
#[derive(Clone, Debug)]
pub struct EGraph<T> {
    /// A map from e-nodes to their e-class.
    ///
    /// This map acts as the source of truth, from which other members are derived.
    /// Whenever this map is modified, we should also update the derived indices.
    pub nodes: BTreeMap<ENode<T>, Id>,

    /// A reverse index from e-class to the e-nodes with the e-class as an argument.
    pub parents: BTreeMap<Id, BTreeSet<ENode<T>>>,
    /// A reverse index from e-class to the e-nodes that belong to (map to) the e-class.
    pub members: BTreeMap<Id, BTreeSet<ENode<T>>>,

    /// A list of pending requests to merge e-class pairs.
    ///
    /// These are not reflected in `self.nodes` until we refresh the e-graph.
    pub pending: Vec<(Id, Id)>,
}

The EGraph struct is centered around a map nodes from e-nodes to e-class identifiers. However, it also has reverse indexes from e-class identifiers back to e-nodes that map to them (members) and e-nodes that use the identifier as an argument (parents). These reverse indexes will be crucial for efficiently updating the e-graph as e-classes merge, and we must rewrite e-nodes to use new e-class identifiers. The pending member is an implementation detail that allows us to perform egg's "deferred merging". You'll see!

The impressive thing, from my point of view, is that this is all you need to implement an e-graph. We will add some more moving parts, but these are the actual type definitions I'm using. Egg adds some more derived indexes, e.g. from Op to Id for e-classes containing each operator, but still derived data from nodes.

Working with E-Graphs

What might you do with an e-graph?

Here are the methods I've implemented, with the details hidden for the moment behind unimplemented!(). We'll investigate them in turn, unpackaging the complexity in smaller bites.

impl<T: Function + Ord + Clone> EGraph<T> {
    /// Introduces the equivalence of two expressions.
    pub fn equate(&mut self, e0: Vec<T>, e1: Vec<T>) {
        let id0 = self.insert(e0);
        let id1 = self.insert(e1);
        self.merge(id0, id1);
    }

    /// Adds an expression to the e-graph.
    ///
    /// The `expr` argument is expected to be in pre-order.
    pub fn insert(&mut self, mut expr: Vec<T>) -> Id {
        unimplemented!()
    }

    /// Requests a merge of two e-classes.
    pub fn merge(&mut self, a: Id, b: Id) {
        self.pending.push((a, b));
    }

    /// Applies all pending merges, and restores invariants.
    pub fn refresh(&mut self) {
        unimplemented!()
    }
}

There are several other methods you may want, which we'll develop a bit later. For example, how do you get information out of an e-graph? How can you have the e-graph help you figure out which classes you should merge, automatically? Stuff like this.

For the moment, the idea is that you might start from a fresh e-graph, insert some expression you are interested in, and then apply equivalences that match your understanding. Here's an example, taken directly from the egg paper. Ah, their example is better, actually, because they apply general rules rather than hard-wired ones, but it's still a good starting point.

    // (a x 2) / 2
    let eggsample = vec![Op::Div, Op::Mul, Op::Var(0), Op::Lit(2), Op::Lit(2)];
    let mut e_graph: EGraph<Op> = Default::default();
    e_graph.insert(eggsample.clone());
    
    // a x 2 == a << 1
    e_graph.equate(
        vec![Op::Mul, Op::Var(0), Op::Lit(2)],
        vec![Op::Shf, Op::Var(0), Op::Lit(1)],
    );
    // (a x 2) / 2 == a x (2 / 2)
    e_graph.equate(
        eggsample.clone(),
        vec![Op::Mul, Op::Var(0), Op::Div, Op::Lit(2), Op::Lit(2)],
    );
    // 2 / 2 == 1
    e_graph.equate(
        vec![Op::Div, Op::Lit(2), Op::Lit(2)],
        vec![Op::Lit(1)],
    );
    // a x 1 == a
    e_graph.equate(
        vec![Op::Mul, Op::Var(0), Op::Lit(1)],
        vec![Op::Var(0)],
    );
    // Refresh the e-graph to reflect equated terms.
    e_graph.refresh();

    println!("{}", e_graph);

This is a worked example from the egg paper, where the narrative arc is that while swapping out a x 2 for a << 1 feels smart, it is even smarter to keep the x 2 around to cancel the / 2. When you introduce each of these equivalences (egg also finds the equivalences; more on that later), you end up with four equivalence classes:

EClass #0: -> [ Lit(2) ]
EClass #3: -> [ Mul[3, 4], Div[5, 0], Var(0) ]
EClass #4: -> [ Div[0, 0], Lit(1) ]
EClass #5: -> [ Mul[3, 0], Shf[3, 4] ]

This might actually be a good moment to see if you can read e-graph.

Our a term is Var(0), found in e-class 3. Also in 3 is Div [5, 0], which it turns out is the root of our inserted eggsample expression. You can see this by following e-classes 5 and 0. E-class 5 contains Mul [3, 0], where 3 contains a and 0 contains Lit(2); this is the numerator in our expression. E-class 0 contains Lit(2) which is the denominator of our expression.

There are several other e-nodes in here, like Lit(1) and Div [0, 0], which play a role in going from (a x 2) / 2 to the much simpler a.

Inserting in to E-Graphs

Insertion of an expression into an e-graph isn't all that hard!

Recall that an expression is a sequence of operators, in pre-order. We would be able to evaluate the expression by starting at the end and working back towards the front. Instead, we are going to "evaluate" the e-class identifier of the terms, looking them up in the e-graph, in the same manner. Starting at the end of the expression, or leaves of the tree if you prefer, we'll treat each operator as if it is an e-node, with e-class id arguments. We'll then look up that e-node, or insert it if it does not exist, and put the resulting e-class identifier onto our evaluation stack.

/// Adds an expression to the e-graph.
///
/// The `expr` argument is expected to be in pre-order.
pub fn insert(&mut self, mut expr: Vec<T>) -> Id {
    // Repeatedly pop the tail of `expr` and the inserted `args`,
    // to find that e-node in `self.nodes` and push that e-class id.
    let mut ids = Vec::new();
    while let Some(op) = expr.pop() {
        let inputs = op.inputs();
        let args = ids[ids.len() - inputs..].iter().cloned().rev();
        let id = self.ensure(ENode::new(op, args));
        ids.truncate(ids.len() - inputs);
        ids.push(id);
    }

    ids.pop().unwrap()
}

There is a helper method ensure that it's worth looking at. The method first looks for the e-node, and if it doesn't find it it mints a new e-class identifier, and inserts the e-node with that e-class identifier. Insertion amounts to inserting into self.nodes, but also first updating the reverse indexes.

/// Ensures that an e-node is present in the e-graph, and returns its e-class id.
///
/// If the e-node is not initially present, it is added and a new e-class is set up.
fn ensure(&mut self, e_node: ENode<T>) -> Id {
    if let Some(id) = self.nodes.get(&e_node) {
        *id
    }
    else {
        let new_id = self.new_id();
        // 1. Add e_node to parents of each element of `args`.
        for e_class in e_node.args.iter() {
            self.parents.entry(*e_class).or_default().insert(e_node.clone());
        }
        // 2. Form a new e-class with `e_node` as its only member.
        self.members.insert(new_id, BTreeSet::from([e_node.clone()]));
        // 3. Add `e_node` to `self.nodes`.
        self.nodes.insert(e_node, new_id);
        new_id
    }
}

The main risk is when adding an e-node: care must be taken to set all invariants correctly, especially around how self.classes is meant to mirror self.nodes. This code would need to change were any other views and indexes over self.nodes introduced. For completeness, the new_id() function just adds one to the largest e-class identifier we have in a map (they are b-trees, and looking at the maximum key is cheap).

You may have noticed above, but this also gives us our EGraph::equate implementation as well. If you scroll up, it inserts its two arguments, and then enqueues a merge between them. That merge request just goes in to a list, where it sits until someone calls EGraph::refresh. Let's dive in to that next!

Refreshing E-Graphs

One of the contributions of the egg work is that e-graphs may defer their merging work. This allows them to spend a decent amount of time thinking about what to merge, and to then perform the merges in batch. The batch updates end up more efficient, as a fair bit of redundant work can be eliminated.

What is all the work that needs to be done?

There are two main things that need to happen when merges are processed:

  1. Any merged e-classes must now be part of the same e-class. There can be many merges, so this looks a lot like a graph connectivity problem. Each connected component of merged e-classes become one e-class.
  2. E-nodes that reference any merged e-class may need to update their identifiers. Fortunately, all e-class references are tracked by the e-class itself, so these are easy to find.
  3. E-nodes that update identifiers may find they collide with other updated e-nodes when re-inserted into self.nodes. These e-nodes were not necessarily part of an e-class that merged, but their e-classes should now merge (this is good news!). This prompts more merging work that must be completed before the congruence invariant is restored.

We'll structure our approach as a loop that applies these steps in order, repeatedly, until there is no pending merge work to do. We'll go fragment by fragment, to reveal the structure.

/// Applies pending merges and restores invariants. Returns true if any merges occurred.
pub fn refresh(&mut self) -> bool {

    // Record the initial number of e-classes.
    let prior_classes = self.members.len();
    
    // Continue for as long as there are pending merges.
    // Stopping early is not an option, as merges that are produced
    // must be resolved to restore the "congruence" invariant.
    while !self.pending.is_empty() {

That first part is always the easiest to write.

The next step is to identify the connected components of e-classes that must merge. We use union-find for this, but any undirected connectivity algorithm should work. It just turns out I've forgotten how to write anything other than union-find. Also, I forgot how to write union-find, and had to read the wikipedia page to get it correct.

        // 0. Plan merges between e-classes.
        let mut uf = BTreeMap::new();
        for (a, b) in self.pending.iter().copied() { uf.union(&a, &b); }
        // Groups of e-classes that will each be merged into one e-class.
        let mut new_classes = BTreeMap::new();
        for x in self.pending.drain(..).flat_map(|(a, b)| [a, b]) {
            let root = *uf.find(&x).unwrap();
            new_classes.entry(root).or_insert_with(BTreeSet::new).insert(x);
        }
        // Map for each merged e-class to its representative e-class identifier.
        // Choose the e-class with the most members and parents as the representative.
        let mut new_nodes = BTreeMap::new();
        for (_root, e_class) in new_classes.iter() {
            // Choose the new id to be the e-class with the most combined members and parents.
            // This minimizes the amount of work we'll have to do to refresh the stale e-nodes.
            let new_id = *e_class.iter().map(|id| (self.members.get(&id).map(|m| m.len()).unwrap_or(0) + self.parents.get(&id).map(|p| p.len()).unwrap_or(0), id)).min().unwrap().1;
            for id in e_class.iter().copied() {
                new_nodes.insert(id, new_id);
            }
        }

The uf map allows us to correct any e-class identifier to a new shared identifier. The new_classes map now has our lists of e-class identifiers to merge. The next step is doing something with those lists, namely writing down all the affected e-nodes, who must update their identifiers.

        // 1. Remove any defunct e-classes, and enqueue their e-nodes for refresh.
        let mut refresh = BTreeSet::new();
        for (_root, class) in new_classes {
            for e_class in class {
                if new_nodes[&e_class] != e_class {
                    if let Some(members) = self.members.remove(&e_class) {
                        refresh.extend(members);
                    }
                    if let Some(parents) = self.parents.remove(&e_class) {
                        refresh.extend(parents);
                    }
                }
            }
        }

We totally removed the merged e-classes, and only wrote down the e-nodes that were present in them. We will rebuild the merged class as part of updating the identifiers; recall that the information is a view over self.nodes, and we can rebuild it.

The next step, last in each iteration, is to update all of the e-nodes. This isn't too hard, conceptually. We remove all references to the old e-nodes, update them, and then insert all references to the new e-nodes. While re-inserting into self.nodes, we may find the node already present (!!), which requires that we schedule a merge between those two e-classes.

        // 2. Refresh each stale e-node.
        for mut e_node in refresh {
            let mut id = self.remove(&e_node).unwrap();
            // Update the e-classes referenced by the e-node.
            if let Some(new) = new_nodes.get(&id) {
                id = *new;
            }
            for arg in e_node.args.iter_mut() {
                if let Some(id) = new_nodes.get(arg) {
                    *arg = *id;
                }
            }
            // Introduce evidence of the refreshed e-node.
            self.members.entry(id).or_default().insert(e_node.clone());
            for arg in e_node.args.iter() {
                self.parents.entry(*arg).or_default().insert(e_node.clone());
            }
            // Re-introduce the e-node to the forward map, and enqueue a merge if necessary.
            if let Some(other) = self.nodes.insert(e_node, id) {
                if other != id {
                    self.pending.push((other, id));
                }
            }
        }

And with that, we are done with the loop body! The only thing left to do is go around the loop until we run out of self.pending. As we only ever merge e-classes, and always merge at least one e-class, the loop must eventually terminate.

    }
    // If e-classes have vanished, merges occurred.
    self.members.len() < prior_classes
}

That's the whole thing. E-graphs accomplished!

Searching E-Graphs

As interesting as it is to manually introduce equivalences like (a x 2) / 2 == a x (2 / 2), the e-graph community has concluded that humans should not be involved in effecting these rules. The reason the rule is appropriate is because of the more general (a x b) / c == a x (b / c), and it happened that b = c = 2 in our case. But how do we start from general rules and automatically find application targets?

It turns out this is a neat instance of graph motif finding. Our equality up above is looking for matches of the form [Div, Mul, v0, v1, v2], where the v terms are variables we would like to bind. Our e-graph is a giant labeled graph, where e-classes reference e-nodes that reference operators and other e-classes. There are a lot of interesting algorithms for finding motifs in graphs; you might remember that I used to do this a long time ago.

The tippy-top best ways to do this are various forms of worst-case optimal join, one of the best things to come out of the databases community in the past decade or so. These are approaches to implementing joins, as it turns out graph motifs are well describe by, that avoid massively exploding the data only to find no results. They have a slightly different flavor than most join algorithms you know about, and I won't be able to explain them fully here.

But we can talk through a bad implementation of pattern matching!

When given a pattern like [Div, Mul, v0, v1, v2] one approach is to bind the variables in turn, and cease exploring some bindings when you find a conflict. For example, we could try all values of v0, and bail out on any values that couldn't possibly be solutions. For our example, all e-classes except 8 would be failures, because none of the other e-classes have a Mul that references them. Continuing with 8, we would find that both e-classes 0 and 6 participate in a (Mul, [8, _]) e-node. Continuing with both [8, 0] and [8, 6] we would find that only e-class 0 divides such a term. This gives us the only binding, [8, 0, 0], which corresponds to (a x 2) / 2.

I implemented .. a worse version of this.

There are many other ways to propose and prune extensions. For example, if you want to do constant folding, you only need to start in e-classes that contain literals, and look at their parents to find e-nodes that could possibly be evaluated. Egg maintains a map from operators to e-classes so that they can quickly investigate e-classes that might match their operators.

More generally, worst-case optimal joins have a class of algorithms based on proposers and validaters. For each variable v, any number of constraints may exist on it, and you pick the constraint that can propose the fewest candidates. Each other constraint validates each candidate, and any that survive all valiations are accepted.

There is an art to getting the data in the right shape that these proposals are fast. This is exactly where treating EGraph::nodes as a source of truth, and various other collections as derived views pays off. The EGraph::classes collection is a view over nodes, set up to provide fast random access. Similarly, egg's classes_by_op is an index by operator to find classes that reference them. Any other number of indexes are potentially available!