Skip to content

Commit

Permalink
[Pass] fix sqrt add to or bug and getFirst/LastInst refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
chen.qian committed Dec 6, 2024
1 parent e116ff0 commit 89642ac
Showing 1 changed file with 89 additions and 75 deletions.
164 changes: 89 additions & 75 deletions llvm/lib/Target/RISCV/RISCVLoopUnrollAndRemainder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/DCE.h"
#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
Expand Down Expand Up @@ -171,21 +172,19 @@ static ICmpInst *getLastICmpInstWithPredicate(BasicBlock *BB,
return lastICmp;
}

// Helper function to get the first ICmp instruction in a basic block
static ICmpInst *getFirstICmpInst(BasicBlock *BB) {
template <typename T> static T *getFirstInst(BasicBlock *BB) {
for (Instruction &I : *BB) {
if (auto *CI = dyn_cast<ICmpInst>(&I)) {
return CI;
if (T *Inst = dyn_cast<T>(&I)) {
return Inst;
}
}
return nullptr;
}

// Helper function to get the last ICmp instruction in a basic block
static ICmpInst *getLastICmpInst(BasicBlock *BB) {
for (auto it = BB->rbegin(); it != BB->rend(); ++it) {
if (auto *icmp = dyn_cast<ICmpInst>(&*it)) {
return icmp;
template <typename T> static T *getLastInst(BasicBlock *BB) {
for (Instruction &I : reverse(*BB)) {
if (T *Inst = dyn_cast<T>(&I)) {
return Inst;
}
}
return nullptr;
Expand Down Expand Up @@ -239,16 +238,6 @@ static PHINode *getLastI32Phi(BasicBlock *BB) {
return nullptr;
}

// Helper function to get the last PHI node in a basic block
static PHINode *getLastPhi(BasicBlock *BB) {
for (auto it = BB->rbegin(); it != BB->rend(); ++it) {
if (auto *Phi = dyn_cast<PHINode>(&*it)) {
return Phi;
}
}
return nullptr;
}

// Helper function to get the first CallInst with a specific name in a basic
// block
static CallInst *getFirstCallInstWithName(BasicBlock *BB, StringRef Name) {
Expand Down Expand Up @@ -406,6 +395,38 @@ static void movePHINodesToTop(BasicBlock &BB,
}
}

static void modifyFirdAddToOr(BasicBlock *ClonedForBody) {
SmallVector<BinaryOperator *> addInsts;

// Collect all add instructions that meet the criteria
for (auto &I : *ClonedForBody) {
if (auto *binOp = dyn_cast<BinaryOperator>(&I)) {
if (binOp->getOpcode() == Instruction::Add && binOp->hasNoSignedWrap() &&
binOp->hasNoUnsignedWrap()) {
addInsts.push_back(binOp);
}
}
}
if (addInsts.empty()) {
return;
}
// Replace each add instruction with an or disjoint instruction
for (auto it = addInsts.begin(); it != std::prev(addInsts.end()); ++it) {
auto *addInst = *it;
// Create a new or disjoint instruction
Instruction *orInst =
BinaryOperator::CreateDisjoint(Instruction::Or, addInst->getOperand(0),
addInst->getOperand(1), "add", addInst);

// Replace all uses of the add instruction
addInst->replaceAllUsesWith(orInst);

// Delete the original add instruction
addInst->eraseFromParent();
orInst->setName("add");
}
}

// Helper function to update predecessors to point to a new preheader
static void updatePredecessorsToPreheader(BasicBlock *ForBody,
BasicBlock *ForBodyPreheader) {
Expand Down Expand Up @@ -1151,7 +1172,7 @@ static Value *expandForCondPreheader(
}

// Get the icmp instruction in ForCondPreheader
ICmpInst *icmpInst = getFirstICmpInst(ForCondPreheader);
ICmpInst *icmpInst = getFirstInst<ICmpInst>(ForCondPreheader);

// Ensure we found the icmp instruction
assert(icmpInst && "Failed to find icmp instruction in ForCondPreheader");
Expand Down Expand Up @@ -1278,7 +1299,7 @@ static void insertUnusedInstructionsBeforeIcmp(PHINode *phiI32InClonedForBody,

static void modifyClonedForBody(BasicBlock *ClonedForBody) {

ICmpInst *lastIcmpEq = getLastICmpInst(ClonedForBody);
ICmpInst *lastIcmpEq = getLastInst<ICmpInst>(ClonedForBody);
assert(lastIcmpEq &&
"Failed to find last icmp eq instruction in ClonedForBody");

Expand Down Expand Up @@ -1472,7 +1493,7 @@ static void modifyForCondPreheader2(BasicBlock *ClonedForBody,
}

// Find operand 1 of the icmp instruction from ClonedForBody
ICmpInst *firstIcmp = getFirstICmpInst(ClonedForBody);
ICmpInst *firstIcmp = getFirstInst<ICmpInst>(ClonedForBody);
assert(firstIcmp && "Unable to find icmp instruction in ClonedForBody");
Value *IcmpOperand1 = firstIcmp->getOperand(1);

Expand Down Expand Up @@ -1549,7 +1570,7 @@ static void modifyForCondPreheader2(BasicBlock *ClonedForBody,

static Value *modifyClonedForBodyPreheader(BasicBlock *ClonedForBodyPreheader,
BasicBlock *ForBody) {
ICmpInst *firstIcmp = getFirstICmpInst(ForBody);
ICmpInst *firstIcmp = getFirstInst<ICmpInst>(ForBody);
assert(firstIcmp && "Unable to find icmp instruction in ForBody");

Value *IcmpOperand1 = firstIcmp->getOperand(1);
Expand Down Expand Up @@ -2011,35 +2032,27 @@ static Instruction *modifyAddToOrInClonedForBody(BasicBlock *ClonedForBody) {
return orInst;
}

static void modifyAddToOr(BasicBlock *ClonedForBody) {
SmallVector<BinaryOperator *> addInsts;
static void runInstCombinePass(Function &F) {
// Create necessary analysis managers
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;

// Collect all add instructions that meet the criteria
for (auto &I : *ClonedForBody) {
if (auto *binOp = dyn_cast<BinaryOperator>(&I)) {
if (binOp->getOpcode() == Instruction::Add) {
addInsts.push_back(binOp);
}
}
}
if (addInsts.empty()) {
return;
}
// Replace each add instruction with an or disjoint instruction
for (auto it = addInsts.begin(); it != std::prev(addInsts.end()); ++it) {
auto *addInst = *it;
// Create a new or disjoint instruction
Instruction *orInst =
BinaryOperator::CreateDisjoint(Instruction::Or, addInst->getOperand(0),
addInst->getOperand(1), "add", addInst);
// Create pass builder
PassBuilder PB;

// Replace all uses of the add instruction
addInst->replaceAllUsesWith(orInst);
// Register analyses
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

// Delete the original add instruction
addInst->eraseFromParent();
orInst->setName("add");
}
// Create function-level optimization pipeline
FunctionPassManager FPM;
FPM.addPass(InstCombinePass());
FPM.run(F, FAM);
}

static Value *unrolladdcClonedForBody(BasicBlock *ClonedForBody,
Expand All @@ -2058,7 +2071,7 @@ static Value *unrolladdcClonedForBody(BasicBlock *ClonedForBody,
assert(firstNonPHI && orInst && "Start or end instruction not found");

// Find the icmp instruction
Instruction *icmpInst = getFirstICmpInst(ClonedForBody);
Instruction *icmpInst = getFirstInst<ICmpInst>(ClonedForBody);

// Ensure that the icmp instruction is found
assert(icmpInst && "icmp instruction not found");
Expand Down Expand Up @@ -2298,7 +2311,7 @@ static void unrollAddc(Function &F, ScalarEvolution &SE, Loop *L,
assert(ForCondPreheader && "Expected to find for.cond.preheader!");
expandForCondPreheaderaddc(F, ForCondPreheader, ClonedForBody, ForBody, sub,
unroll_factor);
modifyAddToOr(ClonedForBody);
runInstCombinePass(F);
groupAndReorderInstructions(ClonedForBody);

// Verify the function
Expand Down Expand Up @@ -2816,11 +2829,11 @@ static void postUnrollLoopWithCount(Function &F, Loop *L, int unroll_count) {
insertPhiNodesForFMulAdd(LoopHeader, LoopPreheader, FMulAddCalls);

movePHINodesToTop(*LoopHeader);
modifyAddToOr(LoopHeader);
runInstCombinePass(F);
groupAndReorderInstructions(LoopHeader);

// Create for.end basic block after LoopHeader
ICmpInst *LastICmp = getLastICmpInst(LoopHeader);
ICmpInst *LastICmp = getLastInst<ICmpInst>(LoopHeader);
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
// Get the first operand of LastICmp
Value *Operand1 = LastICmp->getOperand(1);
Expand Down Expand Up @@ -3023,7 +3036,7 @@ static bool shouldUnrollDotprodType(Function &F, LoopInfo *LI) {
}

static std::pair<Value *, Value *> modifyEntryBB(BasicBlock &entryBB) {
ICmpInst *icmp = getLastICmpInst(&entryBB);
ICmpInst *icmp = getLastInst<ICmpInst>(&entryBB);
assert(icmp && "icmp not found");
Value *start_index = icmp->getOperand(0);
Value *end_index = icmp->getOperand(1);
Expand Down Expand Up @@ -3115,7 +3128,7 @@ static void postUnrollLoopWithVariable(Function &F, Loop *L, int unroll_count) {
temp->insertBefore(LoopPreheader->getTerminator());
}

ICmpInst *lastICmp = getLastICmpInst(ForBody7);
ICmpInst *lastICmp = getLastInst<ICmpInst>(ForBody7);
assert(lastICmp && "icmp not found");
lastICmp->setOperand(1, Sub);
lastICmp->setPredicate(ICmpInst::ICMP_SLT);
Expand Down Expand Up @@ -3552,7 +3565,7 @@ static std::tuple<Value *, GetElementPtrInst *, Value *>
modifyOuterLoop4(Loop *L, BasicBlock *ForBodyMerged,
BasicBlock *CloneForBodyPreheader) {
BasicBlock *BB = L->getHeader();
PHINode *phi = getLastPhi(BB);
PHINode *phi = getLastInst<PHINode>(BB);
// Add new instructions
IRBuilder<> Builder(BB);
Builder.SetInsertPoint(phi->getNextNode());
Expand Down Expand Up @@ -3596,7 +3609,7 @@ static void modifyInnerLoop4(Loop *L, BasicBlock *ForBodyMerged, Value *Sub,
movePHINodesToTop(*ForBodyMerged);

groupAndReorderInstructions(ForBodyMerged);
ICmpInst *LastICmp = getLastICmpInst(ForBodyMerged);
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForBodyMerged);
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
LastICmp->setOperand(1, Sub);
swapTerminatorSuccessors(ForBodyMerged);
Expand Down Expand Up @@ -3653,7 +3666,8 @@ static void modifyInnerLoop4(Loop *L, BasicBlock *ForBodyMerged, Value *Sub,
AddPHI->addIncoming(Add2, NewForEnd);
Value *phifloatincomingvalue0 =
getFirstCallInstWithName(CloneForBody, "llvm.fmuladd.f32");
Value *phii32incomingvalue0 = getLastICmpInst(CloneForBody)->getOperand(0);
Value *phii32incomingvalue0 =
getLastInst<ICmpInst>(CloneForBody)->getOperand(0);
for (PHINode &Phi : CloneForBody->phis()) {
if (Phi.getType()->isIntegerTy(32)) {
Phi.setIncomingValue(0, AddPHI);
Expand All @@ -3676,7 +3690,7 @@ static void modifyInnerLoop4(Loop *L, BasicBlock *ForBodyMerged, Value *Sub,
static std::tuple<Value *, Value *, GetElementPtrInst *>
modifyOuterLoop8(Loop *L) {
BasicBlock *BB = L->getHeader();
ICmpInst *LastICmp = getLastICmpInst(BB);
ICmpInst *LastICmp = getLastInst<ICmpInst>(BB);
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
swapTerminatorSuccessors(BB);

Expand Down Expand Up @@ -3714,7 +3728,7 @@ static std::tuple<Value *, Value *, GetElementPtrInst *>
modifyOuterLoop16(Loop *L) {
BasicBlock *BB = L->getHeader();
BasicBlock *BBLoopPreHeader = L->getLoopPreheader();
ICmpInst *LastICmp = getLastICmpInst(BB);
ICmpInst *LastICmp = getLastInst<ICmpInst>(BB);
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
swapTerminatorSuccessors(BB);

Expand Down Expand Up @@ -3763,7 +3777,7 @@ static void modifyInnerLoop(Loop *L, BasicBlock *ForBodyMerged, Value *Add60,
movePHINodesToTop(*ForBodyMerged);

groupAndReorderInstructions(ForBodyMerged);
ICmpInst *LastICmp = getLastICmpInst(ForBodyMerged);
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForBodyMerged);
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
LastICmp->setOperand(1, Add60);
swapTerminatorSuccessors(ForBodyMerged);
Expand Down Expand Up @@ -3873,7 +3887,7 @@ static void modifyInnerLoop(Loop *L, BasicBlock *ForBodyMerged, Value *Add60,

Value *operand1 = unroll_count == 16
? getFirstI32Phi(OuterBB)
: getLastICmpInst(CloneForBody)->getOperand(1);
: getLastInst<ICmpInst>(CloneForBody)->getOperand(1);
// Create a new comparison instruction
ICmpInst *NewCmp =
new ICmpInst(ICmpInst::ICMP_UGT, PhiSum, operand1, "cmp182.not587");
Expand All @@ -3890,7 +3904,8 @@ static void modifyInnerLoop(Loop *L, BasicBlock *ForBodyMerged, Value *Add60,
getFirstCallInstWithName(CloneForBody, "llvm.fmuladd.f32");
for (PHINode &Phi : CloneForBody->phis()) {
if (Phi.getType()->isIntegerTy(32)) {
Phi.setIncomingValue(0, getLastICmpInst(CloneForBody)->getOperand(0));
Phi.setIncomingValue(0,
getLastInst<ICmpInst>(CloneForBody)->getOperand(0));
Phi.setIncomingBlock(0, CloneForBody);
Phi.setIncomingValue(1, PhiSum);
Phi.setIncomingBlock(1, ForEnd164);
Expand Down Expand Up @@ -3981,7 +3996,7 @@ static void modifyFirstCloneForBody(BasicBlock *CloneForBody,
lastAddInst = &I;
}
}
ICmpInst *LastCmpInst = getLastICmpInst(CloneForBody);
ICmpInst *LastCmpInst = getLastInst<ICmpInst>(CloneForBody);
LastCmpInst->setOperand(0, lastAddInst);
LastCmpInst->setOperand(1, Operand1);
FirstI32Phi->setIncomingValue(1, lastAddInst);
Expand Down Expand Up @@ -4045,7 +4060,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
getFirstI32Phi(ForCond23Preheader)->getIncomingBlock(0);
Instruction *FirstI32Phi = getFirstI32Phi(ForCondCleanup3);

ICmpInst *LastICmp = getLastICmpInst(ForCondCleanup3);
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForCondCleanup3);
// Create new add instruction
IRBuilder<> Builder(LastICmp);
Value *Add269 = Builder.CreateNSWAdd(
Expand All @@ -4067,7 +4082,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,

N_069->setIncomingValue(1, Add281);

ICmpInst *LastICmpInPreheader = getLastICmpInst(ForCond23Preheader);
ICmpInst *LastICmpInPreheader = getLastInst<ICmpInst>(ForCond23Preheader);
// Create new phi node
PHINode *N_0_lcssa = PHINode::Create(Type::getInt32Ty(F.getContext()), 2,
"n.0.lcssa", LastICmpInPreheader);
Expand All @@ -4093,7 +4108,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
Value *Add11 = Builder.CreateAdd(Operand1, CoeffPosLcssa);

ForBody27LrPh->getTerminator()->setSuccessor(0, CloneForBody);
ICmpInst *LastICmpInForBodyMerged = getLastICmpInst(ForBodyMerged);
ICmpInst *LastICmpInForBodyMerged = getLastInst<ICmpInst>(ForBodyMerged);
LastICmpInForBodyMerged->setOperand(1, Operand1);
LastICmpInForBodyMerged->setOperand(0, Inc20_7);

Expand Down Expand Up @@ -4159,9 +4174,8 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
CI->setOperand(2, PHI);
}
movePHINodesToTop(*ForBodyMerged);
modifyAddToOr(ForBodyMerged);

ICmpInst *LastICmpForBodyMerged = getLastICmpInst(ForBodyMerged);
modifyFirdAddToOr(ForBodyMerged);
ICmpInst *LastICmpForBodyMerged = getLastInst<ICmpInst>(ForBodyMerged);
LastICmpForBodyMerged->setPredicate(ICmpInst::ICMP_SGT);
cast<Instruction>(LastICmpForBodyMerged->getOperand(0))
->setOperand(0, getFirstI32Phi(ForBodyMerged));
Expand Down Expand Up @@ -4256,7 +4270,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
CoeffPosLcssaPhi->addIncoming(SubResult, ForCondCleanup26LoopExit);
// eraseAllStoreInstInBB(ForCondCleanup26);

ICmpInst *LastICmpForCondCleanup26 = getLastICmpInst(ForCondCleanup26);
ICmpInst *LastICmpForCondCleanup26 = getLastInst<ICmpInst>(ForCondCleanup26);

LastICmpForCondCleanup26->setPredicate(ICmpInst::ICMP_SLT);
PHINode *FirstI32ForCondCleanup3 = getFirstI32Phi(ForCondCleanup3);
Expand Down Expand Up @@ -4314,7 +4328,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
0, ConstantInt::get(getLastI32Phi(ForCond130Preheader)->getType(), 0));
LastI32Phi130->setIncomingValue(1, AndResult);

ICmpInst *LastICmp130 = getLastICmpInst(ForCond130Preheader);
ICmpInst *LastICmp130 = getLastInst<ICmpInst>(ForCond130Preheader);
LastICmp130->setOperand(1, FirstI32ForCondCleanup3);

PHINode *LastI32PhiClone = getLastFloatPhi(CloneForBody);
Expand Down Expand Up @@ -4434,9 +4448,8 @@ static void modifyFirdSecondLoop(Function &F, Loop *L,
Add76310->addIncoming(Add76, ForBodyMerged);

movePHINodesToTop(*ForBodyMerged);
modifyAddToOr(ForBodyMerged);

ICmpInst *LastICmp = getLastICmpInst(ForBodyMerged);
modifyFirdAddToOr(ForBodyMerged);
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForBodyMerged);
LastICmp->setPredicate(ICmpInst::ICMP_SGT);
cast<Instruction>(Add76)->moveBefore(LastICmp);
LastICmp->setOperand(0, Add76);
Expand Down Expand Up @@ -5043,6 +5056,7 @@ RISCVLoopUnrollAndRemainderPass::run(Function &F, FunctionAnalysisManager &AM) {
if (currentUnrollType == UnrollType::FIRD) {
addLegacyCommonOptimizationPasses(F);
}

// Verify function
if (verifyFunction(F, &errs())) {
LLVM_DEBUG(errs() << "Function verification failed\n");
Expand Down

0 comments on commit 89642ac

Please sign in to comment.