Skip to content

Commit

Permalink
Always promote phi node
Browse files Browse the repository at this point in the history
  • Loading branch information
pvelesko committed Jan 23, 2025
1 parent d0fe9e6 commit a16bf16
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
44 changes: 28 additions & 16 deletions llvm_passes/HipPromoteInts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,26 @@ unsigned HipPromoteIntsPass::getPromotedBitWidth(unsigned Original) {
return 64;
}

// Add this new structure to track replacements
Type* HipPromoteIntsPass::getPromotedType(Type* TypeToPromote) {
if (auto* IntTy = dyn_cast<IntegerType>(TypeToPromote)) {
unsigned PromotedWidth = getPromotedBitWidth(IntTy->getBitWidth());
return Type::getIntNTy(TypeToPromote->getContext(), PromotedWidth);
}
return TypeToPromote; // Return original type if not an integer
}

struct Replacement {
Instruction* Old;
Value* New;
Replacement(Instruction* O, Value* N) : Old(O), New(N) {}
};

void processInstruction(Instruction *I, Type *OldTy, Type *PromotedTy, std::string Indent,
void processInstruction(Instruction *I, Type *NonStdType, Type *PromotedTy, std::string Indent,
SmallVectorImpl<Replacement> &Replacements,
SmallDenseMap<Value*, Value*> &PromotedValues) {
IRBuilder<> Builder(I);

// Helper to get or create promoted version of a value
/// Helper to get or create promoted version of a value
auto getPromotedValue = [&](Value* V) -> Value* {
// First check if we already promoted this value
if (PromotedValues.count(V))
Expand All @@ -65,8 +72,8 @@ void processInstruction(Instruction *I, Type *OldTy, Type *PromotedTy, std::stri
if (V->getType() == PromotedTy)
return V;

// If it's the old type, promote it
if (V->getType() == OldTy) {
// If it's the non-standard type, promote it
if (V->getType() == NonStdType) {
auto NewV = Builder.CreateZExt(V, PromotedTy);
PromotedValues[V] = NewV;
return NewV;
Expand All @@ -78,7 +85,9 @@ void processInstruction(Instruction *I, Type *OldTy, Type *PromotedTy, std::stri

if (isa<PHINode>(I)) {
PHINode* Phi = cast<PHINode>(I);
PHINode* NewPhi = PHINode::Create(Phi->getType(), Phi->getNumIncomingValues(), "", Phi);
// Create new PHI node with the promoted type (e.g., i64) instead of original type
Type* PromotedType = HipPromoteIntsPass::getPromotedType(Phi->getType());
PHINode* NewPhi = PHINode::Create(PromotedType, Phi->getNumIncomingValues(), "", Phi);

// Copy all incoming values and blocks
for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) {
Expand All @@ -88,24 +97,27 @@ void processInstruction(Instruction *I, Type *OldTy, Type *PromotedTy, std::stri
// If the incoming value is from our promotion chain, use the promoted value
Value* NewIncomingValue = PromotedValues.count(IncomingValue) ?
PromotedValues[IncomingValue] : IncomingValue;

if (NewIncomingValue->getType() != Phi->getType())
NewIncomingValue = Builder.CreateTrunc(NewIncomingValue, Phi->getType());

// If the incoming value isn't promoted yet, promote it now
if (NewIncomingValue->getType() != PromotedType) {
NewIncomingValue = Builder.CreateZExt(NewIncomingValue, PromotedType);
}

NewPhi->addIncoming(NewIncomingValue, IncomingBlock);
}

errs() << Indent << " " << *I << " ============> " << *NewPhi << "\n";
PromotedValues[I] = NewPhi;
PromotedValues[Phi] = NewPhi;
Phi->replaceAllUsesWith(NewPhi);
Replacements.push_back(Replacement(I, NewPhi));
}
else if (isa<ZExtInst>(I)) {
ZExtInst* ZExtI = cast<ZExtInst>(I);
Value* SrcOp = ZExtI->getOperand(0);

// If we're extending from our old type to our promoted type,
// If we're extending from our non-standard type to our promoted type,
// just use the promoted value directly
if (SrcOp->getType() == OldTy && ZExtI->getDestTy() == PromotedTy) {
if (SrcOp->getType() == NonStdType && ZExtI->getDestTy() == PromotedTy) {
Value* PromotedSrc = PromotedValues.count(SrcOp) ? PromotedValues[SrcOp] : SrcOp;
errs() << Indent << " " << *I << " ============> Using promoted: " << *PromotedSrc << "\n";
PromotedValues[I] = PromotedSrc;
Expand Down Expand Up @@ -141,7 +153,7 @@ void processInstruction(Instruction *I, Type *OldTy, Type *PromotedTy, std::stri
}
else if (isa<BinaryOperator>(I)) {
BinaryOperator* BinOp = cast<BinaryOperator>(I);
bool NeedsPromotion = (BinOp->getType() == OldTy);
bool NeedsPromotion = (BinOp->getType() == NonStdType);

Value* LHS = getPromotedValue(BinOp->getOperand(0));
Value* RHS = getPromotedValue(BinOp->getOperand(1));
Expand Down Expand Up @@ -195,7 +207,7 @@ void processInstruction(Instruction *I, Type *OldTy, Type *PromotedTy, std::stri
}
}

bool promoteChainPrint(Instruction *OldI, Type *OldTy, Type *PromotedTy,
bool promoteChainPrint(Instruction *OldI, Type *NonStdType, Type *PromotedTy,
SmallPtrSetImpl<Instruction*> &Visited,
SmallVectorImpl<Replacement> &Replacements,
SmallDenseMap<Value*, Value*> &PromotedValues,
Expand All @@ -213,12 +225,12 @@ bool promoteChainPrint(Instruction *OldI, Type *OldTy, Type *PromotedTy,
std::string Indent(Depth * 2, ' ');

// Process instruction
processInstruction(OldI, OldTy, PromotedTy, Indent, Replacements, PromotedValues);
processInstruction(OldI, NonStdType, PromotedTy, Indent, Replacements, PromotedValues);

// Recursively process all users
for (User *U : OldI->users()) {
if (auto *UI = dyn_cast<Instruction>(U)) {
promoteChainPrint(UI, OldTy, PromotedTy, Visited, Replacements, PromotedValues, Depth + 1);
promoteChainPrint(UI, NonStdType, PromotedTy, Visited, Replacements, PromotedValues, Depth + 1);
}
}

Expand Down
3 changes: 3 additions & 0 deletions llvm_passes/HipPromoteInts.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class HipPromoteIntsPass : public PassInfoMixin<HipPromoteIntsPass> {

// Check if the given bit width is a standard size (8, 16, 32, 64)
static bool isStandardBitWidth(unsigned BitWidth);

// Get the promoted type for a given type
static Type* getPromotedType(Type* TypeToPromote);
};

} // namespace llvm
Expand Down

0 comments on commit a16bf16

Please sign in to comment.