Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cranelift: Break op cost ties with expression depth in egraphs #7456

Merged
merged 5 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 92 additions & 15 deletions cranelift/codegen/src/egraph/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,52 @@ use crate::ir::Opcode;
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);

impl core::fmt::Debug for Cost {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if *self == Cost::infinity() {
write!(f, "Cost::Infinite")
} else {
f.debug_struct("Cost::Finite")
.field("op_cost", &self.op_cost())
.field("depth", &self.depth())
.finish()
}
}
}
fitzgen marked this conversation as resolved.
Show resolved Hide resolved

impl Ord for Cost {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We make sure that the high bits are the op cost and the low bits are
// the depth. This means that we can use normal integer comparison to
// order by op cost and then depth.
//
// We want to break op cost ties with depth (rather than the other way
// around). When the op cost is the same, we prefer shallow and wide
// expressions to narrow and deep expressions and breaking ties with
// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
// to `((a + b) + c) + d`. This is beneficial because it exposes more
// instruction-level parallelism and shortens live ranges.
self.0.cmp(&other.0)
}
}

impl PartialOrd for Cost {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Cost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = (Self::OP_COST_MASK >> Self::DEPTH_BITS) - 1;

pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Expand All @@ -43,11 +86,38 @@ impl Cost {
Cost(0)
}

/// Clamp this cost at a "finite" value. Can be used in
/// conjunction with saturating ops to avoid saturating into
/// `infinity()`.
fn finite(self) -> Cost {
Cost(std::cmp::min(u32::MAX - 1, self.0))
/// Construct a new finite cost from the given parts.
///
/// The opcode cost is clamped to the maximum value representable.
fn new_finite(opcode_cost: u32, depth: u8) -> Cost {
let opcode_cost = std::cmp::min(opcode_cost, Self::MAX_OP_COST);
let cost = Cost((opcode_cost << Self::DEPTH_BITS) | u32::from(depth));
debug_assert_ne!(cost, Cost::infinity());
cost
}

fn depth(&self) -> u8 {
let depth = self.0 & Self::DEPTH_MASK;
u8::try_from(depth).unwrap()
}

fn op_cost(&self) -> u32 {
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
fitzgen marked this conversation as resolved.
Show resolved Hide resolved
}

/// Compute the cost of the operation and its given operands.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
Cost::new_finite(c.op_cost(), c.depth().saturating_add(1))
}
}

impl std::iter::Sum<Cost> for Cost {
fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
iter.fold(Self::zero(), |a, b| a + b)
}
}

Expand All @@ -59,22 +129,29 @@ impl std::default::Default for Cost {

impl std::ops::Add<Cost> for Cost {
type Output = Cost;

fn add(self, other: Cost) -> Cost {
Cost(self.0.saturating_add(other.0)).finite()
let op_cost = std::cmp::min(
self.op_cost().saturating_add(other.op_cost()),
Self::MAX_OP_COST,
);
let depth = std::cmp::max(self.depth(), other.depth());
Cost::new_finite(op_cost, depth)
}
}

/// Return the cost of a *pure* opcode. Caller is responsible for
/// checking that the opcode came from an instruction that satisfies
/// `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn pure_op_cost(op: Opcode) -> Cost {
/// Return the cost of a *pure* opcode.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
fn pure_op_cost(op: Opcode) -> Cost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost(1),
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new_finite(1, 0),

// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost(2)
Cost::new_finite(2, 0)
}

// "Simple" arithmetic.
Expand All @@ -86,9 +163,9 @@ pub(crate) fn pure_op_cost(op: Opcode) -> Cost {
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost(3),
| Opcode::Sshr => Cost::new_finite(3, 0),

// Everything else (pure.)
_ => Cost(4),
_ => Cost::new_finite(4, 0),
}
}
13 changes: 5 additions & 8 deletions cranelift/codegen/src/egraph/elaborate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Elaboration phase: lowers EGraph back to sequences of operations
//! in CFG nodes.

use super::cost::{pure_op_cost, Cost};
use super::cost::Cost;
use super::domtree::DomTreeWithChildren;
use super::Stats;
use crate::dominator_tree::DominatorTree;
Expand Down Expand Up @@ -245,13 +245,10 @@ impl<'a> Elaborator<'a> {
// N.B.: at this point we know that the opcode is
// pure, so `pure_op_cost`'s precondition is
// satisfied.
let cost = self
.func
.dfg
.inst_values(inst)
.fold(pure_op_cost(inst_data.opcode()), |cost, value| {
cost + best[value].0
});
let cost = Cost::of_pure_op(
inst_data.opcode(),
self.func.dfg.inst_values(inst).map(|value| best[value].0),
);
best[value] = BestEntry(cost, value);
}
}
Expand Down