Skip to content

Commit

Permalink
mir-opt: Merge all branch BBs into a single copy statement
Browse files Browse the repository at this point in the history
  • Loading branch information
DianQK committed Oct 5, 2024
1 parent b2cedc4 commit dddf046
Show file tree
Hide file tree
Showing 13 changed files with 602 additions and 61 deletions.
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ mod lower_intrinsics;
mod lower_slice_len;
mod match_branches;
mod mentioned_items;
mod merge_branches;
mod multiple_return_terminators;
mod nrvo;
mod post_drop_elaboration;
Expand Down Expand Up @@ -609,6 +610,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&dead_store_elimination::DeadStoreElimination::Initial,
&gvn::GVN,
&simplify::SimplifyLocals::AfterGVN,
&merge_branches::MergeBranchSimplification,
&dataflow_const_prop::DataflowConstProp,
&single_use_consts::SingleUseConsts,
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
Expand Down
279 changes: 279 additions & 0 deletions compiler/rustc_mir_transform/src/merge_branches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
//! This pass attempts to merge all branches to eliminate switch terminator.
//! Ideally, we could combine it with `MatchBranchSimplification`, as these two passes
//! match and merge statements with different patterns. Given the compile time and
//! code complexity, we have not merged them into a more general pass for now.
use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
use rustc_index::bit_set::BitSet;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty;
use rustc_middle::ty::util::Discr;
use rustc_middle::ty::{ParamEnv, TyCtxt};
use rustc_mir_dataflow::impls::{MaybeTransitiveLiveLocals, borrowed_locals};
use rustc_mir_dataflow::{Analysis, ResultsCursor};

pub(super) struct MergeBranchSimplification;

impl<'tcx> crate::MirPass<'tcx> for MergeBranchSimplification {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() >= 2
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let def_id = body.source.def_id();
let param_env = tcx.param_env_reveal_all_normalized(def_id);

let borrowed_locals = borrowed_locals(body);
let mut maybe_live: ResultsCursor<'_, '_, MaybeTransitiveLiveLocals<'_>> =
MaybeTransitiveLiveLocals::new(&borrowed_locals)
.into_engine(tcx, body)
.iterate_to_fixpoint()
.into_results_cursor(body);
for i in 0..body.basic_blocks.len() {
let bbs = &*body.basic_blocks;
let switch_bb_idx = BasicBlock::from_usize(i);
let Some((switch_discr, targets)) = bbs[switch_bb_idx].terminator().kind.as_switch()
else {
continue;
};
// Check if the copy source matches the following pattern.
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
let Some(&Statement {
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(src_place))),
..
}) = bbs[switch_bb_idx].statements.last()
else {
continue;
};
if switch_discr.place() != Some(discr_place) {
continue;
}
let src_ty = src_place.ty(body.local_decls(), tcx);
if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() {
continue;
}
// We require that the possible target blocks all be distinct.
if !targets.is_distinct() {
continue;
}
if !bbs[targets.otherwise()].is_empty_unreachable() {
continue;
}
// Check that destinations are identical, and if not, then don't optimize this block.
let mut targets_iter = targets.iter();
let first_terminator_kind = &bbs[targets_iter.next().unwrap().1].terminator().kind;
if !targets_iter.all(|(_, other_target)| {
first_terminator_kind == &bbs[other_target].terminator().kind
}) {
continue;
}
if let Some(dest_place) = can_simplify_to_copy(
tcx,
param_env,
body,
targets,
src_place,
src_ty,
&borrowed_locals,
&mut maybe_live,
) {
let statement_index = bbs[switch_bb_idx].statements.len();
let parent_end = Location { block: switch_bb_idx, statement_index };
let mut patch = MirPatch::new(body);
patch.add_assign(parent_end, dest_place, Rvalue::Use(Operand::Copy(src_place)));
patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone());
patch.apply(body);
super::simplify::remove_dead_blocks(body);
// After modifying the MIR, the result of `MaybeTransitiveLiveLocals` may become invalid,
// keeping it simple to process only once.
break;
}
}
}
}

