Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fold comparison of adding two z/sext booleans #67895

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 73 additions & 56 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <bitset>

using namespace llvm;
using namespace PatternMatch;
Expand Down Expand Up @@ -2895,19 +2896,89 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp,
return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C));
}

static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0,
Value *Op1, IRBuilderBase &Builder,
bool HasOneUse) {
switch (Table.to_ulong()) {
case 0: // 0 0 0 0
return Builder.getFalse();
case 1: // 0 0 0 1
return HasOneUse ? Builder.CreateNot(Builder.CreateOr(Op0, Op1)) : nullptr;
case 2: // 0 0 1 0
return HasOneUse ? Builder.CreateAnd(Builder.CreateNot(Op0), Op1) : nullptr;
case 3: // 0 0 1 1
return Builder.CreateNot(Op0);
case 4: // 0 1 0 0
return HasOneUse ? Builder.CreateAnd(Op0, Builder.CreateNot(Op1)) : nullptr;
case 5: // 0 1 0 1
return Builder.CreateNot(Op1);
case 6: // 0 1 1 0
return Builder.CreateXor(Op0, Op1);
case 7: // 0 1 1 1
return HasOneUse ? Builder.CreateNot(Builder.CreateAnd(Op0, Op1)) : nullptr;
case 8: // 1 0 0 0
return Builder.CreateAnd(Op0, Op1);
case 9: // 1 0 0 1
return HasOneUse ? Builder.CreateNot(Builder.CreateXor(Op0, Op1)) : nullptr;
case 10: // 1 0 1 0
return Op1;
case 11: // 1 0 1 1
return HasOneUse ? Builder.CreateOr(Builder.CreateNot(Op0), Op1) : nullptr;
case 12: // 1 1 0 0
return Op0;
case 13: // 1 1 0 1
return HasOneUse ? Builder.CreateOr(Op0, Builder.CreateNot(Op1)) : nullptr;
case 14: // 1 1 1 0
return Builder.CreateOr(Op0, Op1);
case 15: // 1 1 1 1
return Builder.getTrue();
default:
llvm_unreachable("Invalid Operation");
}
return nullptr;
}

/// Fold icmp (add X, Y), C.
Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
BinaryOperator *Add,
const APInt &C) {
Value *Y = Add->getOperand(1);
Value *X = Add->getOperand(0);

Value *Op0, *Op1;
Instruction *Ext0, *Ext1;
const CmpInst::Predicate Pred = Cmp.getPredicate();
if (match(Add,
m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))),
m_CombineAnd(m_Instruction(Ext1),
m_ZExtOrSExt(m_Value(Op1))))) &&
Op0->getType()->isIntOrIntVectorTy(1) &&
Op1->getType()->isIntOrIntVectorTy(1)) {
unsigned BW = C.getBitWidth();
std::bitset<4> Table;
auto ComputeTable = [&](bool Op0Val, bool Op1Val) {
int Res = 0;
if (Op0Val)
Res += isa<ZExtInst>(Ext0) ? 1 : -1;
if (Op1Val)
Res += isa<ZExtInst>(Ext1) ? 1 : -1;
return ICmpInst::compare(APInt(BW, Res, true), C, Pred);
};

Table[0] = ComputeTable(false, false);
Table[1] = ComputeTable(false, true);
Table[2] = ComputeTable(true, false);
Table[3] = ComputeTable(true, true);
if (auto *Cond =
createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse()))
return replaceInstUsesWith(Cmp, Cond);
}
const APInt *C2;
if (Cmp.isEquality() || !match(Y, m_APInt(C2)))
return nullptr;

// Fold icmp pred (add X, C2), C.
Value *X = Add->getOperand(0);
Type *Ty = Add->getType();
const CmpInst::Predicate Pred = Cmp.getPredicate();

// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
Expand Down Expand Up @@ -6410,60 +6481,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE)
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);

const APInt *C;
if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
match(I.getOperand(1), m_APInt(C)) &&
X->getType()->isIntOrIntVectorTy(1) &&
Y->getType()->isIntOrIntVectorTy(1)) {
unsigned BitWidth = C->getBitWidth();
Pred = I.getPredicate();
APInt Zero = APInt::getZero(BitWidth);
APInt MinusOne = APInt::getAllOnes(BitWidth);
APInt One(BitWidth, 1);
if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) ||
(C->slt(Zero) && Pred == ICmpInst::ICMP_SLT))
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) ||
(C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT))
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));

if (I.getOperand(0)->hasOneUse()) {
APInt NewC = *C;
// canonicalize predicate to eq/ne
if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) ||
(*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) {
// x s< 0 in [-1, 1] --> x == -1
// x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1
NewC = MinusOne;
Pred = ICmpInst::ICMP_EQ;
} else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) ||
(*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) {
// x s> -1 in [-1, 1] --> x != -1
// x u< -1 in [-1, 1] --> x != -1
Pred = ICmpInst::ICMP_NE;
} else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) {
// x s> 0 in [-1, 1] --> x == 1
NewC = One;
Pred = ICmpInst::ICMP_EQ;
} else if (*C == One && Pred == ICmpInst::ICMP_SLT) {
// x s< 1 in [-1, 1] --> x != 1
Pred = ICmpInst::ICMP_NE;
}

if (NewC == MinusOne) {
if (Pred == ICmpInst::ICMP_EQ)
return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y);
if (Pred == ICmpInst::ICMP_NE)
return BinaryOperator::CreateOr(X, Builder.CreateNot(Y));
} else if (NewC == One) {
if (Pred == ICmpInst::ICMP_EQ)
return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y));
if (Pred == ICmpInst::ICMP_NE)
return BinaryOperator::CreateOr(Builder.CreateNot(X), Y);
}
}
}

return nullptr;
}

Expand Down
Loading