Skip to content

Commit

Permalink
Auto merge of #75382 - JulianKnodt:match_branches, r=oli-obk
Browse files Browse the repository at this point in the history
First iteration of simplify match branches

This is a simple MIR pass that attempts to convert
```
   bb0: {
        StorageLive(_2);
        _3 = discriminant(_1);
        switchInt(move _3) -> [0isize: bb2, otherwise: bb1];
    }

    bb1: {
        _2 = const false;
        goto -> bb3;
    }

    bb2: {
        _2 = const true;
        goto -> bb3;
    }
```
into
```
    bb0: {
        StorageLive(_2);
        _3 = discriminant(_1);
        _2 = _3 == 0;
        goto -> bb3;
    }
```
There are still missing components(like checking if the assignments are bools).
Was hoping that this could get some review though.

Handles #75141

r? @oli-obk
  • Loading branch information
bors committed Aug 13, 2020
2 parents b6396b7 + 46e5699 commit 5e3f1b1
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 0 deletions.
93 changes: 93 additions & 0 deletions src/librustc_mir/transform/match_branches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use crate::transform::{MirPass, MirSource};
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

pub struct MatchBranchSimplification;

// What's the intent of this pass?
// If one block is found that switches between blocks which both go to the same place
// AND both of these blocks set a similar const in their ->
// condense into 1 block based on discriminant AND goto the destination afterwards

impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) {
let param_env = tcx.param_env(src.def_id());
let bbs = body.basic_blocks_mut();
'outer: for bb_idx in bbs.indices() {
let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind {
TerminatorKind::SwitchInt {
discr: Operand::Move(ref place),
switch_ty,
ref targets,
ref values,
..
} if targets.len() == 2 && values.len() == 1 => {
(place, values[0], switch_ty, targets[0], targets[1])
}
// Only optimize switch int statements
_ => continue,
};

// Check that destinations are identical, and if not, then don't optimize this block
if &bbs[first].terminator().kind != &bbs[second].terminator().kind {
continue;
}

// Check that blocks are assignments of consts to the same place or same statement,
// and match up 1-1, if not don't optimize this block.
let first_stmts = &bbs[first].statements;
let scnd_stmts = &bbs[second].statements;
if first_stmts.len() != scnd_stmts.len() {
continue;
}
for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) {
match (&f.kind, &s.kind) {
// If two statements are exactly the same just ignore them.
(f_s, s_s) if f_s == s_s => (),

(
StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
) if lhs_f == lhs_s => {
if let Some(f_c) = f_c.literal.try_eval_bool(tcx, param_env) {
// This should also be a bool because it's writing to the same place
let s_c = s_c.literal.try_eval_bool(tcx, param_env).unwrap();
if f_c != s_c {
// have to check this here because f_c & s_c might have
// different spans.
continue;
}
}
continue 'outer;
}
// If there are not exclusively assignments, then ignore this
_ => continue 'outer,
}
}
// Take owenership of items now that we know we can optimize.
let discr = discr.clone();
let (from, first) = bbs.pick2_mut(bb_idx, first);

let new_stmts = first.statements.iter().cloned().map(|mut s| {
if let StatementKind::Assign(box (_, ref mut rhs)) = s.kind {
if let Rvalue::Use(Operand::Constant(c)) = rhs {
let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size;
let const_cmp = Operand::const_from_scalar(
tcx,
switch_ty,
crate::interpret::Scalar::from_uint(val, size),
rustc_span::DUMMY_SP,
);
if let Some(c) = c.literal.try_eval_bool(tcx, param_env) {
let op = if c { BinOp::Eq } else { BinOp::Ne };
*rhs = Rvalue::BinaryOp(op, Operand::Move(discr), const_cmp);
}
}
}
s
});
from.statements.extend(new_stmts);
from.terminator_mut().kind = first.terminator().kind.clone();
}
}
}
2 changes: 2 additions & 0 deletions src/librustc_mir/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub mod generator;
pub mod inline;
pub mod instcombine;
pub mod instrument_coverage;
pub mod match_branches;
pub mod no_landing_pads;
pub mod nrvo;
pub mod promote_consts;
Expand Down Expand Up @@ -440,6 +441,7 @@ fn run_optimization_passes<'tcx>(
// with async primitives.
&generator::StateTransform,
&instcombine::InstCombine,
&match_branches::MatchBranchSimplification,
&const_prop::ConstProp,
&simplify_branches::SimplifyBranches::new("after-const-prop"),
&simplify_try::SimplifyArmIdentity,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
- // MIR for `foo` before MatchBranchSimplification
+ // MIR for `foo` after MatchBranchSimplification

