Skip to content

Commit

Permalink
Fix invalid bitwidth LLVM Pass.
Browse files Browse the repository at this point in the history
1. Upgrade non-standard sizes
2. Remove redundant trunc instructions
  • Loading branch information
pvelesko committed Jan 20, 2025
1 parent 620cca2 commit c470576
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 1 deletion.
2 changes: 1 addition & 1 deletion llvm_passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ add_library(LLVMHipPasses MODULE HipPasses.cpp
HipPrintf.cpp HipGlobalVariables.cpp HipTextureLowering.cpp HipAbort.cpp
HipEmitLoweredNames.cpp HipWarps.cpp HipKernelArgSpiller.cpp
HipLowerZeroLengthArrays.cpp HipSanityChecks.cpp HipLowerSwitch.cpp
HipLowerMemset.cpp HipIGBADetector.cpp ${EXTRA_OBJS})
HipLowerMemset.cpp HipIGBADetector.cpp HipPromoteInts.cpp ${EXTRA_OBJS})

# If trying to recompile with LLVM unloaded, the inlcude path is not found
target_compile_options(LLVMHipDynMem PRIVATE -I/${LLVM_INCLUDE_DIRS})
Expand Down
4 changes: 4 additions & 0 deletions llvm_passes/HipPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "HipLowerSwitch.h"
#include "HipLowerMemset.h"
#include "HipIGBADetector.h"
#include "HipPromoteInts.h"

#include "llvm/IR/Module.h"
#include "llvm/Passes/PassBuilder.h"
Expand Down Expand Up @@ -99,6 +100,9 @@ class HipFixOpenCLMDPass : public PassInfoMixin<HipFixOpenCLMDPass> {
static void addFullLinkTimePasses(ModulePassManager &MPM) {
MPM.addPass(HipSanityChecksPass());

// Fix InvalidBitWidth errors due to non-standard integer types
MPM.addPass(HipPromoteIntsPass());

/// For extracting name expression to lowered name expressions (hiprtc).
MPM.addPass(HipEmitLoweredNamesPass());

Expand Down
115 changes: 115 additions & 0 deletions llvm_passes/HipPromoteInts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "HipPromoteInts.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "hip-promote-ints"

/**
* This pass promotes integer types to the next standard bit width.
* During optimization of loops, LLVM generates non-standard integer types
* such as i33 or i56
*
* __global__ void testWarpCalc(int* debug) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int globalIdx = bid * blockDim.x + tid;
// Optimizations on this loop will generate i33 types.
int result = 0;
for(int i = 0; i < tid + 1; i++) {
result += i * globalIdx;
}
// Store using atomic operation
atomicExch(&debug[globalIdx], result);
}
*
* https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/2823
*/

using namespace llvm;

bool HipPromoteIntsPass::isStandardBitWidth(unsigned BitWidth) {
return BitWidth == 1 || BitWidth == 8 || BitWidth == 16 || BitWidth == 32 || BitWidth == 64;
}

unsigned HipPromoteIntsPass::getPromotedBitWidth(unsigned Original) {
if (Original <= 8) return 8;
if (Original <= 16) return 16;
if (Original <= 32) return 32;
return 64;
}

PreservedAnalyses HipPromoteIntsPass::run(Module &M, ModuleAnalysisManager &AM) {
bool Changed = false;

for (Function &F : M) {
LLVM_DEBUG(dbgs() << "[HipPromoteInts] Analyzing function: " << F.getName() << "\n");

for (BasicBlock &BB : F) {
// Use a vector to store instructions that need modification
std::vector<Instruction*> WorkList;
for (Instruction &I : BB) {
WorkList.push_back(&I);
}

// Process the worklist safely outside the BB iteration
for (Instruction *I : WorkList) {
if (auto *IntTy = dyn_cast<IntegerType>(I->getType())) {
if (!isStandardBitWidth(IntTy->getBitWidth())) {
LLVM_DEBUG(dbgs() << "[HipPromoteInts] Found non-standard type in result: " << *I << "\n");

unsigned NextStdSize = getPromotedBitWidth(IntTy->getBitWidth());
Type *PromotedType = Type::getIntNTy(M.getContext(), NextStdSize);

LLVM_DEBUG(dbgs() << "[HipPromoteInts] Promoting from i" << IntTy->getBitWidth()
<< " to i" << NextStdSize << "\n");

// Update the instruction type
I->mutateType(PromotedType);

// Special handling for trunc instructions where source and dest are same size
if (isa<TruncInst>(I)) {
auto *Trunc = cast<TruncInst>(I);
Value *Src = Trunc->getOperand(0);
if (auto *SrcIntTy = dyn_cast<IntegerType>(Src->getType())) {
if (SrcIntTy->getBitWidth() == NextStdSize) {
LLVM_DEBUG(dbgs() << "[HipPromoteInts] Found trunc with matching source size: " << *Trunc << "\n");
LLVM_DEBUG(dbgs() << "[HipPromoteInts] Source operand: " << *Src << "\n");
// When source and dest types are the same, just use the source directly
Trunc->replaceAllUsesWith(Src);
Trunc->eraseFromParent();
Changed = true;
continue;
}
}
}

// Update operands if needed
if (auto *BinOp = dyn_cast<BinaryOperator>(I)) {
Value *LHS = BinOp->getOperand(0);
Value *RHS = BinOp->getOperand(1);

IRBuilder<> Builder(I);
if (LHS->getType() != PromotedType) {
LHS = Builder.CreateZExtOrTrunc(LHS, PromotedType);
BinOp->setOperand(0, LHS);
}
if (RHS->getType() != PromotedType) {
RHS = Builder.CreateZExtOrTrunc(RHS, PromotedType);
BinOp->setOperand(1, RHS);
}
}

LLVM_DEBUG(dbgs() << "[HipPromoteInts] Instruction after promotion: " << *I << "\n");
Changed = true;
}
}
}
}
}

return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
22 changes: 22 additions & 0 deletions llvm_passes/HipPromoteInts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef HIP_PROMOTE_INTS_H
#define HIP_PROMOTE_INTS_H

#include "llvm/IR/PassManager.h"
#include "llvm/IR/Module.h"

namespace llvm {

class HipPromoteIntsPass : public PassInfoMixin<HipPromoteIntsPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);

// Promote a non-standard integer type to the next larger standard size
static unsigned getPromotedBitWidth(unsigned Original);

// Check if the given bit width is a standard size (8, 16, 32, 64)
static bool isStandardBitWidth(unsigned BitWidth);
};

} // namespace llvm

#endif // HIP_PROMOTE_INTS_H

0 comments on commit c470576

Please sign in to comment.