From 50c30220c35f6b05d0442f046798c381c6ee25f1 Mon Sep 17 00:00:00 2001 From: dianqk Date: Sun, 16 Nov 2025 23:05:01 +0800 Subject: [PATCH 1/4] Overhaul MatchBranchSimplification --- compiler/rustc_mir_transform/src/lib.rs | 1 + .../rustc_mir_transform/src/match_branches.rs | 796 ++++++++---------- ...tch_eq_bool.MatchBranchSimplification.diff | 49 ++ ...h_eq_bool_2.MatchBranchSimplification.diff | 44 + tests/mir-opt/matches_reduce_branches.rs | 71 ++ ...single_case.MatchBranchSimplification.diff | 22 + ....foo.SimplifyLocals-final.panic-abort.diff | 1 - ...foo.SimplifyLocals-final.panic-unwind.diff | 1 - 8 files changed, 539 insertions(+), 446 deletions(-) create mode 100644 tests/mir-opt/matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index bc2c6bd81aca9..5860101771a69 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -6,6 +6,7 @@ #![feature(file_buffered)] #![feature(if_let_guard)] #![feature(impl_trait_in_assoc_type)] +#![feature(iterator_try_collect)] #![feature(try_blocks)] #![feature(yeet_expr)] // tidy-alphabetical-end diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 5e511f1a418b6..d5b65958fdde3 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,15 +1,12 @@ -use std::iter; - use rustc_abi::Integer; -use rustc_index::IndexSlice; use rustc_middle::mir::*; use rustc_middle::ty::layout::{IntegerExt, TyAndLayout}; use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; -use tracing::instrument; use super::simplify::simplify_cfg; use crate::patch::MirPatch; +/// Merges all targets into one basic block if each statement can have the same statement. pub(super) struct MatchBranchSimplification; impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { @@ -19,32 +16,15 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let typing_env = body.typing_env(tcx); - let mut apply_patch = false; - let mut patch = MirPatch::new(body); - for (bb, bb_data) in body.basic_blocks.iter_enumerated() { - match &bb_data.terminator().kind { - TerminatorKind::SwitchInt { - discr: Operand::Copy(_) | Operand::Move(_), - targets, - .. - // We require that the possible target blocks don't contain this block. - } if !targets.all_targets().contains(&bb) => {} - // Only optimize switch int statements - _ => continue, - }; - - if SimplifyToIf.simplify(tcx, body, &mut patch, bb, typing_env).is_some() { - apply_patch = true; + let mut changed = false; + for bb in body.basic_blocks.indices() { + if !candidate_match(body, bb) { continue; - } - if SimplifyToExp::default().simplify(tcx, body, &mut patch, bb, typing_env).is_some() { - apply_patch = true; - continue; - } + }; + changed |= simplify_match(tcx, typing_env, body, bb) } - if apply_patch { - patch.apply(body); + if changed { simplify_cfg(tcx, body); } } @@ -54,222 +34,307 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { } } -trait SimplifyMatch<'tcx> { - /// Simplifies a match statement, returning `Some` if the simplification succeeds, `None` - /// otherwise. Generic code is written here, and we generally don't need a custom - /// implementation. - fn simplify( - &mut self, - tcx: TyCtxt<'tcx>, - body: &Body<'tcx>, - patch: &mut MirPatch<'tcx>, - switch_bb_idx: BasicBlock, - typing_env: ty::TypingEnv<'tcx>, - ) -> Option<()> { - let bbs = &body.basic_blocks; - let TerminatorKind::SwitchInt { discr, targets, .. } = - &bbs[switch_bb_idx].terminator().kind - else { - unreachable!(); - }; - - let discr_ty = discr.ty(body.local_decls(), tcx); - self.can_simplify(tcx, targets, typing_env, bbs, discr_ty)?; - - // Take ownership of items now that we know we can optimize. - let discr = discr.clone(); - - // Introduce a temporary for the discriminant value. - let source_info = bbs[switch_bb_idx].terminator().source_info; - let discr_local = patch.new_temp(discr_ty, source_info.span); +struct SimplifyMatch<'tcx, 'a> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + patch: MirPatch<'tcx>, + body: &'a Body<'tcx>, + switch_bb: BasicBlock, + discr_local: Option, + discr_ty: Ty<'tcx>, +} - let (_, first) = targets.iter().next().unwrap(); - let statement_index = bbs[switch_bb_idx].statements.len(); - let parent_end = Location { block: switch_bb_idx, statement_index }; - patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); - patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); - self.new_stmts(tcx, targets, typing_env, patch, parent_end, bbs, discr_local, discr_ty); - patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); - patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone()); - Some(()) +impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { + fn discr_local(&mut self) -> Local { + *self.discr_local.get_or_insert_with(|| { + // Introduce a temporary for the discriminant value. + let source_info = self.body.basic_blocks[self.switch_bb].terminator().source_info; + self.patch.new_temp(self.discr_ty, source_info.span) + }) } - /// Check that the BBs to be simplified satisfies all distinct and - /// that the terminator are the same. - /// There are also conditions for different ways of simplification. - fn can_simplify( - &mut self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - bbs: &IndexSlice>, - discr_ty: Ty<'tcx>, - ) -> Option<()>; - - fn new_stmts( + /// Merges the assignments if all rvalues are constants and equal. + fn merge_if_equal_const( &self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - patch: &mut MirPatch<'tcx>, - parent_end: Location, - bbs: &IndexSlice>, - discr_local: Local, - discr_ty: Ty<'tcx>, - ); -} - -struct SimplifyToIf; + dest: Place<'tcx>, + consts: &[(u128, &ConstOperand<'tcx>)], + otherwise: Option<&ConstOperand<'tcx>>, + ) -> Option> { + let (_, first_const, mut others) = split_first_case(consts, otherwise); + let first_scalar_int = first_const.const_.try_eval_scalar_int(self.tcx, self.typing_env)?; + if others.all(|const_| { + const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) == Some(first_scalar_int) + }) { + Some(StatementKind::Assign(Box::new(( + dest, + Rvalue::Use(Operand::Constant(Box::new(first_const.clone()))), + )))) + } else { + None + } + } -/// If a source block is found that switches between two blocks that are exactly -/// the same modulo const bool assignments (e.g., one assigns true another false -/// to the same place), merge a target block statements into the source block, -/// using Eq / Ne comparison with switch value where const bools value differ. -/// -/// For example: -/// -/// ```ignore (MIR) -/// bb0: { -/// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; -/// } -/// -/// bb1: { -/// _2 = const true; -/// goto -> bb3; -/// } -/// -/// bb2: { -/// _2 = const false; -/// goto -> bb3; -/// } -/// ``` -/// -/// into: -/// -/// ```ignore (MIR) -/// bb0: { -/// _2 = Eq(move _3, const 42_isize); -/// goto -> bb3; -/// } -/// ``` -impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { - #[instrument(level = "debug", skip(self, tcx), ret)] - fn can_simplify( + /// If a source block is found that switches between two blocks that are exactly + /// the same modulo const bool assignments (e.g., one assigns true another false + /// to the same place), merge a target block statements into the source block, + /// using Eq / Ne comparison with switch value where const bools value differ. + /// + /// For example: + /// + /// ```ignore (MIR) + /// bb0: { + /// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; + /// } + /// + /// bb1: { + /// _2 = const true; + /// goto -> bb3; + /// } + /// + /// bb2: { + /// _2 = const false; + /// goto -> bb3; + /// } + /// ``` + /// + /// into: + /// + /// ```ignore (MIR) + /// bb0: { + /// _2 = Eq(move _3, const 42_isize); + /// goto -> bb3; + /// } + /// ``` + fn merge_by_eq_op( &mut self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - bbs: &IndexSlice>, - _discr_ty: Ty<'tcx>, - ) -> Option<()> { - let (first, second) = match targets.all_targets() { - &[first, otherwise] => (first, otherwise), - &[first, second, otherwise] if bbs[otherwise].is_empty_unreachable() => (first, second), - _ => { - return None; - } - }; - - // We require that the possible target blocks all be distinct. - if first == second { + dest: Place<'tcx>, + consts: &[(u128, &ConstOperand<'tcx>)], + otherwise: Option<&ConstOperand<'tcx>>, + ) -> Option> { + // FIXME: extend to any case. + let (first_case, first_const, mut others) = split_first_case(consts, otherwise); + if !first_const.ty().is_bool() { return None; } - // Check that destinations are identical, and if not, then don't optimize this block - if bbs[first].terminator().kind != bbs[second].terminator().kind { - return None; + let first_bool = first_const.const_.try_eval_bool(self.tcx, self.typing_env)?; + if others.all(|const_| { + const_.const_.try_eval_bool(self.tcx, self.typing_env) == Some(!first_bool) + }) { + // Make value conditional on switch condition. + let size = + self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + self.tcx, + self.discr_ty, + rustc_const_eval::interpret::Scalar::from_uint(first_case, size), + rustc_span::DUMMY_SP, + ); + let op = if first_bool { BinOp::Eq } else { BinOp::Ne }; + let rval = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(self.discr_local())), const_cmp)), + ); + Some(StatementKind::Assign(Box::new((dest, rval)))) + } else { + None } + } - // Check that blocks are assignments of consts to the same place or same statement, - // and match up 1-1, if not don't optimize this block. - let first_stmts = &bbs[first].statements; - let second_stmts = &bbs[second].statements; - if first_stmts.len() != second_stmts.len() { + /// Merges the assignments if all rvalues can be cast from the discriminant value by IntToInt. + /// + /// For example: + /// + /// ```ignore (MIR) + /// bb0: { + /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; + /// } + /// + /// bb1: { + /// unreachable; + /// } + /// + /// bb2: { + /// _0 = const 1_i16; + /// goto -> bb5; + /// } + /// + /// bb3: { + /// _0 = const 2_i16; + /// goto -> bb5; + /// } + /// + /// bb4: { + /// _0 = const 3_i16; + /// goto -> bb5; + /// } + /// ``` + /// + /// into: + /// + /// ```ignore (MIR) + /// bb0: { + /// _0 = _1 as i16 (IntToInt); + /// goto -> bb5; + /// } + /// ``` + fn merge_by_int_to_int( + &mut self, + dest: Place<'tcx>, + consts: &[(u128, &ConstOperand<'tcx>)], + ) -> Option> { + let (_, first_const) = consts[0]; + if !first_const.ty().is_integral() { return None; } - for (f, s) in iter::zip(first_stmts, second_stmts) { - match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => {} + let discr_layout = + self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap(); + if consts.iter().all(|&(case, const_)| { + let Some(scalar_int) = const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) + else { + return false; + }; + can_cast(self.tcx, case, discr_layout, const_.ty(), scalar_int) + }) { + let operand = Operand::Copy(Place::from(self.discr_local())); + let rval = if first_const.ty() == self.discr_ty { + Rvalue::Use(operand) + } else { + Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty()) + }; + Some(StatementKind::Assign(Box::new((dest, rval)))) + } else { + None + } + } - // If two statements are const bool assignments to the same place, we can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty().is_bool() - && s_c.const_.ty().is_bool() - && f_c.const_.try_eval_bool(tcx, typing_env).is_some() - && s_c.const_.try_eval_bool(tcx, typing_env).is_some() => {} + /// Returns a new statement if we can use the statement replace all statements. + fn try_merge_stmts( + &mut self, + _index: usize, + stmts: &[(u128, &StatementKind<'tcx>)], + otherwise: Option<&StatementKind<'tcx>>, + ) -> Option> { + if let Some(new_stmt) = identical_stmts(stmts, otherwise) { + return Some(new_stmt); + } - // Otherwise we cannot optimize. Try another block. - _ => return None, + let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?; + if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) { + if let Some(new_stmt) = self.merge_if_equal_const(dest, &consts, otherwise) { + return Some(new_stmt); + } + if let Some(new_stmt) = self.merge_by_eq_op(dest, &consts, otherwise) { + return Some(new_stmt); + } + // Requires the otherwise is unreachable. + if otherwise.is_none() + && let Some(new_stmt) = self.merge_by_int_to_int(dest, &consts) + { + return Some(new_stmt); } } - Some(()) + None } +} - fn new_stmts( - &self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - patch: &mut MirPatch<'tcx>, - parent_end: Location, - bbs: &IndexSlice>, - discr_local: Local, - discr_ty: Ty<'tcx>, - ) { - let ((val, first), second) = match (targets.all_targets(), targets.all_values()) { - (&[first, otherwise], &[val]) => ((val, first), otherwise), - (&[first, second, otherwise], &[val, _]) if bbs[otherwise].is_empty_unreachable() => { - ((val, first), second) - } - _ => unreachable!(), - }; - - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let first = &bbs[first]; - let second = &bbs[second]; - for (f, s) in iter::zip(&first.statements, &second.statements) { - match (&f.kind, &s.kind) { - (f_s, s_s) if f_s == s_s => { - patch.add_statement(parent_end, f.kind.clone()); - } - - ( - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), - ) => { - // From earlier loop we know that we are dealing with bool constants only: - let f_b = f_c.const_.try_eval_bool(tcx, typing_env).unwrap(); - let s_b = s_c.const_.try_eval_bool(tcx, typing_env).unwrap(); - if f_b == s_b { - // Same value in both blocks. Use statement as is. - patch.add_statement(parent_end, f.kind.clone()); - } else { - // Different value between blocks. Make value conditional on switch - // condition. - let size = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap().size; - let const_cmp = Operand::const_from_scalar( - tcx, - discr_ty, - rustc_const_eval::interpret::Scalar::from_uint(val, size), - rustc_span::DUMMY_SP, - ); - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; - let rhs = Rvalue::BinaryOp( - op, - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), - ); - patch.add_assign(parent_end, *lhs, rhs); - } - } +/// Returns the first case target if all targets have an equal number of statements and identical destination. +fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool { + let targets = match &body.basic_blocks[switch_bb].terminator().kind { + TerminatorKind::SwitchInt { + discr: Operand::Copy(_) | Operand::Move(_), targets, .. + } => targets, + // Only optimize switch int statements + _ => return false, + }; + // We require that the possible target blocks don't contain this block. + if targets.all_targets().contains(&switch_bb) { + return false; + } + // We require that the possible target blocks all be distinct. + if !targets.is_distinct() { + return false; + } + let &[first, ref others @ .., otherwise] = targets.all_targets() else { + return false; + }; + let first_case_bb = &body.basic_blocks[first]; + let first_case_terminator_kind = &first_case_bb.terminator().kind; + let first_case_stmts_len = first_case_bb.statements.len(); + + let otherwise = + if body.basic_blocks[otherwise].is_empty_unreachable() { None } else { Some(&otherwise) }; + // Check that destinations are identical, and if not, then don't optimize this block + others.iter().chain(otherwise).all(|&bb| { + let bb = &body.basic_blocks[bb]; + first_case_stmts_len == bb.statements.len() + && first_case_terminator_kind == &bb.terminator().kind + }) +} - _ => unreachable!(), - } +fn simplify_match<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + body: &mut Body<'tcx>, + switch_bb: BasicBlock, +) -> bool { + let (discr, targets) = match &body.basic_blocks[switch_bb].terminator().kind { + TerminatorKind::SwitchInt { discr, targets, .. } => (discr, targets), + _ => unreachable!(), + }; + let mut simplify_match = SimplifyMatch { + tcx, + typing_env, + patch: MirPatch::new(body), + body, + switch_bb, + discr_local: None, + discr_ty: discr.ty(body.local_decls(), tcx), + }; + let stmts: Vec<_> = targets + .iter() + .map(|(case, bb)| (case, simplify_match.body.basic_blocks[bb].statements.as_slice())) + .collect(); + let mut new_stmts = Vec::new(); + let otherwise_stmts = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() { + None + } else { + Some(body.basic_blocks[targets.otherwise()].statements.as_slice()) + }; + let first_case_bb = targets.all_targets()[0]; + let stmt_len = body.basic_blocks[first_case_bb].statements.len(); + let mut cases = Vec::with_capacity(stmt_len); + // Check at each position in the basic blocks whether these statements can be merged. + for index in 0..stmt_len { + let otherwise = otherwise_stmts.map(|stmt| &stmt[index].kind); + cases.clear(); + for &(case, stmts) in &stmts { + cases.push((case, &stmts[index].kind)); } + let Some(new_stmt) = simplify_match.try_merge_stmts(index, cases.as_slice(), otherwise) + else { + return false; + }; + new_stmts.push(new_stmt); } + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + + let statement_index = body.basic_blocks[switch_bb].statements.len(); + let parent_end = Location { block: switch_bb, statement_index }; + let mut patch = simplify_match.patch; + if let Some(discr_local) = simplify_match.discr_local { + patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); + patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); + } + for new_stmt in new_stmts { + patch.add_statement(parent_end, new_stmt); + } + if let Some(discr_local) = simplify_match.discr_local { + patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); + } + patch.patch_terminator(switch_bb, body.basic_blocks[first_case_bb].terminator().kind.clone()); + patch.apply(body); + true } /// Check if the cast constant using `IntToInt` is equal to the target constant. @@ -298,234 +363,77 @@ fn can_cast( cast_scalar == target_scalar } -#[derive(Default)] -struct SimplifyToExp { - transform_kinds: Vec, -} - -#[derive(Clone, Copy, Debug)] -enum ExpectedTransformKind<'a, 'tcx> { - /// Identical statements. - Same(&'a StatementKind<'tcx>), - /// Assignment statements have the same value. - SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt }, - /// Enum variant comparison type. - Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> }, -} - -enum TransformKind { - Same, - Cast, -} - -impl From> for TransformKind { - fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self { - match compare_type { - ExpectedTransformKind::Same(_) => TransformKind::Same, - ExpectedTransformKind::SameByEq { .. } => TransformKind::Same, - ExpectedTransformKind::Cast { .. } => TransformKind::Cast, - } - } -} - -/// If we find that the value of match is the same as the assignment, -/// merge a target block statements into the source block, -/// using cast to transform different integer types. -/// -/// For example: -/// -/// ```ignore (MIR) -/// bb0: { -/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; -/// } -/// -/// bb1: { -/// unreachable; -/// } -/// -/// bb2: { -/// _0 = const 1_i16; -/// goto -> bb5; -/// } -/// -/// bb3: { -/// _0 = const 2_i16; -/// goto -> bb5; -/// } -/// -/// bb4: { -/// _0 = const 3_i16; -/// goto -> bb5; -/// } -/// ``` -/// -/// into: -/// -/// ```ignore (MIR) -/// bb0: { -/// _0 = _3 as i16 (IntToInt); -/// goto -> bb5; -/// } -/// ``` -impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { - #[instrument(level = "debug", skip(self, tcx), ret)] - fn can_simplify( - &mut self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - bbs: &IndexSlice>, - discr_ty: Ty<'tcx>, - ) -> Option<()> { - if targets.iter().len() < 2 || targets.iter().len() > 64 { - return None; - } - // We require that the possible target blocks all be distinct. - if !targets.is_distinct() { - return None; - } - if !bbs[targets.otherwise()].is_empty_unreachable() { - return None; - } - let mut target_iter = targets.iter(); - let (first_case_val, first_target) = target_iter.next().unwrap(); - let first_terminator_kind = &bbs[first_target].terminator().kind; - // Check that destinations are identical, and if not, then don't optimize this block - if !targets - .iter() - .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) - { +fn candidate_assign<'tcx, 'a>( + stmts: &'a [(u128, &'a StatementKind<'tcx>)], + otherwise: Option<&'a StatementKind<'tcx>>, +) -> Option<(Place<'tcx>, Vec<(u128, &'a Rvalue<'tcx>)>, Option<&'a Rvalue<'tcx>>)> { + let (_, first_stmt) = stmts[0]; + let (dest, _) = first_stmt.as_assign()?; + let otherwise = if let Some(otherwise) = otherwise { + let Some((otherwise_dest, rval)) = otherwise.as_assign() else { return None; - } - - let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap(); - let first_stmts = &bbs[first_target].statements; - let (second_case_val, second_target) = target_iter.next().unwrap(); - let second_stmts = &bbs[second_target].statements; - if first_stmts.len() != second_stmts.len() { + }; + if otherwise_dest != dest { return None; } - - // We first compare the two branches, and then the other branches need to fulfill the same - // conditions. - let mut expected_transform_kinds = Vec::new(); - for (f, s) in iter::zip(first_stmts, second_stmts) { - let compare_type = match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s), - - // If two statements are assignments with the match values to the same place, we - // can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty() == s_c.const_.ty() - && f_c.const_.ty().is_integral() => - { - match ( - f_c.const_.try_eval_scalar_int(tcx, typing_env), - s_c.const_.try_eval_scalar_int(tcx, typing_env), - ) { - (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq { - place: lhs_f, - ty: f_c.const_.ty(), - scalar: f, - }, - // Enum variants can also be simplified to an assignment statement, - // if we can use `IntToInt` cast to get an equal value. - (Some(f), Some(s)) - if (can_cast( - tcx, - first_case_val, - discr_layout, - f_c.const_.ty(), - f, - ) && can_cast( - tcx, - second_case_val, - discr_layout, - f_c.const_.ty(), - s, - )) => - { - ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() } - } - _ => { - return None; - } - } - } - - // Otherwise we cannot optimize. Try another block. - _ => return None, - }; - expected_transform_kinds.push(compare_type); - } - - // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step. - for (other_val, other_target) in target_iter { - let other_stmts = &bbs[other_target].statements; - if expected_transform_kinds.len() != other_stmts.len() { + Some(rval) + } else { + None + }; + let rvals = stmts + .into_iter() + .map(|&(case, stmt)| { + let (other_dest, rval) = stmt.as_assign()?; + if other_dest != dest { return None; } - for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) { - match (*f, &s.kind) { - (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {} - ( - ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar }, - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && s_c.const_.ty() == f_ty - && s_c.const_.try_eval_scalar_int(tcx, typing_env) == Some(scalar) => {} - ( - ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty }, - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, typing_env) - && lhs_f == lhs_s - && s_c.const_.ty() == f_ty - && can_cast(tcx, other_val, discr_layout, f_ty, f) => {} - _ => return None, - } - } - } - self.transform_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect(); - Some(()) - } + Some((case, rval)) + }) + .try_collect()?; + Some((*dest, rvals, otherwise)) +} - fn new_stmts( - &self, - _tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - _typing_env: ty::TypingEnv<'tcx>, - patch: &mut MirPatch<'tcx>, - parent_end: Location, - bbs: &IndexSlice>, - discr_local: Local, - discr_ty: Ty<'tcx>, - ) { - let (_, first) = targets.iter().next().unwrap(); - let first = &bbs[first]; +// Returns all ConstOperands if all Rvalues are ConstOperands. +fn candidate_const<'tcx, 'a>( + rvals: &'a [(u128, &'a Rvalue<'tcx>)], + otherwise: Option<&'a Rvalue<'tcx>>, +) -> Option<(Vec<(u128, &'a ConstOperand<'tcx>)>, Option<&'a ConstOperand<'tcx>>)> { + let otherwise = if let Some(otherwise) = otherwise { + let Rvalue::Use(Operand::Constant(box const_)) = otherwise else { + return None; + }; + Some(const_) + } else { + None + }; + let consts = rvals + .into_iter() + .map(|&(case, rval)| { + let Rvalue::Use(Operand::Constant(box const_)) = rval else { return None }; + Some((case, const_)) + }) + .try_collect()?; + Some((consts, otherwise)) +} - for (t, s) in iter::zip(&self.transform_kinds, &first.statements) { - match (t, &s.kind) { - (TransformKind::Same, _) => { - patch.add_statement(parent_end, s.kind.clone()); - } - ( - TransformKind::Cast, - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - ) => { - let operand = Operand::Copy(Place::from(discr_local)); - let r_val = if f_c.const_.ty() == discr_ty { - Rvalue::Use(operand) - } else { - Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) - }; - patch.add_assign(parent_end, *lhs, r_val); - } - _ => unreachable!(), - } - } +// Returns the first case and others (including otherwise if present). +fn split_first_case<'a, T>( + stmts: &'a [(u128, &'a T)], + otherwise: Option<&'a T>, +) -> (u128, &'a T, impl Iterator) { + let (first_case, first) = stmts[0]; + (first_case, first, stmts[1..].into_iter().map(|&(_, val)| val).chain(otherwise)) +} + +// If all statements are identical, we can optimize. +fn identical_stmts<'tcx>( + stmts: &[(u128, &StatementKind<'tcx>)], + otherwise: Option<&StatementKind<'tcx>>, +) -> Option> { + use itertools::Itertools; + let (_, first_stmt, others) = split_first_case(stmts, otherwise); + if std::iter::once(first_stmt).chain(others).all_equal() { + return Some(first_stmt.clone()); } + None } diff --git a/tests/mir-opt/matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..a896f66866e62 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff @@ -0,0 +1,49 @@ +- // MIR for `match_eq_bool` before MatchBranchSimplification ++ // MIR for `match_eq_bool` after MatchBranchSimplification + + fn match_eq_bool(_1: i32) -> bool { + debug i => _1; + let mut _0: bool; + let _2: bool; + let _3: (); ++ let mut _4: i32; + scope 1 { + debug a => _2; + } + + bb0: { + StorageLive(_2); + StorageLive(_3); +- switchInt(copy _1) -> [7: bb3, 8: bb2, otherwise: bb1]; +- } +- +- bb1: { +- _2 = const true; ++ StorageLive(_4); ++ _4 = copy _1; ++ _2 = Ne(copy _4, const 7_i32); + _3 = (); +- goto -> bb4; +- } +- +- bb2: { +- _2 = const true; +- _3 = (); +- goto -> bb4; +- } +- +- bb3: { +- _2 = const false; +- _3 = (); +- goto -> bb4; +- } +- +- bb4: { ++ StorageDead(_4); + StorageDead(_3); + _0 = copy _2; + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..3ed8e13304172 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff @@ -0,0 +1,44 @@ +- // MIR for `match_eq_bool_2` before MatchBranchSimplification ++ // MIR for `match_eq_bool_2` after MatchBranchSimplification + + fn match_eq_bool_2(_1: i32) -> bool { + debug i => _1; + let mut _0: bool; + let _2: bool; + let _3: (); + scope 1 { + debug a => _2; + } + + bb0: { + StorageLive(_2); + StorageLive(_3); + switchInt(copy _1) -> [7: bb3, 8: bb2, otherwise: bb1]; + } + + bb1: { + _2 = const true; + _3 = (); + goto -> bb4; + } + + bb2: { + _2 = const false; + _3 = (); + goto -> bb4; + } + + bb3: { + _2 = const false; + _3 = (); + goto -> bb4; + } + + bb4: { + StorageDead(_3); + _0 = copy _2; + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index 89ef3bfb30857..df887d246a119 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -81,6 +81,54 @@ fn match_nested_if() -> bool { val } +// EMIT_MIR matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff +fn match_eq_bool(i: i32) -> bool { + // CHECK-LABEL: fn match_eq_bool( + // CHECK: = Ne( + // CHECK-NOT: switchInt + // CHECK: return + let a; + match i { + 7 => { + a = false; + () + } + 8 => { + a = true; + () + } + _ => { + a = true; + () + } + }; + a +} + +// EMIT_MIR matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff +fn match_eq_bool_2(i: i32) -> bool { + // CHECK-LABEL: fn match_eq_bool_2( + // CHECK-NOT: = Ne( + // CHECK: switchInt + // CHECK: return + let a; + match i { + 7 => { + a = false; + () + } + 8 => { + a = false; + () + } + _ => { + a = true; + () + } + }; + a +} + // # Fold switchInt into IntToInt. // To simplify writing and checking these test cases, I use the first character of // each case to distinguish the sign of the number: @@ -627,6 +675,29 @@ fn match_i128_u128(i: EnumAi128) -> u128 { } } +// EMIT_MIR matches_reduce_branches.single_case.MatchBranchSimplification.diff +#[custom_mir(dialect = "runtime")] +fn single_case(i: Option) -> i32 { + // CHECK-LABEL: fn single_case( + // CHECK-NOT: switchInt + mir! { + { + let discr = Discriminant(i); + match discr { + 0 => none, + _ => unreachable_bb, + } + } + none = { + RET = 1; + Return() + } + unreachable_bb = { + Unreachable() + } + } +} + // EMIT_MIR matches_reduce_branches.match_non_int_failed.MatchBranchSimplification.diff #[custom_mir(dialect = "runtime")] fn match_non_int_failed(i: char) -> u8 { diff --git a/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..54726613a3a55 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff @@ -0,0 +1,22 @@ +- // MIR for `single_case` before MatchBranchSimplification ++ // MIR for `single_case` after MatchBranchSimplification + + fn single_case(_1: Option) -> i32 { + let mut _0: i32; + let mut _2: isize; + + bb0: { + _2 = discriminant(_1); +- switchInt(copy _2) -> [0: bb1, otherwise: bb2]; +- } +- +- bb1: { + _0 = const 1_i32; + return; +- } +- +- bb2: { +- unreachable; + } + } + diff --git a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff index ff1bc58524bc2..dd21719adb656 100644 --- a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff +++ b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff @@ -10,7 +10,6 @@ let mut _5: isize; - let mut _7: bool; - let mut _8: u8; -- let mut _9: bool; scope 1 { debug a => _6; let _6: u8; diff --git a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff index 2c289c664754a..6e50b615030f9 100644 --- a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff +++ b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff @@ -10,7 +10,6 @@ let mut _5: isize; - let mut _7: bool; - let mut _8: u8; -- let mut _9: bool; scope 1 { debug a => _6; let _6: u8; From adf64bd305166b3c8848c44e62c83978cca3b69c Mon Sep 17 00:00:00 2001 From: dianqk Date: Sun, 16 Nov 2025 23:10:38 +0800 Subject: [PATCH 2/4] Simplify the canonical enum clone branches to a copy statement --- .../rustc_mir_transform/src/match_branches.rs | 177 ++++++++++--- ...unchecked.PreCodegen.after.panic-abort.mir | 12 - ...nchecked.PreCodegen.after.panic-unwind.mir | 12 - ...atch_option.MatchBranchSimplification.diff | 32 +++ ...option2_mut.MatchBranchSimplification.diff | 39 +++ tests/mir-opt/matches_reduce_branches.rs | 60 +++++ tests/mir-opt/pre-codegen/copy_and_clone.rs | 250 ++++++++++++++++++ ...witch_targets.ub_if_b.PreCodegen.after.mir | 10 - ....two_unwrap_unchecked.PreCodegen.after.mir | 20 +- 9 files changed, 524 insertions(+), 88 deletions(-) create mode 100644 tests/mir-opt/matches_reduce_branches.match_option.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/pre-codegen/copy_and_clone.rs diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index d5b65958fdde3..ed7734944bcc6 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,12 +1,14 @@ use rustc_abi::Integer; +use rustc_const_eval::const_eval::mk_eval_cx_for_const_val; use rustc_middle::mir::*; use rustc_middle::ty::layout::{IntegerExt, TyAndLayout}; +use rustc_middle::ty::util::Discr; use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; use super::simplify::simplify_cfg; use crate::patch::MirPatch; -/// Merges all targets into one basic block if each statement can have the same statement. +/// Unifies all targets into one basic block if each statement can have the same statement. pub(super) struct MatchBranchSimplification; impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { @@ -40,6 +42,7 @@ struct SimplifyMatch<'tcx, 'a> { patch: MirPatch<'tcx>, body: &'a Body<'tcx>, switch_bb: BasicBlock, + discr: &'a Operand<'tcx>, discr_local: Option, discr_ty: Ty<'tcx>, } @@ -53,8 +56,8 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { }) } - /// Merges the assignments if all rvalues are constants and equal. - fn merge_if_equal_const( + /// Unifies the assignments if all rvalues are constants and equal. + fn unify_if_equal_const( &self, dest: Place<'tcx>, consts: &[(u128, &ConstOperand<'tcx>)], @@ -76,7 +79,7 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { /// If a source block is found that switches between two blocks that are exactly /// the same modulo const bool assignments (e.g., one assigns true another false - /// to the same place), merge a target block statements into the source block, + /// to the same place), unify a target block statements into the source block, /// using Eq / Ne comparison with switch value where const bools value differ. /// /// For example: @@ -105,7 +108,7 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { /// goto -> bb3; /// } /// ``` - fn merge_by_eq_op( + fn unify_by_eq_op( &mut self, dest: Place<'tcx>, consts: &[(u128, &ConstOperand<'tcx>)], @@ -140,7 +143,7 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { } } - /// Merges the assignments if all rvalues can be cast from the discriminant value by IntToInt. + /// Unifies the assignments if all rvalues can be cast from the discriminant value by IntToInt. /// /// For example: /// @@ -177,7 +180,7 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { /// goto -> bb5; /// } /// ``` - fn merge_by_int_to_int( + fn unify_by_int_to_int( &mut self, dest: Place<'tcx>, consts: &[(u128, &ConstOperand<'tcx>)], @@ -207,10 +210,91 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { } } + /// This is primarily used to unify these copy statements that simplified the canonical enum clone method by GVN. + /// The GVN simplified + /// ```ignore (syntax-highlighting-only) + /// match a { + /// Foo::A(x) => Foo::A(*x), + /// Foo::B => Foo::B + /// } + /// ``` + /// to + /// ```ignore (syntax-highlighting-only) + /// match a { + /// Foo::A(_x) => a, // copy a + /// Foo::B => Foo::B + /// } + /// ``` + /// This will simplify into a copy statement. + fn unify_by_copy( + &self, + dest: Place<'tcx>, + rvals: &[(u128, &Rvalue<'tcx>)], + ) -> Option> { + let bbs = &self.body.basic_blocks; + // Check if the copy source matches the following pattern. + // _2 = discriminant(*_1); // "*_1" is the expected the copy source. + // switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; + let &Statement { + kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))), + .. + } = bbs[self.switch_bb].statements.last()? + else { + return None; + }; + if self.discr.place() != Some(discr_place) { + return None; + } + let src_ty = copy_src_place.ty(self.body.local_decls(), self.tcx); + if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() { + return None; + } + let dest_ty = dest.ty(self.body.local_decls(), self.tcx); + if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() { + return None; + } + let ty::Adt(def, _) = dest_ty.ty.kind() else { + return None; + }; + + for &(case, rvalue) in rvals.iter() { + match rvalue { + // Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`. + Rvalue::Use(Operand::Constant(box constant)) + if let Const::Val(const_, ty) = constant.const_ => + { + let (ecx, op) = mk_eval_cx_for_const_val( + self.tcx.at(constant.span), + self.typing_env, + const_, + ty, + )?; + let variant = ecx.read_discriminant(&op).discard_err()?; + if !def.variants()[variant].fields.is_empty() { + return None; + } + let Discr { val, .. } = ty.discriminant_for_variant(self.tcx, variant)?; + if val != case { + return None; + } + } + Rvalue::Use(Operand::Copy(src_place)) if *src_place == copy_src_place => {} + // Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`. + Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields) + if fields.is_empty() + && let Some(Discr { val, .. }) = + src_ty.ty.discriminant_for_variant(self.tcx, *variant_index) + && val == case => {} + _ => return None, + } + } + Some(StatementKind::Assign(Box::new((dest, Rvalue::Use(Operand::Copy(copy_src_place)))))) + } + /// Returns a new statement if we can use the statement replace all statements. - fn try_merge_stmts( + fn try_unify_stmts( &mut self, - _index: usize, + index: usize, stmts: &[(u128, &StatementKind<'tcx>)], otherwise: Option<&StatementKind<'tcx>>, ) -> Option> { @@ -220,25 +304,37 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?; if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) { - if let Some(new_stmt) = self.merge_if_equal_const(dest, &consts, otherwise) { + if let Some(new_stmt) = self.unify_if_equal_const(dest, &consts, otherwise) { return Some(new_stmt); } - if let Some(new_stmt) = self.merge_by_eq_op(dest, &consts, otherwise) { + if let Some(new_stmt) = self.unify_by_eq_op(dest, &consts, otherwise) { return Some(new_stmt); } // Requires the otherwise is unreachable. if otherwise.is_none() - && let Some(new_stmt) = self.merge_by_int_to_int(dest, &consts) + && let Some(new_stmt) = self.unify_by_int_to_int(dest, &consts) { return Some(new_stmt); } } + + // We only know the first statement is safe to introduce new dereferences. + if index == 0 + // We cannot create overlapping assignments. + && dest.is_stable_offset() + // Requires the otherwise is unreachable. + && otherwise.is_none() + && let Some(new_stmt) = self.unify_by_copy(dest, &rvals) + { + return Some(new_stmt); + } None } } /// Returns the first case target if all targets have an equal number of statements and identical destination. fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool { + use itertools::Itertools; let targets = match &body.basic_blocks[switch_bb].terminator().kind { TerminatorKind::SwitchInt { discr: Operand::Copy(_) | Operand::Move(_), targets, .. @@ -254,21 +350,14 @@ fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool { if !targets.is_distinct() { return false; } - let &[first, ref others @ .., otherwise] = targets.all_targets() else { - return false; - }; - let first_case_bb = &body.basic_blocks[first]; - let first_case_terminator_kind = &first_case_bb.terminator().kind; - let first_case_stmts_len = first_case_bb.statements.len(); - - let otherwise = - if body.basic_blocks[otherwise].is_empty_unreachable() { None } else { Some(&otherwise) }; // Check that destinations are identical, and if not, then don't optimize this block - others.iter().chain(otherwise).all(|&bb| { - let bb = &body.basic_blocks[bb]; - first_case_stmts_len == bb.statements.len() - && first_case_terminator_kind == &bb.terminator().kind - }) + targets + .all_targets() + .iter() + .map(|&bb| &body.basic_blocks[bb]) + .filter(|bb| !bb.is_empty_unreachable()) + .map(|bb| (bb.statements.len(), &bb.terminator().kind)) + .all_equal() } fn simplify_match<'tcx>( @@ -287,31 +376,41 @@ fn simplify_match<'tcx>( patch: MirPatch::new(body), body, switch_bb, + discr, discr_local: None, discr_ty: discr.ty(body.local_decls(), tcx), }; - let stmts: Vec<_> = targets - .iter() - .map(|(case, bb)| (case, simplify_match.body.basic_blocks[bb].statements.as_slice())) - .collect(); + let reachable_cases: Vec<_> = + targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect(); let mut new_stmts = Vec::new(); - let otherwise_stmts = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() { + let otherwise = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() { None } else { - Some(body.basic_blocks[targets.otherwise()].statements.as_slice()) + Some(targets.otherwise()) + }; + // We can patch the terminator to goto because there is a single target. + match (&reachable_cases[..], otherwise) { + (&[(_, single_target)], None) | (&[], Some(single_target)) => { + let mut patch = simplify_match.patch; + patch.patch_terminator(switch_bb, TerminatorKind::Goto { target: single_target }); + patch.apply(body); + return true; + } + _ => {} + } + let Some(&(_, first_case_bb)) = reachable_cases.first() else { + return false; }; - let first_case_bb = targets.all_targets()[0]; let stmt_len = body.basic_blocks[first_case_bb].statements.len(); let mut cases = Vec::with_capacity(stmt_len); - // Check at each position in the basic blocks whether these statements can be merged. + // Check at each position in the basic blocks whether these statements can be unified. for index in 0..stmt_len { - let otherwise = otherwise_stmts.map(|stmt| &stmt[index].kind); cases.clear(); - for &(case, stmts) in &stmts { - cases.push((case, &stmts[index].kind)); + let otherwise = otherwise.map(|bb| &body.basic_blocks[bb].statements[index].kind); + for &(case, bb) in &reachable_cases { + cases.push((case, &body.basic_blocks[bb].statements[index].kind)); } - let Some(new_stmt) = simplify_match.try_merge_stmts(index, cases.as_slice(), otherwise) - else { + let Some(new_stmt) = simplify_match.try_unify_stmts(index, &cases, otherwise) else { return false; }; new_stmts.push(new_stmt); diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir index b7b892c177c3e..2f1f925112b78 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir @@ -4,7 +4,6 @@ fn unwrap_unchecked(_1: Option) -> T { debug slf => _1; let mut _0: T; scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { - let mut _2: isize; scope 2 { } scope 3 (inlined #[track_caller] unreachable_unchecked) { @@ -16,18 +15,7 @@ fn unwrap_unchecked(_1: Option) -> T { } bb0: { - StorageLive(_2); - _2 = discriminant(_1); - switchInt(move _2) -> [0: bb2, 1: bb1, otherwise: bb2]; - } - - bb1: { _0 = copy ((_1 as Some).0: T); - StorageDead(_2); return; } - - bb2: { - unreachable; - } } diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir index b7b892c177c3e..2f1f925112b78 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir @@ -4,7 +4,6 @@ fn unwrap_unchecked(_1: Option) -> T { debug slf => _1; let mut _0: T; scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { - let mut _2: isize; scope 2 { } scope 3 (inlined #[track_caller] unreachable_unchecked) { @@ -16,18 +15,7 @@ fn unwrap_unchecked(_1: Option) -> T { } bb0: { - StorageLive(_2); - _2 = discriminant(_1); - switchInt(move _2) -> [0: bb2, 1: bb1, otherwise: bb2]; - } - - bb1: { _0 = copy ((_1 as Some).0: T); - StorageDead(_2); return; } - - bb2: { - unreachable; - } } diff --git a/tests/mir-opt/matches_reduce_branches.match_option.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_option.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..76148eb9bd436 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_option.MatchBranchSimplification.diff @@ -0,0 +1,32 @@ +- // MIR for `match_option` before MatchBranchSimplification ++ // MIR for `match_option` after MatchBranchSimplification + + fn match_option(_1: &Option) -> Option { + debug i => _1; + let mut _0: std::option::Option; + let mut _2: isize; + + bb0: { + _2 = discriminant((*_1)); +- switchInt(move _2) -> [0: bb2, 1: bb3, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = Option::::None; +- goto -> bb4; +- } +- +- bb3: { + _0 = copy (*_1); +- goto -> bb4; +- } +- +- bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..e6f273d086417 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff @@ -0,0 +1,39 @@ +- // MIR for `match_option2_mut` before MatchBranchSimplification ++ // MIR for `match_option2_mut` after MatchBranchSimplification + + fn match_option2_mut(_1: &mut Option2) -> Option2 { + let mut _0: Option2; + let mut _2: isize; + + bb0: { + _2 = discriminant((*_1)); + switchInt(copy _2) -> [0: bb1, 1: bb2, 2: bb3, otherwise: bb4]; + } + + bb1: { + (*_1) = Option2::::None2; + _0 = Option2::::None1; + goto -> bb5; + } + + bb2: { + (*_1) = Option2::::None2; + _0 = Option2::::None2; + goto -> bb5; + } + + bb3: { + (*_1) = Option2::::None2; + _0 = copy (*_1); + goto -> bb5; + } + + bb4: { + unreachable; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index df887d246a119..1766e77cf664b 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -675,6 +675,23 @@ fn match_i128_u128(i: EnumAi128) -> u128 { } } +// EMIT_MIR matches_reduce_branches.match_option.MatchBranchSimplification.diff +fn match_option(i: &Option) -> Option { + // CHECK-LABEL: fn match_option( + // CHECK-NOT: switchInt + // CHECK: _0 = copy (*_1); + match i { + Some(_) => *i, + None => None, + } +} + +enum Option2 { + None1, + None2, + Some(T), +} + // EMIT_MIR matches_reduce_branches.single_case.MatchBranchSimplification.diff #[custom_mir(dialect = "runtime")] fn single_case(i: Option) -> i32 { @@ -698,6 +715,47 @@ fn single_case(i: Option) -> i32 { } } +// We cannot dereference `i` after the value has been changed. +// EMIT_MIR matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff +#[custom_mir(dialect = "runtime")] +fn match_option2_mut(i: &mut Option2) -> Option2 { + // CHECK-LABEL: fn match_option2_mut( + // CHECK: switchInt + // CHECK: return + mir! { + { + let discr = Discriminant(*i); + match discr { + 0 => none1_bb, + 1 => none2_bb, + 2 => some_bb, + _ => unreachable_bb, + } + } + none1_bb = { + *i = Option2::None2; + RET = Option2::None1; + Goto(ret) + } + none2_bb = { + *i = Option2::None2; + RET = Option2::None2; + Goto(ret) + } + some_bb = { + *i = Option2::None2; + RET = *i; + Goto(ret) + } + unreachable_bb = { + Unreachable() + } + ret = { + Return() + } + } +} + // EMIT_MIR matches_reduce_branches.match_non_int_failed.MatchBranchSimplification.diff #[custom_mir(dialect = "runtime")] fn match_non_int_failed(i: char) -> u8 { @@ -767,4 +825,6 @@ fn main() { let _ = my_is_some(None); let _ = match_non_int_failed('a'); + let _ = match_option(&None); + let _ = match_option2_mut(&mut Option2::None1); } diff --git a/tests/mir-opt/pre-codegen/copy_and_clone.rs b/tests/mir-opt/pre-codegen/copy_and_clone.rs new file mode 100644 index 0000000000000..05da25afa2a39 --- /dev/null +++ b/tests/mir-opt/pre-codegen/copy_and_clone.rs @@ -0,0 +1,250 @@ +//@ [COPY] compile-flags: --cfg=copy +//@ revisions: COPY CLONE + +// Test case from https://github.com/rust-lang/rust/issues/128081. +// Ensure both Copy and Clone get optimized copy. + +#[unsafe(no_mangle)] +pub fn intra_clone(intra: &Av1BlockIntra) -> Av1BlockIntraInter { + // CHECK-LABEL: fn intra_clone( + // CHECK: [[C:_.*]] = copy (*_1); + // CHECK: _0 = Av1BlockIntraInter::Intra(move [[C]]); + Av1BlockIntraInter::Intra(intra.clone()) +} + +#[unsafe(no_mangle)] +pub fn inter_clone(inter: &Av1BlockInter) -> Av1BlockIntraInter { + // CHECK-LABEL: fn inter_clone( + // CHECK: [[C:_.*]] = copy (*_1); + // CHECK: _0 = Av1BlockIntraInter::Inter(move [[C]]); + Av1BlockIntraInter::Inter(inter.clone()) +} + +#[unsafe(no_mangle)] +pub fn dav1dsequenceheader_copy(v: &Dav1dSequenceHeader) -> Dav1dSequenceHeader { + // CHECK-LABEL: fn dav1dsequenceheader_copy( + // CHECK: _0 = copy (*_1); + v.clone() +} + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct mv { + pub y: i16, + pub x: i16, +} + +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct MaskedInterIntraPredMode(u8); + +#[derive(Clone)] +#[cfg_attr(copy, derive(Copy))] +#[repr(C)] +pub struct Av1BlockInter1d { + pub mv: [mv; 2], + pub wedge_idx: u8, + pub mask_sign: u8, + pub interintra_mode: MaskedInterIntraPredMode, + pub _padding: u8, +} + +#[derive(Clone)] +#[cfg_attr(copy, derive(Copy))] +#[repr(C)] +pub struct Av1BlockInterNd { + pub one_d: Av1BlockInter1d, +} + +#[derive(Clone, Copy)] +pub enum CompInterType { + WeightedAvg = 1, + Avg = 2, + Seg = 3, + Wedge = 4, +} + +#[derive(Clone, Copy)] +pub enum MotionMode { + Translation = 0, + Obmc = 1, + Warp = 2, +} + +#[derive(Clone, Copy)] +pub enum DrlProximity { + Nearest, + Nearer, + Near, + Nearish, +} + +#[derive(Clone, Copy)] +pub enum TxfmSize { + S4x4 = 0, + S8x8 = 1, + S16x16 = 2, + S32x32 = 3, + S64x64 = 4, + R4x8 = 5, + R8x4 = 6, + R8x16 = 7, + R16x8 = 8, + R16x32 = 9, + R32x16 = 10, + R32x64 = 11, + R64x32 = 12, + R4x16 = 13, + R16x4 = 14, + R8x32 = 15, + R32x8 = 16, + R16x64 = 17, + R64x16 = 18, +} + +#[derive(Clone, Copy)] +pub enum Filter2d { + Regular8Tap = 0, + RegularSmooth8Tap = 1, + RegularSharp8Tap = 2, + SharpRegular8Tap = 3, + SharpSmooth8Tap = 4, + Sharp8Tap = 5, + SmoothRegular8Tap = 6, + Smooth8Tap = 7, + SmoothSharp8Tap = 8, + Bilinear = 9, +} + +#[derive(Clone, Copy)] +pub enum InterIntraType { + Blend, + Wedge, +} + +#[cfg_attr(copy, derive(Copy))] +#[derive(Clone)] +#[repr(C)] +pub struct Av1BlockInter { + pub nd: Av1BlockInterNd, + pub comp_type: Option, + pub inter_mode: u8, + pub motion_mode: MotionMode, + pub drl_idx: DrlProximity, + pub r#ref: [i8; 2], + pub max_ytx: TxfmSize, + pub filter2d: Filter2d, + pub interintra_type: Option, + pub tx_split0: u8, + pub tx_split1: u16, +} + +#[cfg_attr(copy, derive(Copy))] +#[derive(Clone)] +#[repr(C)] +pub struct Av1BlockIntra { + pub y_mode: u8, + pub uv_mode: u8, + pub tx: TxfmSize, + pub pal_sz: [u8; 2], + pub y_angle: i8, + pub uv_angle: i8, + pub cfl_alpha: [i8; 2], +} + +#[repr(C)] +pub enum Av1BlockIntraInter { + Intra(Av1BlockIntra), + Inter(Av1BlockInter), +} + +use std::ffi::{c_int, c_uint}; + +pub type Dav1dPixelLayout = c_uint; +pub type Dav1dColorPrimaries = c_uint; +pub type Dav1dTransferCharacteristics = c_uint; +pub type Dav1dMatrixCoefficients = c_uint; +pub type Dav1dChromaSamplePosition = c_uint; +pub type Dav1dAdaptiveBoolean = c_uint; + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct Dav1dSequenceHeaderOperatingPoint { + pub major_level: u8, + pub minor_level: u8, + pub initial_display_delay: u8, + pub idc: u16, + pub tier: u8, + pub decoder_model_param_present: u8, + pub display_model_param_present: u8, +} + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct Dav1dSequenceHeaderOperatingParameterInfo { + pub decoder_buffer_delay: u32, + pub encoder_buffer_delay: u32, + pub low_delay_mode: u8, +} + +pub const DAV1D_MAX_OPERATING_POINTS: usize = 32; + +#[cfg_attr(copy, derive(Copy))] +#[derive(Clone)] +#[repr(C)] +pub struct Dav1dSequenceHeader { + pub profile: u8, + pub max_width: c_int, + pub max_height: c_int, + pub layout: Dav1dPixelLayout, + pub pri: Dav1dColorPrimaries, + pub trc: Dav1dTransferCharacteristics, + pub mtrx: Dav1dMatrixCoefficients, + pub chr: Dav1dChromaSamplePosition, + pub hbd: u8, + pub color_range: u8, + pub num_operating_points: u8, + pub operating_points: [Dav1dSequenceHeaderOperatingPoint; DAV1D_MAX_OPERATING_POINTS], + pub still_picture: u8, + pub reduced_still_picture_header: u8, + pub timing_info_present: u8, + pub num_units_in_tick: u32, + pub time_scale: u32, + pub equal_picture_interval: u8, + pub num_ticks_per_picture: u32, + pub decoder_model_info_present: u8, + pub encoder_decoder_buffer_delay_length: u8, + pub num_units_in_decoding_tick: u32, + pub buffer_removal_delay_length: u8, + pub frame_presentation_delay_length: u8, + pub display_model_info_present: u8, + pub width_n_bits: u8, + pub height_n_bits: u8, + pub frame_id_numbers_present: u8, + pub delta_frame_id_n_bits: u8, + pub frame_id_n_bits: u8, + pub sb128: u8, + pub filter_intra: u8, + pub intra_edge_filter: u8, + pub inter_intra: u8, + pub masked_compound: u8, + pub warped_motion: u8, + pub dual_filter: u8, + pub order_hint: u8, + pub jnt_comp: u8, + pub ref_frame_mvs: u8, + pub screen_content_tools: Dav1dAdaptiveBoolean, + pub force_integer_mv: Dav1dAdaptiveBoolean, + pub order_hint_n_bits: u8, + pub super_res: u8, + pub cdef: u8, + pub restoration: u8, + pub ss_hor: u8, + pub ss_ver: u8, + pub monochrome: u8, + pub color_description_present: u8, + pub separate_uv_delta_q: u8, + pub film_grain_present: u8, + pub operating_parameter_info: + [Dav1dSequenceHeaderOperatingParameterInfo; DAV1D_MAX_OPERATING_POINTS], +} diff --git a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir index 8a6732d5f745a..581665773272b 100644 --- a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir +++ b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir @@ -3,7 +3,6 @@ fn ub_if_b(_1: Thing) -> Thing { debug t => _1; let mut _0: Thing; - let mut _2: isize; scope 1 (inlined #[track_caller] unreachable_unchecked) { scope 2 (inlined core::ub_checks::check_language_ub) { scope 3 (inlined core::ub_checks::check_language_ub::runtime) { @@ -12,16 +11,7 @@ fn ub_if_b(_1: Thing) -> Thing { } bb0: { - _2 = discriminant(_1); - switchInt(move _2) -> [0: bb1, 1: bb2, otherwise: bb2]; - } - - bb1: { _0 = move _1; return; } - - bb2: { - unreachable; - } } diff --git a/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir index b2b7f88d8534b..a246104fe5ee7 100644 --- a/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir +++ b/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir @@ -4,11 +4,11 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { debug v => _1; let mut _0: i32; let mut _2: std::option::Option; - let _4: i32; + let _3: i32; scope 1 { - debug v1 => _4; + debug v1 => _3; scope 2 { - debug v2 => _4; + debug v2 => _3; } scope 8 (inlined #[track_caller] Option::::unwrap_unchecked) { scope 9 { @@ -22,7 +22,6 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { } } scope 3 (inlined #[track_caller] Option::::unwrap_unchecked) { - let mut _3: isize; scope 4 { } scope 5 (inlined #[track_caller] unreachable_unchecked) { @@ -35,17 +34,8 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { bb0: { _2 = copy (*_1); - _3 = discriminant(_2); - switchInt(copy _3) -> [0: bb2, 1: bb1, otherwise: bb2]; - } - - bb1: { - _4 = copy ((_2 as Some).0: i32); - _0 = Add(copy _4, copy _4); + _3 = copy ((_2 as Some).0: i32); + _0 = Add(copy _3, copy _3); return; } - - bb2: { - unreachable; - } } From 05f417cc46b835563700c02b971a8478b4f289f7 Mon Sep 17 00:00:00 2001 From: dianqk Date: Fri, 23 Jan 2026 18:36:27 +0800 Subject: [PATCH 3/4] Add assume to single target --- .../rustc_mir_transform/src/match_branches.rs | 9 ++++++--- .../rustc_mir_transform/src/unreachable_prop.rs | 14 +++++++------- .../issues/issue-107681-unwrap_unchecked.rs | 1 + .../issue-122600-ptr-discriminant-update.rs | 2 +- ...rap_unchecked.PreCodegen.after.panic-abort.mir | 7 +++++++ ...ap_unchecked.PreCodegen.after.panic-unwind.mir | 7 +++++++ ...hes.single_case.MatchBranchSimplification.diff | 3 +++ ...te_switch_targets.ub_if_b.PreCodegen.after.mir | 5 +++++ ...cked.two_unwrap_unchecked.PreCodegen.after.mir | 15 ++++++++++----- 9 files changed, 47 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index ed7734944bcc6..d72be74bc4594 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -7,6 +7,7 @@ use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; use super::simplify::simplify_cfg; use crate::patch::MirPatch; +use crate::unreachable_prop::remove_successors_from_switch; /// Unifies all targets into one basic block if each statement can have the same statement. pub(super) struct MatchBranchSimplification; @@ -389,10 +390,12 @@ fn simplify_match<'tcx>( Some(targets.otherwise()) }; // We can patch the terminator to goto because there is a single target. - match (&reachable_cases[..], otherwise) { - (&[(_, single_target)], None) | (&[], Some(single_target)) => { + match (reachable_cases.len(), otherwise.is_none()) { + (1, true) | (0, false) => { let mut patch = simplify_match.patch; - patch.patch_terminator(switch_bb, TerminatorKind::Goto { target: single_target }); + remove_successors_from_switch(tcx, switch_bb, body, &mut patch, |bb| { + body.basic_blocks[bb].is_empty_unreachable() + }); patch.apply(body); return true; } diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index c417a9272f2a9..ddc33eafc9138 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -35,7 +35,9 @@ impl crate::MirPass<'_> for UnreachablePropagation { } // Try to remove unreachable targets from the switch. TerminatorKind::SwitchInt { .. } => { - remove_successors_from_switch(tcx, bb, &unreachable_blocks, body, &mut patch) + remove_successors_from_switch(tcx, bb, body, &mut patch, |bb| { + unreachable_blocks.contains(&bb) + }) } _ => false, }; @@ -60,20 +62,18 @@ impl crate::MirPass<'_> for UnreachablePropagation { } /// Return whether the current terminator is fully unreachable. -fn remove_successors_from_switch<'tcx>( +pub(crate) fn remove_successors_from_switch<'tcx>( tcx: TyCtxt<'tcx>, bb: BasicBlock, - unreachable_blocks: &FxHashSet, body: &Body<'tcx>, patch: &mut MirPatch<'tcx>, + is_unreachable_block: impl Fn(BasicBlock) -> bool, ) -> bool { let terminator = body.basic_blocks[bb].terminator(); let TerminatorKind::SwitchInt { discr, targets } = &terminator.kind else { bug!() }; let source_info = terminator.source_info; let location = body.terminator_loc(bb); - let is_unreachable = |bb| unreachable_blocks.contains(&bb); - // If there are multiple targets, we want to keep information about reachability for codegen. // For example (see tests/codegen-llvm/match-optimizes-away.rs) // @@ -116,10 +116,10 @@ fn remove_successors_from_switch<'tcx>( }; let otherwise = targets.otherwise(); - let otherwise_unreachable = is_unreachable(otherwise); + let otherwise_unreachable = is_unreachable_block(otherwise); let reachable_iter = targets.iter().filter(|&(value, bb)| { - let is_unreachable = is_unreachable(bb); + let is_unreachable = is_unreachable_block(bb); // We remove this target from the switch, so record the inequality using `Assume`. if is_unreachable && !otherwise_unreachable { add_assumption(BinOp::Ne, value); diff --git a/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs b/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs index b8b9ea7436f33..5834255f3d313 100644 --- a/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs +++ b/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs @@ -14,6 +14,7 @@ pub unsafe fn foo(x: &mut Copied>) -> u32 { // CHECK-NOT: br {{.*}} // CHECK-NOT: select // CHECK: [[RET:%.*]] = load i32, ptr + // CHECK-NEXT: assume // CHECK-NEXT: ret i32 [[RET]] x.next().unwrap_unchecked() } diff --git a/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs b/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs index a0b453fac8e93..5b100d2cdc381 100644 --- a/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs +++ b/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs @@ -26,7 +26,7 @@ pub unsafe fn update(s: *mut State) { // CHECK-NOT: 75{{3|4}} // old: %[[TAG:.+]] = load i8, ptr %s, align 1 - // old-NEXT: trunc nuw i8 %[[TAG]] to i1 + // old-NEXT: and i8 %[[TAG]], 1 // CHECK-NOT: load // CHECK-NOT: store diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir index 2f1f925112b78..e0fcd5c92247c 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir @@ -3,7 +3,9 @@ fn unwrap_unchecked(_1: Option) -> T { debug slf => _1; let mut _0: T; + let mut _3: bool; scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { + let mut _2: isize; scope 2 { } scope 3 (inlined #[track_caller] unreachable_unchecked) { @@ -15,7 +17,12 @@ fn unwrap_unchecked(_1: Option) -> T { } bb0: { + StorageLive(_2); + _2 = discriminant(_1); + _3 = Eq(copy _2, const 1_isize); + assume(move _3); _0 = copy ((_1 as Some).0: T); + StorageDead(_2); return; } } diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir index 2f1f925112b78..e0fcd5c92247c 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir @@ -3,7 +3,9 @@ fn unwrap_unchecked(_1: Option) -> T { debug slf => _1; let mut _0: T; + let mut _3: bool; scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { + let mut _2: isize; scope 2 { } scope 3 (inlined #[track_caller] unreachable_unchecked) { @@ -15,7 +17,12 @@ fn unwrap_unchecked(_1: Option) -> T { } bb0: { + StorageLive(_2); + _2 = discriminant(_1); + _3 = Eq(copy _2, const 1_isize); + assume(move _3); _0 = copy ((_1 as Some).0: T); + StorageDead(_2); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff index 54726613a3a55..ba99ab0229497 100644 --- a/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff @@ -4,6 +4,7 @@ fn single_case(_1: Option) -> i32 { let mut _0: i32; let mut _2: isize; ++ let mut _3: bool; bb0: { _2 = discriminant(_1); @@ -11,6 +12,8 @@ - } - - bb1: { ++ _3 = Eq(copy _2, const 0_isize); ++ assume(move _3); _0 = const 1_i32; return; - } diff --git a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir index 581665773272b..d08aa8456e7f9 100644 --- a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir +++ b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir @@ -3,6 +3,8 @@ fn ub_if_b(_1: Thing) -> Thing { debug t => _1; let mut _0: Thing; + let mut _2: isize; + let mut _3: bool; scope 1 (inlined #[track_caller] unreachable_unchecked) { scope 2 (inlined core::ub_checks::check_language_ub) { scope 3 (inlined core::ub_checks::check_language_ub::runtime) { @@ -11,6 +13,9 @@ fn ub_if_b(_1: Thing) -> Thing { } bb0: { + _2 = discriminant(_1); + _3 = Eq(copy _2, const 0_isize); + assume(move _3); _0 = move _1; return; } diff --git a/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir index a246104fe5ee7..c0f3978663960 100644 --- a/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir +++ b/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir @@ -4,11 +4,12 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { debug v => _1; let mut _0: i32; let mut _2: std::option::Option; - let _3: i32; + let mut _4: bool; + let _5: i32; scope 1 { - debug v1 => _3; + debug v1 => _5; scope 2 { - debug v2 => _3; + debug v2 => _5; } scope 8 (inlined #[track_caller] Option::::unwrap_unchecked) { scope 9 { @@ -22,6 +23,7 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { } } scope 3 (inlined #[track_caller] Option::::unwrap_unchecked) { + let mut _3: isize; scope 4 { } scope 5 (inlined #[track_caller] unreachable_unchecked) { @@ -34,8 +36,11 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { bb0: { _2 = copy (*_1); - _3 = copy ((_2 as Some).0: i32); - _0 = Add(copy _3, copy _3); + _3 = discriminant(_2); + _4 = Eq(copy _3, const 1_isize); + assume(move _4); + _5 = copy ((_2 as Some).0: i32); + _0 = Add(copy _5, copy _5); return; } } From b4e3de09288f02d4b7ade045cafcc627e1709316 Mon Sep 17 00:00:00 2001 From: dianqk Date: Tue, 17 Feb 2026 12:38:45 +0800 Subject: [PATCH 4/4] Run MatchBranchSimplification with opt-level 2 --- compiler/rustc_mir_transform/src/match_branches.rs | 3 ++- tests/mir-opt/simplify_locals_fixedpoint.rs | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index d72be74bc4594..05d085fafe937 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -14,7 +14,8 @@ pub(super) struct MatchBranchSimplification; impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 1 + // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable. + sess.mir_opt_level() >= 2 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { diff --git a/tests/mir-opt/simplify_locals_fixedpoint.rs b/tests/mir-opt/simplify_locals_fixedpoint.rs index 0b6c95630c0a7..01aa7df0716f0 100644 --- a/tests/mir-opt/simplify_locals_fixedpoint.rs +++ b/tests/mir-opt/simplify_locals_fixedpoint.rs @@ -1,8 +1,12 @@ -// skip-filecheck // EMIT_MIR_FOR_EACH_PANIC_STRATEGY -//@ compile-flags: -Zmir-opt-level=1 +//@ compile-flags: -Zmir-opt-level=1 -Zmir-enable-passes=+MatchBranchSimplification +// EMIT_MIR simplify_locals_fixedpoint.foo.SimplifyLocals-final.diff fn foo() { + // CHECK-LABEL: fn foo( + // CHECK-NOT: let mut {{.*}}: bool; + // CHECK-NOT: let mut {{.*}}: u8; + // CHECK-NOT: let mut {{.*}}: bool; if let (Some(a), None) = (Option::::None, Option::::None) { if a > 42u8 {} } @@ -11,5 +15,3 @@ fn foo() { fn main() { foo::<()>(); } - -// EMIT_MIR simplify_locals_fixedpoint.foo.SimplifyLocals-final.diff