Skip to content

Commit

Permalink
Fix assmuption of the extraction order, make it generic then make sur…
Browse files Browse the repository at this point in the history
…e of the order using pattern match

Change-Id: I053e47d156c37cf4d7ab5b2af83c348b4210631a
  • Loading branch information
hassnaaHamdi committed Jun 18, 2024
1 parent 6b0cbee commit e5154bd
Showing 1 changed file with 61 additions and 49 deletions.
110 changes: 61 additions & 49 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16586,67 +16586,80 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
}

bool getDeinterleavedValues(Value *DI,
SmallVectorImpl<Value *> &DeinterleavedValues,
SmallVectorImpl<Instruction *> &DeadInsts) {
if (!DI->hasNUses(2))
SmallVectorImpl<Instruction *> &DeinterleavedValues) {
if (!DI->hasNUsesOrMore(2))
return false;

// make sure that the users of DI are extractValue instructions
auto *Extr0 = *(++DI->user_begin());
if (!match(Extr0, m_ExtractValue<0>(m_Deinterleave2(m_Value()))))
return false;
auto *Extr1 = *(DI->user_begin());
if (!match(Extr1, m_ExtractValue<1>(m_Deinterleave2(m_Value()))))
auto *Extr1 = dyn_cast<ExtractValueInst>(*(DI->user_begin()));
auto *Extr2 = dyn_cast<ExtractValueInst>(*(++DI->user_begin()));
if (!Extr1 || !Extr2)
return false;

// each extractValue instruction is expected to have a single user,
// which should be another DI
if (!Extr0->hasOneUser() || !Extr1->hasOneUser())
if (!Extr1->hasNUsesOrMore(1) || !Extr2->hasNUsesOrMore(1))
return false;
auto *DI1 = *(Extr0->user_begin());
if (!match(DI1, m_Deinterleave2(m_Value())))
auto *DI1 = *(Extr1->user_begin());
auto *DI2 = *(Extr2->user_begin());

if (!DI1->hasNUsesOrMore(2) || !DI2->hasNUsesOrMore(2))
return false;
auto *DI2 = *(Extr1->user_begin());
if (!match(DI2, m_Deinterleave2(m_Value())))
// Leaf nodes of the deinterleave tree:
auto *A = dyn_cast<ExtractValueInst>(*(DI1->user_begin()));
auto *B = dyn_cast<ExtractValueInst>(*(++DI1->user_begin()));
auto *C = dyn_cast<ExtractValueInst>(*(DI2->user_begin()));
auto *D = dyn_cast<ExtractValueInst>(*(++DI2->user_begin()));
// Make sure that the A,B,C,D are instructions of ExtractValue,
// before getting the extract index
if (!A || !B || !C || !D)
return false;

if (!DI1->hasNUses(2) || !DI2->hasNUses(2))
DeinterleavedValues.resize(4);
// Place the values into the vector in the order of extraction:
DeinterleavedValues[A->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = A;
DeinterleavedValues[B->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = B;
DeinterleavedValues[C->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = C;
DeinterleavedValues[D->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = D;

// Make sure that A,B,C,D match the deinterleave tree pattern
if (!match(DeinterleavedValues[0], m_ExtractValue<0>(m_Deinterleave2(
m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
!match(DeinterleavedValues[1], m_ExtractValue<1>(m_Deinterleave2(
m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
!match(DeinterleavedValues[2], m_ExtractValue<0>(m_Deinterleave2(
m_ExtractValue<1>(m_Deinterleave2(m_Value()))))) ||
!match(DeinterleavedValues[3], m_ExtractValue<1>(m_Deinterleave2(
m_ExtractValue<1>(m_Deinterleave2(m_Value())))))) {
LLVM_DEBUG(dbgs() << "matching deinterleave4 failed\n");
return false;

// Leaf nodes of the deinterleave tree
auto *A = *(++DI1->user_begin());
auto *C = *(DI1->user_begin());
auto *B = *(++DI2->user_begin());
auto *D = *(DI2->user_begin());

DeinterleavedValues.push_back(A);
DeinterleavedValues.push_back(B);
DeinterleavedValues.push_back(C);
DeinterleavedValues.push_back(D);

// These Values will not be used anymre,
// DI4 will be created instead of nested DI1 and DI2
DeadInsts.push_back(cast<Instruction>(DI1));
DeadInsts.push_back(cast<Instruction>(Extr0));
DeadInsts.push_back(cast<Instruction>(DI2));
DeadInsts.push_back(cast<Instruction>(Extr1));

}
// Order the values according to the deinterleaving order.
std::swap(DeinterleavedValues[1], DeinterleavedValues[2]);
return true;
}

void deleteDeadDeinterleaveInstructions(Instruction *DeadRoot) {
Value *DeadDeinterleave = nullptr, *DeadExtract = nullptr;
match(DeadRoot, m_ExtractValue(m_Value(DeadDeinterleave)));
assert(DeadDeinterleave != nullptr && "Match is expected to succeed");
match(DeadDeinterleave, m_Deinterleave2(m_Value(DeadExtract)));
assert(DeadExtract != nullptr && "Match is expected to succeed");
DeadRoot->eraseFromParent();
if (DeadDeinterleave->getNumUses() == 0)
cast<Instruction>(DeadDeinterleave)->eraseFromParent();
if (DeadExtract->getNumUses() == 0)
cast<Instruction>(DeadExtract)->eraseFromParent();
}

bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
IntrinsicInst *DI, LoadInst *LI) const {
// Only deinterleave2 supported at present.
if (DI->getIntrinsicID() != Intrinsic::vector_deinterleave2)
return false;

SmallVector<Value *, 4> DeinterleavedValues;
SmallVector<Instruction *, 10> DeadInsts;
SmallVector<Instruction *, 4> DeinterleavedValues;
const DataLayout &DL = DI->getModule()->getDataLayout();
unsigned Factor = 2;
VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));

if (getDeinterleavedValues(DI, DeinterleavedValues, DeadInsts)) {
if (getDeinterleavedValues(DI, DeinterleavedValues)) {
Factor = DeinterleavedValues.size();
VTy = cast<VectorType>(DeinterleavedValues[0]->getType());
}
Expand Down Expand Up @@ -16693,7 +16706,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
LdN = Builder.CreateCall(LdNFunc, Address, "ldN");
Value *Idx =
Builder.getInt64(I * LdTy->getElementCount().getKnownMinValue());
for (int J = 0; J < Factor; ++J) {
for (unsigned J = 0; J < Factor; ++J) {
WideValues[J] = Builder.CreateInsertVector(
VTy, WideValues[J], Builder.CreateExtractValue(LdN, J), Idx);
}
Expand All @@ -16703,7 +16716,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
else
Result = PoisonValue::get(StructType::get(VTy, VTy, VTy, VTy));
// Construct the wide result out of the small results.
for (int J = 0; J < Factor; ++J) {
for (unsigned J = 0; J < Factor; ++J) {
Result = Builder.CreateInsertValue(Result, WideValues[J], J);
}
} else {
Expand All @@ -16713,15 +16726,14 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
Result = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
}
if (Factor > 2) {
// Itereate over old deinterleaved values to replace it by
// the new deinterleaved values.
for (unsigned I = 0; I < DeinterleavedValues.size(); I++) {
llvm::Value *CurrentExtract = DeinterleavedValues[I];
Value *NewExtract = Builder.CreateExtractValue(Result, I);
CurrentExtract->replaceAllUsesWith(NewExtract);
cast<Instruction>(CurrentExtract)->eraseFromParent();
DeinterleavedValues[I]->replaceAllUsesWith(NewExtract);
}

for (auto &dead : DeadInsts)
dead->eraseFromParent();
for (unsigned I = 0; I < DeinterleavedValues.size(); I++)
deleteDeadDeinterleaveInstructions(DeinterleavedValues[I]);
return true;
}
DI->replaceAllUsesWith(Result);
Expand Down Expand Up @@ -16803,7 +16815,7 @@ bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore(
Address = Builder.CreateGEP(StTy, BaseAddr, {Offset});
Value *Idx =
Builder.getInt64(I * StTy->getElementCount().getKnownMinValue());
for (int J = 0; J < Factor; J++) {
for (unsigned J = 0; J < Factor; J++) {
ValuesToInterleave[J] =
Builder.CreateExtractVector(StTy, WideValues[J], Idx);
}
Expand Down

0 comments on commit e5154bd

Please sign in to comment.