/// The GVN simplified
/// ```ignore (syntax-highlighting-only)
/// match a {
/// Foo::A(x) => Foo::A(*x),
/// Foo::B => Foo::B
/// }
/// ```
/// to
/// ```ignore (syntax-highlighting-only)
/// match a {
/// Foo::A(_x) => a, // copy a
/// Foo::B => Foo::B
/// }
/// ```
/// This function answers whether it can be simplified to a copy statement
/// by returning the copy destination.
fn can_simplify_to_copy<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
body: &Body<'tcx>,
targets: &SwitchTargets,
src_place: Place<'tcx>,
src_ty: tcx::PlaceTy<'tcx>,
borrowed_locals: &BitSet<Local>,
maybe_live: &mut ResultsCursor<'_, 'tcx, MaybeTransitiveLiveLocals<'_>>,
) -> Option<Place<'tcx>> {
let mut targets_iter = targets.iter();
let dest_place = targets_iter.next().and_then(|(index, target)| {
find_copy_assign(
tcx,
param_env,
body,
index,
target,
src_place,
src_ty,
borrowed_locals,
maybe_live,
)
})?;
let dest_ty = dest_place.ty(body.local_decls(), tcx);
if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
return None;
}
if targets_iter.any(|(other_index, other_target)| {
Some(dest_place)
!= find_copy_assign(
tcx,
param_env,
body,
other_index,
other_target,
src_place,
src_ty,
borrowed_locals,
maybe_live,
)
}) {
return None;
}
Some(dest_place)
}

fn find_copy_assign<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
body: &Body<'tcx>,
index: u128,
target_block: BasicBlock,
src_place: Place<'tcx>,
src_ty: tcx::PlaceTy<'tcx>,
borrowed_locals: &BitSet<Local>,
maybe_live: &mut ResultsCursor<'_, 'tcx, MaybeTransitiveLiveLocals<'_>>,
) -> Option<Place<'tcx>> {
let statements = &body.basic_blocks[target_block].statements;
if statements.is_empty() {
return None;
}
let assign_stmt = if statements.len() == 1 {
0
} else {
// We are matching a statement copied from the source to the same destination from the BB,
// and dead statements can be ignored.
// We can treat the rvalue is the source if it's equal to the source.
let mut lived_stmts: BitSet<usize> = BitSet::new_filled(statements.len());
let mut expected_assign_stmt = None;
for (statement_index, statement) in statements.iter().enumerate().rev() {
let loc = Location { block: target_block, statement_index };
if let StatementKind::Assign(assign) = &statement.kind {
if !assign.1.is_safe_to_remove() {
return None;
}
}
match &statement.kind {
StatementKind::Assign(box (dest_place, _))
| StatementKind::SetDiscriminant { place: box dest_place, .. }
| StatementKind::Deinit(box dest_place) => {
if dest_place.is_indirect() || borrowed_locals.contains(dest_place.local) {
return None;
}
maybe_live.seek_before_primary_effect(loc);
if !maybe_live.get().contains(dest_place.local) {
lived_stmts.remove(statement_index);
} else if matches!(statement.kind, StatementKind::Assign(_))
&& expected_assign_stmt.is_none()
{
// There is only one statement that cannot be ignored
// that can be used as an expected copy statement.
expected_assign_stmt = Some(statement_index);
lived_stmts.remove(statement_index);
} else {
return None;
}
}
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Nop => (),

StatementKind::Retag(_, _)
| StatementKind::Coverage(_)
| StatementKind::Intrinsic(_)
| StatementKind::ConstEvalCounter
| StatementKind::PlaceMention(_)
| StatementKind::FakeRead(_)
| StatementKind::AscribeUserType(_, _) => {
return None;
}
}
}
let expected_assign = expected_assign_stmt?;
// We can ignore the paired StorageLive and StorageDead.
let mut storage_live_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len());
for stmt_index in lived_stmts.iter() {
let statement = &statements[stmt_index];
match &statement.kind {
StatementKind::StorageLive(local) if storage_live_locals.insert(*local) => {}
StatementKind::StorageDead(local) if storage_live_locals.remove(*local) => {}
StatementKind::Nop => {}
_ => return None,
}
}
if !storage_live_locals.is_empty() {
return None;
}
expected_assign
};
let Statement { kind: StatementKind::Assign(box (dest_place, ref rvalue)), .. } =
statements[assign_stmt]
else {
return None;
};
let dest_ty = dest_place.ty(body.local_decls(), tcx);
if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
return None;
}
let ty::Adt(def, _) = dest_ty.ty.kind() else {
return None;
};
match rvalue {
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
Rvalue::Use(Operand::Constant(box constant))
if let Const::Val(const_, ty) = constant.const_ =>
{
let (ecx, op) = mk_eval_cx_for_const_val(tcx.at(constant.span), param_env, const_, ty)?;
let variant = ecx.read_discriminant(&op).discard_err()?;
if !def.variants()[variant].fields.is_empty() {
return None;
}
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
if val != index {
return None;
}
}
Rvalue::Use(Operand::Copy(place)) if *place == src_place => {}
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
if fields.is_empty()
&& let Some(Discr { val, .. }) =
src_ty.ty.discriminant_for_variant(tcx, *variant_index)
&& val == index => {}
_ => return None,
}
Some(dest_place)
}
14 changes: 8 additions & 6 deletions tests/codegen/match-optimizes-away.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
//@ compile-flags: -O
//@ compile-flags: -O -Cno-prepopulate-passes

