Skip to content

Commit

Permalink
Cranelift: Use a fixpoint loop to compute the best value for each eclass
Browse files Browse the repository at this point in the history
Fixes #7857
  • Loading branch information
fitzgen committed Feb 2, 2024
1 parent 32d5ff4 commit 7223562
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 55 deletions.
1 change: 1 addition & 0 deletions cranelift/codegen/src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,4 +701,5 @@ pub(crate) struct Stats {
pub(crate) elaborate_func: u64,
pub(crate) elaborate_func_pre_insts: u64,
pub(crate) elaborate_func_post_insts: u64,
pub(crate) elaborate_best_cost_fixpoint_iters: u64,
}
79 changes: 61 additions & 18 deletions cranelift/codegen/src/egraph/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl Cost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = (Self::OP_COST_MASK >> Self::DEPTH_BITS) - 1;
const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;

pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
Expand All @@ -86,14 +86,24 @@ impl Cost {
Cost(0)
}

/// Construct a new finite cost from the given parts.
pub(crate) fn is_infinite(&self) -> bool {
*self == Cost::infinity()
}

pub(crate) fn is_finite(&self) -> bool {
!self.is_infinite()
}

/// Construct a new `Cost` from the given parts.
///
/// The opcode cost is clamped to the maximum value representable.
fn new_finite(opcode_cost: u32, depth: u8) -> Cost {
let opcode_cost = std::cmp::min(opcode_cost, Self::MAX_OP_COST);
let cost = Cost((opcode_cost << Self::DEPTH_BITS) | u32::from(depth));
debug_assert_ne!(cost, Cost::infinity());
cost
/// If the opcode cost is greater than or equal to the maximum representable
/// opcode cost, then the resulting `Cost` saturates to infinity.
fn new(opcode_cost: u32, depth: u8) -> Cost {
if opcode_cost >= Self::MAX_OP_COST {
Self::infinity()
} else {
Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
}
}

fn depth(&self) -> u8 {
Expand All @@ -111,7 +121,7 @@ impl Cost {
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
Cost::new_finite(c.op_cost(), c.depth().saturating_add(1))
Cost::new(c.op_cost(), c.depth().saturating_add(1))
}
}

Expand All @@ -131,12 +141,9 @@ impl std::ops::Add<Cost> for Cost {
type Output = Cost;

fn add(self, other: Cost) -> Cost {
let op_cost = std::cmp::min(
self.op_cost().saturating_add(other.op_cost()),
Self::MAX_OP_COST,
);
let op_cost = self.op_cost().saturating_add(other.op_cost());
let depth = std::cmp::max(self.depth(), other.depth());
Cost::new_finite(op_cost, depth)
Cost::new(op_cost, depth)
}
}

Expand All @@ -147,11 +154,11 @@ impl std::ops::Add<Cost> for Cost {
fn pure_op_cost(op: Opcode) -> Cost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new_finite(1, 0),
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),

// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost::new_finite(2, 0)
Cost::new(2, 0)
}

