From aa7e5ca1e7b655c4e5d8968caa5ca7bd7d133f7b Mon Sep 17 00:00:00 2001 From: Camille GILLOT Date: Sat, 22 Jul 2023 16:59:13 +0000 Subject: [PATCH] Refactor UninhabitedEnumBranching to mark targets unreachable. --- .../src/uninhabited_enum_branching.rs | 107 ++++++++---------- ...after-uninhabited-enum-branching.after.mir | 8 +- ...anching.main.UninhabitedEnumBranching.diff | 9 +- ...after-uninhabited-enum-branching.after.mir | 12 ++ ...nching2.main.UninhabitedEnumBranching.diff | 8 +- ..._fallthrough.UninhabitedEnumBranching.diff | 6 +- 6 files changed, 83 insertions(+), 67 deletions(-) diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs index 092bcb5c97930..98f67e18a8d0e 100644 --- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs +++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs @@ -3,8 +3,7 @@ use crate::MirPass; use rustc_data_structures::fx::FxHashSet; use rustc_middle::mir::{ - BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator, - TerminatorKind, + BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind, }; use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::ty::{Ty, TyCtxt}; @@ -30,18 +29,16 @@ fn get_switched_on_type<'tcx>( let terminator = block_data.terminator(); // Only bother checking blocks which terminate by switching on a local. - if let Some(local) = get_discriminant_local(&terminator.kind) { - let stmt_before_term = (!block_data.statements.is_empty()) - .then(|| &block_data.statements[block_data.statements.len() - 1].kind); - - if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term - { - if l.as_local() == Some(local) { - let ty = place.ty(body, tcx).ty; - if ty.is_enum() { - return Some(ty); - } - } + let local = get_discriminant_local(&terminator.kind)?; + + let stmt_before_term = block_data.statements.last()?; + + if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind + && l.as_local() == Some(local) + { + let ty = place.ty(body, tcx).ty; + if ty.is_enum() { + return Some(ty); } } @@ -72,28 +69,6 @@ fn variant_discriminants<'tcx>( } } -/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new -/// bb to use as the new target if not. -fn ensure_otherwise_unreachable<'tcx>( - body: &Body<'tcx>, - targets: &SwitchTargets, -) -> Option> { - let otherwise = targets.otherwise(); - let bb = &body.basic_blocks[otherwise]; - if bb.terminator().kind == TerminatorKind::Unreachable - && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_))) - { - return None; - } - - let mut new_block = BasicBlockData::new(Some(Terminator { - source_info: bb.terminator().source_info, - kind: TerminatorKind::Unreachable, - })); - new_block.is_cleanup = bb.is_cleanup; - Some(new_block) -} - impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { sess.mir_opt_level() > 0 @@ -102,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("UninhabitedEnumBranching starting for {:?}", body.source); - for bb in body.basic_blocks.indices() { + let mut removable_switchs = Vec::new(); + + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { trace!("processing block {:?}", bb); - let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body) - else { + if bb_data.is_cleanup { continue; - }; + } + + let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue }; let layout = tcx.layout_of( tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty), @@ -122,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { trace!("allowed_variants = {:?}", allowed_variants); - if let TerminatorKind::SwitchInt { targets, .. } = - &mut body.basic_blocks_mut()[bb].terminator_mut().kind - { - let mut new_targets = SwitchTargets::new( - targets.iter().filter(|(val, _)| allowed_variants.contains(val)), - targets.otherwise(), - ); - - if new_targets.iter().count() == allowed_variants.len() { - if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) { - let new_otherwise = body.basic_blocks_mut().push(updated); - *new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise; - } - } + let terminator = bb_data.terminator(); + let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() }; - if let TerminatorKind::SwitchInt { targets, .. } = - &mut body.basic_blocks_mut()[bb].terminator_mut().kind - { - *targets = new_targets; + let mut reachable_count = 0; + for (index, (val, _)) in targets.iter().enumerate() { + if allowed_variants.contains(&val) { + reachable_count += 1; } else { - unreachable!() + removable_switchs.push((bb, index)); } - } else { - unreachable!() } + + if reachable_count == allowed_variants.len() { + removable_switchs.push((bb, targets.iter().count())); + } + } + + if removable_switchs.is_empty() { + return; + } + + let new_block = BasicBlockData::new(Some(Terminator { + source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info, + kind: TerminatorKind::Unreachable, + })); + let unreachable_block = body.basic_blocks.as_mut().push(new_block); + + for (bb, index) in removable_switchs { + let bb = &mut body.basic_blocks.as_mut()[bb]; + let terminator = bb.terminator_mut(); + let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() }; + targets.all_targets_mut()[index] = unreachable_block; } } } diff --git a/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir b/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir index 1ee44e48c8ec8..1df3b74b3e7d9 100644 --- a/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir +++ b/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir @@ -12,14 +12,20 @@ fn main() -> () { let mut _8: isize; let _9: &str; let mut _10: bool; + let mut _11: bool; + let mut _12: bool; bb0: { StorageLive(_1); StorageLive(_2); _2 = Test1::C; _3 = discriminant(_2); - _10 = Eq(_3, const 2_isize); + _10 = Ne(_3, const 0_isize); assume(move _10); + _11 = Ne(_3, const 1_isize); + assume(move _11); + _12 = Eq(_3, const 2_isize); + assume(move _12); StorageLive(_5); _5 = const "C"; _1 = &(*_5); diff --git a/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff b/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff index 9db95abec34c8..5b107f80ce80b 100644 --- a/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff +++ b/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff @@ -19,7 +19,7 @@ _2 = Test1::C; _3 = discriminant(_2); - switchInt(move _3) -> [0: bb3, 1: bb4, 2: bb1, otherwise: bb2]; -+ switchInt(move _3) -> [2: bb1, otherwise: bb2]; ++ switchInt(move _3) -> [0: bb9, 1: bb9, 2: bb1, otherwise: bb9]; } bb1: { @@ -54,7 +54,8 @@ StorageLive(_7); _7 = Test2::D; _8 = discriminant(_7); - switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2]; +- switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2]; ++ switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb9]; } bb6: { @@ -75,6 +76,10 @@ StorageDead(_6); _0 = const (); return; ++ } ++ ++ bb9: { ++ unreachable; } } diff --git a/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir b/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir index 9c0c5d18917ae..06375b3ffaebf 100644 --- a/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir +++ b/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir @@ -15,6 +15,10 @@ fn main() -> () { let _11: &str; let _12: &str; let _13: &str; + let mut _14: bool; + let mut _15: bool; + let mut _16: bool; + let mut _17: bool; scope 1 { debug plop => _1; } @@ -29,6 +33,10 @@ fn main() -> () { StorageLive(_4); _4 = &(_1.1: Test1); _5 = discriminant((*_4)); + _16 = Ne(_5, const 0_isize); + assume(move _16); + _17 = Ne(_5, const 1_isize); + assume(move _17); switchInt(move _5) -> [2: bb3, 3: bb1, otherwise: bb2]; } @@ -57,6 +65,10 @@ fn main() -> () { StorageDead(_3); StorageLive(_9); _10 = discriminant((_1.1: Test1)); + _14 = Ne(_10, const 0_isize); + assume(move _14); + _15 = Ne(_10, const 1_isize); + assume(move _15); switchInt(move _10) -> [2: bb6, 3: bb5, otherwise: bb2]; } diff --git a/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff b/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff index 12ce6505af948..165421acd6916 100644 --- a/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff +++ b/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff @@ -31,7 +31,7 @@ _4 = &(_1.1: Test1); _5 = discriminant((*_4)); - switchInt(move _5) -> [0: bb3, 1: bb4, 2: bb5, 3: bb1, otherwise: bb2]; -+ switchInt(move _5) -> [2: bb5, 3: bb1, otherwise: bb2]; ++ switchInt(move _5) -> [0: bb12, 1: bb12, 2: bb5, 3: bb1, otherwise: bb12]; } bb1: { @@ -73,7 +73,7 @@ StorageLive(_9); _10 = discriminant((_1.1: Test1)); - switchInt(move _10) -> [0: bb8, 1: bb9, 2: bb10, 3: bb7, otherwise: bb2]; -+ switchInt(move _10) -> [2: bb10, 3: bb7, otherwise: bb2]; ++ switchInt(move _10) -> [0: bb12, 1: bb12, 2: bb10, 3: bb7, otherwise: bb12]; } bb7: { @@ -110,6 +110,10 @@ _0 = const (); StorageDead(_1); return; ++ } ++ ++ bb12: { ++ unreachable; } } diff --git a/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff b/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff index 498e1e20f8a79..79948139f888d 100644 --- a/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff +++ b/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff @@ -9,7 +9,7 @@ bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb2, 1: bb3, otherwise: bb1]; -+ switchInt(move _2) -> [1: bb3, otherwise: bb1]; ++ switchInt(move _2) -> [0: bb5, 1: bb3, otherwise: bb1]; } bb1: { @@ -29,6 +29,10 @@ bb4: { return; ++ } ++ ++ bb5: { ++ unreachable; } }