Skip to content

Commit

Permalink
[InstCombine] Fold comparison of adding two z/sext booleans (#67895)
Browse files Browse the repository at this point in the history
- Add test coverage for sext/zext boolean additions
- [InstCombine] Fold comparison of adding two z/sext booleans

Fixes #64859.
  • Loading branch information
elhewaty authored Oct 6, 2023
1 parent 185e16d commit 5d8fb47
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 285 deletions.
129 changes: 73 additions & 56 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <bitset>

using namespace llvm;
using namespace PatternMatch;
Expand Down Expand Up @@ -2895,19 +2896,89 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp,
return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C));
}

static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0,
Value *Op1, IRBuilderBase &Builder,
bool HasOneUse) {
switch (Table.to_ulong()) {
case 0: // 0 0 0 0
return Builder.getFalse();
case 1: // 0 0 0 1
return HasOneUse ? Builder.CreateNot(Builder.CreateOr(Op0, Op1)) : nullptr;
case 2: // 0 0 1 0
return HasOneUse ? Builder.CreateAnd(Builder.CreateNot(Op0), Op1) : nullptr;
case 3: // 0 0 1 1
return Builder.CreateNot(Op0);
case 4: // 0 1 0 0
return HasOneUse ? Builder.CreateAnd(Op0, Builder.CreateNot(Op1)) : nullptr;
case 5: // 0 1 0 1
return Builder.CreateNot(Op1);
case 6: // 0 1 1 0
return Builder.CreateXor(Op0, Op1);
case 7: // 0 1 1 1
return HasOneUse ? Builder.CreateNot(Builder.CreateAnd(Op0, Op1)) : nullptr;
case 8: // 1 0 0 0
return Builder.CreateAnd(Op0, Op1);
case 9: // 1 0 0 1
return HasOneUse ? Builder.CreateNot(Builder.CreateXor(Op0, Op1)) : nullptr;
case 10: // 1 0 1 0
return Op1;
case 11: // 1 0 1 1
return HasOneUse ? Builder.CreateOr(Builder.CreateNot(Op0), Op1) : nullptr;
case 12: // 1 1 0 0
return Op0;
case 13: // 1 1 0 1
return HasOneUse ? Builder.CreateOr(Op0, Builder.CreateNot(Op1)) : nullptr;
case 14: // 1 1 1 0
return Builder.CreateOr(Op0, Op1);
case 15: // 1 1 1 1
return Builder.getTrue();
default:
llvm_unreachable("Invalid Operation");
}
return nullptr;
}

/// Fold icmp (add X, Y), C.
Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
BinaryOperator *Add,
const APInt &C) {
Value *Y = Add->getOperand(1);
Value *X = Add->getOperand(0);

Value *Op0, *Op1;
Instruction *Ext0, *Ext1;
const CmpInst::Predicate Pred = Cmp.getPredicate();
if (match(Add,
m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))),
m_CombineAnd(m_Instruction(Ext1),
m_ZExtOrSExt(m_Value(Op1))))) &&
Op0->getType()->isIntOrIntVectorTy(1) &&
Op1->getType()->isIntOrIntVectorTy(1)) {
unsigned BW = C.getBitWidth();
std::bitset<4> Table;
auto ComputeTable = [&](bool Op0Val, bool Op1Val) {
int Res = 0;
if (Op0Val)
Res += isa<ZExtInst>(Ext0) ? 1 : -1;
if (Op1Val)
Res += isa<ZExtInst>(Ext1) ? 1 : -1;
return ICmpInst::compare(APInt(BW, Res, true), C, Pred);
};

Table[0] = ComputeTable(false, false);
Table[1] = ComputeTable(false, true);
Table[2] = ComputeTable(true, false);
Table[3] = ComputeTable(true, true);
if (auto *Cond =
createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse()))
return replaceInstUsesWith(Cmp, Cond);
}
const APInt *C2;
if (Cmp.isEquality() || !match(Y, m_APInt(C2)))
return nullptr;

// Fold icmp pred (add X, C2), C.
Value *X = Add->getOperand(0);
Type *Ty = Add->getType();
const CmpInst::Predicate Pred = Cmp.getPredicate();

// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
Expand Down Expand Up @@ -6410,60 +6481,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE)
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);

const APInt *C;
if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
match(I.getOperand(1), m_APInt(C)) &&
X->getType()->isIntOrIntVectorTy(1) &&
Y->getType()->isIntOrIntVectorTy(1)) {
unsigned BitWidth = C->getBitWidth();
Pred = I.getPredicate();
APInt Zero = APInt::getZero(BitWidth);
APInt MinusOne = APInt::getAllOnes(BitWidth);
APInt One(BitWidth, 1);
if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) ||
(C->slt(Zero) && Pred == ICmpInst::ICMP_SLT))
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) ||
(C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT))
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));

if (I.getOperand(0)->hasOneUse()) {
APInt NewC = *C;
// canonicalize predicate to eq/ne
if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) ||
(*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) {
// x s< 0 in [-1, 1] --> x == -1
// x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1
NewC = MinusOne;
Pred = ICmpInst::ICMP_EQ;
} else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) ||
(*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) {
// x s> -1 in [-1, 1] --> x != -1
// x u< -1 in [-1, 1] --> x != -1
Pred = ICmpInst::ICMP_NE;
} else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) {
// x s> 0 in [-1, 1] --> x == 1
NewC = One;
Pred = ICmpInst::ICMP_EQ;
} else if (*C == One && Pred == ICmpInst::ICMP_SLT) {
// x s< 1 in [-1, 1] --> x != 1
Pred = ICmpInst::ICMP_NE;
}

if (NewC == MinusOne) {
if (Pred == ICmpInst::ICMP_EQ)
return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y);
if (Pred == ICmpInst::ICMP_NE)
return BinaryOperator::CreateOr(X, Builder.CreateNot(Y));
} else if (NewC == One) {
if (Pred == ICmpInst::ICMP_EQ)
return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y));
if (Pred == ICmpInst::ICMP_NE)
return BinaryOperator::CreateOr(Builder.CreateNot(X), Y);
}
}
}

return nullptr;
}

Expand Down
Loading

0 comments on commit 5d8fb47

Please sign in to comment.