Skip to content

Commit

Permalink
[NFC][InstCombine] Extract canTryToConstantAddTwoShiftAmounts() as he…
Browse files Browse the repository at this point in the history
…lper
  • Loading branch information
LebedevRI committed Apr 4, 2021
1 parent 5352490 commit dceb3e5
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,30 @@ using namespace PatternMatch;

#define DEBUG_TYPE "instcombine"

bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1,
Value *ShAmt1) {
// We have two shift amounts from two different shifts. The types of those
// shift amounts may not match. If that's the case let's bailout now..
if (ShAmt0->getType() != ShAmt1->getType())
return false;

// As input, we have the following pattern:
// Sh0 (Sh1 X, Q), K
// We want to rewrite that as:
// Sh x, (Q+K) iff (Q+K) u< bitwidth(x)
// While we know that originally (Q+K) would not overflow
// (because 2 * (N-1) u<= iN -1), we have looked past extensions of
// shift amounts. so it may now overflow in smaller bitwidth.
// To ensure that does not happen, we need to ensure that the total maximal
// shift amount is still representable in that smaller bit width.
unsigned MaximalPossibleTotalShiftAmount =
(Sh0->getType()->getScalarSizeInBits() - 1) +
(Sh1->getType()->getScalarSizeInBits() - 1);
APInt MaximalRepresentableShiftAmount =
APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount);
}

// Given pattern:
// (x shiftopcode Q) shiftopcode K
// we should rewrite it as
Expand Down Expand Up @@ -57,26 +81,8 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1)))))
return nullptr;

// We have two shift amounts from two different shifts. The types of those
// shift amounts may not match. If that's the case let's bailout now..
if (ShAmt0->getType() != ShAmt1->getType())
return nullptr;

// As input, we have the following pattern:
// Sh0 (Sh1 X, Q), K
// We want to rewrite that as:
// Sh x, (Q+K) iff (Q+K) u< bitwidth(x)
// While we know that originally (Q+K) would not overflow
// (because 2 * (N-1) u<= iN -1), we have looked past extensions of
// shift amounts. so it may now overflow in smaller bitwidth.
// To ensure that does not happen, we need to ensure that the total maximal
// shift amount is still representable in that smaller bit width.
unsigned MaximalPossibleTotalShiftAmount =
(Sh0->getType()->getScalarSizeInBits() - 1) +
(Sh1->getType()->getScalarSizeInBits() - 1);
APInt MaximalRepresentableShiftAmount =
APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
// Verify that it would be safe to try to add those two shift amounts.
if (!canTryToConstantAddTwoShiftAmounts(Sh0, ShAmt0, Sh1, ShAmt1))
return nullptr;

// We are only looking for signbit extraction if we have two right shifts.
Expand Down

0 comments on commit dceb3e5

Please sign in to comment.