#![crate_type = "lib"]

pub enum Three {
Expand All @@ -19,8 +19,9 @@ pub enum Four {
#[no_mangle]
pub fn three_valued(x: Three) -> Three {
// CHECK-LABEL: @three_valued
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i8 %0
// CHECK-SAME: (i8{{.*}} [[X:%x]])
// CHECK-NEXT: start:
// CHECK-NEXT: ret i8 [[X]]
match x {
Three::A => Three::A,
Three::B => Three::B,
Expand All @@ -31,8 +32,9 @@ pub fn three_valued(x: Three) -> Three {
#[no_mangle]
pub fn four_valued(x: Four) -> Four {
// CHECK-LABEL: @four_valued
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i16 %0
// CHECK-SAME: (i16{{.*}} [[X:%x]])
// CHECK-NEXT: start:
// CHECK-NEXT: ret i16 [[X]]
match x {
Four::A => Four::A,
Four::B => Four::B,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
- // MIR for `no_fields` before MergeBranchSimplification
+ // MIR for `no_fields` after MergeBranchSimplification

fn no_fields(_1: NoFields) -> NoFields {
debug a => _1;
let mut _0: NoFields;
let mut _2: isize;

bb0: {
_2 = discriminant(_1);
- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
+ _0 = copy _1;
+ goto -> bb1;
}

bb1: {
- unreachable;
- }
-
- bb2: {
- _0 = NoFields::B;
- goto -> bb4;
- }
-
- bb3: {
- _0 = NoFields::A;
- goto -> bb4;
- }
-
- bb4: {
return;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
- // MIR for `no_fields_failed` before MergeBranchSimplification
+ // MIR for `no_fields_failed` after MergeBranchSimplification

fn no_fields_failed(_1: NoFields) -> NoFields {
debug a => _1;
let mut _0: NoFields;
let mut _2: isize;

bb0: {
_2 = discriminant(_1);
switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
}

bb1: {
unreachable;
}

bb2: {
_0 = NoFields::A;
goto -> bb4;
}

bb3: {
_0 = NoFields::B;
goto -> bb4;
}

bb4: {
return;
}
}

Loading

0 comments on commit dddf046

Please sign in to comment.