Skip to content

Commit

Permalink
Working!
Browse files Browse the repository at this point in the history
  • Loading branch information
pvelesko committed Jan 22, 2025
1 parent 828f7ea commit 5e3d475
Showing 1 changed file with 93 additions and 63 deletions.
156 changes: 93 additions & 63 deletions llvm_passes/HipPromoteInts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,44 @@ void processInstruction(Instruction *I, Type *OldTy, Type *PromotedTy, std::stri
Replacements.push_back(Replacement(I, NewPhi));
}
else if (isa<ZExtInst>(I)) {
Value* SrcOp = cast<ZExtInst>(I)->getOperand(0);
auto NewInst = Builder.CreateZExt(SrcOp, PromotedTy);
errs() << Indent << " " << *I << " ============> " << *NewInst << "\n";
PromotedValues[I] = NewInst;
Replacements.push_back(Replacement(I, NewInst));
ZExtInst* ZExtI = cast<ZExtInst>(I);
Value* SrcOp = ZExtI->getOperand(0);

// If we're extending from our old type to our promoted type,
// just use the promoted value directly
if (SrcOp->getType() == OldTy && ZExtI->getDestTy() == PromotedTy) {
Value* PromotedSrc = PromotedValues.count(SrcOp) ? PromotedValues[SrcOp] : SrcOp;
errs() << Indent << " " << *I << " ============> Using promoted: " << *PromotedSrc << "\n";
PromotedValues[I] = PromotedSrc;
Replacements.push_back(Replacement(I, PromotedSrc));
} else {
// Otherwise handle as normal
Value* PromotedSrc = PromotedValues.count(SrcOp) ? PromotedValues[SrcOp] : SrcOp;
if (PromotedSrc->getType() != PromotedTy) {
PromotedSrc = Builder.CreateZExt(PromotedSrc, PromotedTy);
}
PromotedValues[I] = PromotedSrc;
Replacements.push_back(Replacement(I, PromotedSrc));
errs() << Indent << " " << *I << " ============> " << *PromotedSrc << "\n";
}
}
else if (isa<TruncInst>(I)) {
TruncInst* TruncI = cast<TruncInst>(I);
Value* SrcOp = TruncI->getOperand(0);
Value* PromotedSrc = getPromotedValue(SrcOp);
Value* PromotedSrc = PromotedValues.count(SrcOp) ? PromotedValues[SrcOp] : SrcOp;

// Always truncate to the original destination type
auto NewInst = Builder.CreateTrunc(PromotedSrc, TruncI->getDestTy());
errs() << Indent << " " << *I << " ============> " << *NewInst << "\n";
PromotedValues[I] = NewInst;
Replacements.push_back(Replacement(I, NewInst));
// Verify the source is actually of our promoted type
if (PromotedSrc->getType() != PromotedTy) {
PromotedSrc = Builder.CreateZExt(PromotedSrc, PromotedTy);
}

// Create a new trunc for external users
Value* NewTrunc = Builder.CreateTrunc(PromotedSrc, TruncI->getType());
errs() << Indent << " " << *I << " ============> " << *NewTrunc << "\n";

// Store both the promoted and truncated versions
PromotedValues[I] = PromotedSrc; // Use promoted version in our chain
Replacements.push_back(Replacement(I, NewTrunc)); // Replace old instruction with new trunc for external users
}
else if (isa<BinaryOperator>(I)) {
BinaryOperator* BinOp = cast<BinaryOperator>(I);
Expand Down Expand Up @@ -177,78 +199,86 @@ bool promoteChainPrint(Instruction *OldI, Type *OldTy, Type *PromotedTy,
SmallVectorImpl<Replacement> &Replacements,
SmallDenseMap<Value*, Value*> &PromotedValues,
unsigned Depth = 0) {
if (!Visited.insert(OldI).second)
// If we've already processed this instruction, just return
if (!Visited.insert(OldI).second) {
// If we have a promoted value for this instruction, use it
if (PromotedValues.count(OldI)) {
errs() << std::string(Depth * 2, ' ') << "Already processed: " << *OldI << "\n";
return true;
}
return false;
}

std::string Indent(Depth * 2, ' ');

// Process instruction but don't replace yet
// Process instruction
processInstruction(OldI, OldTy, PromotedTy, Indent, Replacements, PromotedValues);

// Recursively process all users before replacing anything
// Recursively process all users
for (User *U : OldI->users()) {
if (auto *UI = dyn_cast<Instruction>(U))
if (auto *UI = dyn_cast<Instruction>(U)) {
promoteChainPrint(UI, OldTy, PromotedTy, Visited, Replacements, PromotedValues, Depth + 1);
}
}

return true;
}

PreservedAnalyses HipPromoteIntsPass::run(Module &M, ModuleAnalysisManager &AM) {
bool Changed = false;

for (Function &F : M) {
(errs() << "[HipPromoteInts] Analyzing function: " << F.getName() << "\n");
bool Changed = false;
SmallPtrSet<Instruction*, 32> GlobalVisited; // Track all visited instructions across chains

for (BasicBlock &BB : F) {
// Use a vector to store instructions that need modification
std::vector<Instruction*> WorkList;
for (Instruction &I : BB) {
WorkList.push_back(&I);
}

// Process the worklist safely outside the BB iteration
for (Instruction *I : WorkList) {
if (auto *IntTy = dyn_cast<IntegerType>(I->getType())) {
if (!isStandardBitWidth(IntTy->getBitWidth())) {
(errs() << "[HipPromoteInts] Found non-standard type in result: " << *I << "\n");

unsigned NextStdSize = getPromotedBitWidth(IntTy->getBitWidth());
Type *PromotedType = Type::getIntNTy(M.getContext(), NextStdSize);

(errs() << "[HipPromoteInts] Promoting from i" << IntTy->getBitWidth()
<< " to i" << NextStdSize << "\n\nDependency chain:\n");

SmallPtrSet<Instruction*, 8> Visited;
SmallVector<Replacement, 16> Replacements;
SmallDenseMap<Value*, Value*> PromotedValues;

// First create all new instructions
promoteChainPrint(I, IntTy, PromotedType, Visited, Replacements, PromotedValues, 0);

// Now update all uses that aren't part of our promotion chain
for (const auto &R : Replacements) {
for (auto &U : R.Old->uses()) {
User *User = U.getUser();
// Only update uses that aren't in our promotion chain
if (!Visited.count(cast<Instruction>(User))) {
U.set(R.New);
for (Function &F : M) {
SmallVector<Instruction*, 16> WorkList;

// First collect all instructions we need to promote
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
if (auto *IntTy = dyn_cast<IntegerType>(I.getType())) {
if (!isStandardBitWidth(IntTy->getBitWidth())) {
WorkList.push_back(&I);
}
}
}

// Finally, remove old instructions in reverse order
for (auto It = Replacements.rbegin(); It != Replacements.rend(); ++It) {
It->Old->eraseFromParent();
}

// Process the worklist
for (Instruction *I : WorkList) {
// Skip if we've already processed this instruction as part of another chain
if (GlobalVisited.count(I))
continue;

if (auto *IntTy = dyn_cast<IntegerType>(I->getType())) {
if (!isStandardBitWidth(IntTy->getBitWidth())) {
Type *PromotedType = Type::getInt64Ty(M.getContext());

SmallVector<Replacement, 16> Replacements;
SmallDenseMap<Value*, Value*> PromotedValues;

// Use GlobalVisited instead of creating a new set
promoteChainPrint(I, IntTy, PromotedType, GlobalVisited,
Replacements, PromotedValues, 0);

// Update uses and cleanup as before
for (const auto &R : Replacements) {
for (auto &U : R.Old->uses()) {
User *User = U.getUser();
if (!GlobalVisited.count(cast<Instruction>(User))) {
U.set(R.New);
}
}
}

for (auto It = Replacements.rbegin(); It != Replacements.rend(); ++It) {
It->Old->eraseFromParent();
}

Changed = true;
}
}

Changed = true;
}
}
}
}
}

return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();

return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}

0 comments on commit 5e3d475

Please sign in to comment.