fn foo(_1: std::option::Option<()>) -> () {
debug bar => _1; // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11
let mut _0: (); // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25
let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
let mut _3: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26

bb0: {
StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
_3 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
- switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+ _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+ // ty::Const
+ // + ty: isize
+ // + val: Value(Scalar(0x00000000))
+ // mir::Constant
+ // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1
+ // + literal: Const { ty: isize, val: Value(Scalar(0x00000000)) }
+ goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
}

bb1: {
_2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
// ty::Const
// + ty: bool
// + val: Value(Scalar(0x00))
// mir::Constant
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
// + literal: Const { ty: bool, val: Value(Scalar(0x00)) }
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
}

bb2: {
_2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
// ty::Const
// + ty: bool
// + val: Value(Scalar(0x01))
// mir::Constant
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
// + literal: Const { ty: bool, val: Value(Scalar(0x01)) }
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
}

bb3: {
switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
}

bb4: {
_0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
// ty::Const
// + ty: ()
// + val: Value(Scalar(<ZST>))
// mir::Constant
// + span: $DIR/matches_reduce_branches.rs:5:5: 7:6
// + literal: Const { ty: (), val: Value(Scalar(<ZST>)) }
goto -> bb5; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
}

bb5: {
StorageDead(_2); // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2
return; // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
- // MIR for `foo` before MatchBranchSimplification
+ // MIR for `foo` after MatchBranchSimplification

fn foo(_1: std::option::Option<()>) -> () {
debug bar => _1; // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11
let mut _0: (); // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25
let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
let mut _3: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26

bb0: {
StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
_3 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
- switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
+ _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
+ // ty::Const
+ // + ty: isize
+ // + val: Value(Scalar(0x0000000000000000))
+ // mir::Constant
+ // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1
+ // + literal: Const { ty: isize, val: Value(Scalar(0x0000000000000000)) }
+ goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
}

bb1: {
_2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
// ty::Const
// + ty: bool
// + val: Value(Scalar(0x00))
// mir::Constant
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
// + literal: Const { ty: bool, val: Value(Scalar(0x00)) }
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
}

bb2: {
_2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
// ty::Const
// + ty: bool
// + val: Value(Scalar(0x01))
// mir::Constant
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
// + literal: Const { ty: bool, val: Value(Scalar(0x01)) }
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
}

bb3: {
switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
}

bb4: {
_0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
// ty::Const
// + ty: ()
// + val: Value(Scalar(<ZST>))
// mir::Constant
// + span: $DIR/matches_reduce_branches.rs:5:5: 7:6
// + literal: Const { ty: (), val: Value(Scalar(<ZST>)) }
goto -> bb5; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
}

bb5: {
StorageDead(_2); // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2
return; // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2
}
}

13 changes: 13 additions & 0 deletions src/test/mir-opt/matches_reduce_branches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// EMIT_MIR_FOR_EACH_BIT_WIDTH
// EMIT_MIR matches_reduce_branches.foo.MatchBranchSimplification.diff

fn foo(bar: Option<()>) {
if matches!(bar, None) {
()
}
}

fn main() {
let _ = foo(None);
let _ = foo(Some(()));
}

0 comments on commit 5e3f1b1

Please sign in to comment.