// "Simple" arithmetic.
Expand All @@ -163,9 +170,45 @@ fn pure_op_cost(op: Opcode) -> Cost {
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost::new_finite(3, 0),
| Opcode::Sshr => Cost::new(3, 0),

// Everything else (pure.)
_ => Cost::new_finite(4, 0),
_ => Cost::new(4, 0),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn add_cost() {
let a = Cost::new(5, 2);
let b = Cost::new(37, 3);
assert_eq!(a + b, Cost::new(42, 3));
}

#[test]
fn add_infinity() {
let a = Cost::new(5, 2);
let b = Cost::infinity();
assert_eq!(a + b, Cost::infinity());
}

#[test]
fn op_cost_saturates_to_infinity() {
let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
let b = Cost::new(11, 2);
assert_eq!(a + b, Cost::infinity());
}

#[test]
fn depth_saturates_to_max_depth() {
let a = Cost::new(10, u8::MAX);
let b = Cost::new(10, 1);
assert_eq!(
Cost::of_pure_op(Opcode::Iconst, [a, b]),
Cost::new(21, u8::MAX)
);
}
}
127 changes: 90 additions & 37 deletions cranelift/codegen/src/egraph/elaborate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::Stats;
use crate::dominator_tree::DominatorTree;
use crate::fx::{FxHashMap, FxHashSet};
use crate::hash_map::Entry as HashEntry;
use crate::inst_predicates::is_pure_for_egraph;
use crate::ir::{Block, Function, Inst, Value, ValueDef};
use crate::loop_analysis::{Loop, LoopAnalysis};
use crate::scoped_hash_map::ScopedHashMap;
Expand Down Expand Up @@ -216,46 +217,92 @@ impl<'a> Elaborator<'a> {

fn compute_best_values(&mut self) {
let best = &mut self.value_to_best_value;
for (value, def) in self.func.dfg.values_and_defs() {
trace!("computing best for value {:?} def {:?}", value, def);
match def {
ValueDef::Union(x, y) => {
// Pick the best of the two options based on
// min-cost. This works because each element of `best`
// is a `(cost, value)` tuple; `cost` comes first so
// the natural comparison works based on cost, and
// breaks ties based on value number.
trace!(" -> best of {:?} and {:?}", best[x], best[y]);
best[value] = std::cmp::min(best[x], best[y]);
trace!(" -> {:?}", best[value]);
}
ValueDef::Param(_, _) => {
best[value] = BestEntry(Cost::zero(), value);

// Do a fixpoint loop to compute the best value for each eclass.
//
// The maximum number of iterations is the length of the longest chain
// of `vNN -> vMM` edges in the dataflow graph where `NN < MM`, so this
// is *technically* quadratic, but `cranelift-frontend` won't construct
// any such edges. NaN canonicalization will introduce some of these
// edges, but they are chains of only two or three edges. So in
// practice, we *never* do more than a handful of iterations here unless
// (a) we parsed the CLIF from text and the text was funkily numbered,
// which we don't really care about, or (b) the CLIF producer did
// something weird, in which case it is their responsibility to stop
// doing that.
trace!("Entering fixpoint loop to compute the best values for each eclass");
let mut keep_going = true;
while keep_going {
keep_going = false;
trace!(
"fixpoint iteration {}",
self.stats.elaborate_best_cost_fixpoint_iters
);
self.stats.elaborate_best_cost_fixpoint_iters += 1;

for (value, def) in self.func.dfg.values_and_defs() {
// If the cost of this value is finite, then we've already found
// its final cost.
if best[value].0.is_finite() {
continue;
}
// If the Inst is inserted into the layout (which is,
// at this point, only the side-effecting skeleton),
// then it must be computed and thus we give it zero
// cost.
ValueDef::Result(inst, _) => {
if let Some(_) = self.func.layout.inst_block(inst) {
best[value] = BestEntry(Cost::zero(), value);
} else {
trace!(" -> value {}: result, computing cost", value);
let inst_data = &self.func.dfg.insts[inst];
// N.B.: at this point we know that the opcode is
// pure, so `pure_op_cost`'s precondition is
// satisfied.
let cost = Cost::of_pure_op(
inst_data.opcode(),
self.func.dfg.inst_values(inst).map(|value| best[value].0),

trace!("computing best for value {:?} def {:?}", value, def);
let orig_best_value = best[value];

match def {
ValueDef::Union(x, y) => {
// Pick the best of the two options based on
// min-cost. This works because each element of `best`
// is a `(cost, value)` tuple; `cost` comes first so
// the natural comparison works based on cost, and
// breaks ties based on value number.
best[value] = std::cmp::min(best[x], best[y]);
trace!(
" -> best of union({:?}, {:?}) = {:?}",
best[x],
best[y],
best[value]
);
best[value] = BestEntry(cost, value);
}
}
};
debug_assert_ne!(best[value].0, Cost::infinity());
debug_assert_ne!(best[value].1, Value::reserved_value());
trace!("best for eclass {:?}: {:?}", value, best[value]);
ValueDef::Param(_, _) => {
best[value] = BestEntry(Cost::zero(), value);
}
// If the Inst is inserted into the layout (which is,
// at this point, only the side-effecting skeleton),
// then it must be computed and thus we give it zero
// cost.
ValueDef::Result(inst, _) => {
if let Some(_) = self.func.layout.inst_block(inst) {
best[value] = BestEntry(Cost::zero(), value);
} else {
let inst_data = &self.func.dfg.insts[inst];
// N.B.: at this point we know that the opcode is
// pure, so `pure_op_cost`'s precondition is
// satisfied.
let cost = Cost::of_pure_op(
inst_data.opcode(),
self.func.dfg.inst_values(inst).map(|value| best[value].0),
);
best[value] = BestEntry(cost, value);
trace!(" -> cost of value {} = {:?}", value, cost);
}
}
};

// Keep on iterating the fixpoint loop while we are finding new
// best values.
keep_going |= orig_best_value != best[value];
}
}

if cfg!(any(feature = "trace-log", debug_assertions)) {
trace!("finished fixpoint loop to compute best value for each eclass");
for value in self.func.dfg.values() {
debug_assert_ne!(best[value].0, Cost::infinity());
debug_assert_ne!(best[value].1, Value::reserved_value());
trace!("-> best for eclass {:?}: {:?}", value, best[value]);
}
}
}

Expand Down Expand Up @@ -606,7 +653,13 @@ impl<'a> Elaborator<'a> {
}
inst
};

// Place the inst just before `before`.
debug_assert!(
is_pure_for_egraph(self.func, inst),
"something has gone very wrong if we are elaborating effectful \
instructions, they should have remained in the skeleton"
);
self.func.layout.insert_inst(inst, before);

// Update the inst's arguments.
Expand Down
37 changes: 37 additions & 0 deletions cranelift/filetests/filetests/egraph/issue-7875.clif
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
test optimize
set enable_verifier=true
set opt_level=speed
target x86_64

;; This test case should optimize just fine, and should definitely not produce
;; CLIF that has verifier errors like
;;
;; error: inst10 (v12 = select.f32 v11, v4, v10 ; v11 = 1): uses value arg
;; from non-dominating block4

function %foo() {
block0:
v0 = iconst.i64 0
v2 = f32const 0.0
v9 = f32const 0.0
v20 = fneg v2
v18 = fcmp eq v20, v20
v4 = select v18, v2, v20
v8 = iconst.i32 0
v11 = iconst.i32 1
brif v0, block2, block3

block2:
brif.i32 v8, block4(v2), block4(v9)

block4(v10: f32):
v12 = select.f32 v11, v4, v10
v13 = bitcast.i32 v12
store v13, v0
trap user0

block3:
v15 = bitcast.i32 v4
store v15, v0
return
}

0 comments on commit 7223562

Please sign in to comment.