Skip to content

Commit

Permalink
[AArch64]: Use PatternMatch to spot (de)interleave accesses
Browse files Browse the repository at this point in the history
Change-Id: Id7639dcb125a2f642b2fea78ea884b74be1c6b74
  • Loading branch information
hassnaaHamdi committed May 17, 2024
1 parent 71120ec commit dec7c6b
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 366 deletions.
4 changes: 0 additions & 4 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
#include <cstdint>
#include <iterator>
#include <map>
#include <queue>
#include <stack>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -3159,7 +3157,6 @@ class TargetLoweringBase {
/// \p DI is the deinterleave intrinsic.
/// \p LI is the accompanying load instruction
virtual bool lowerDeinterleaveIntrinsicToLoad(IntrinsicInst *DI,
SmallVector<Value *> &LeafNodes,
LoadInst *LI) const {
return false;
}
Expand All @@ -3171,7 +3168,6 @@ class TargetLoweringBase {
/// \p II is the interleave intrinsic.
/// \p SI is the accompanying store instruction
virtual bool lowerInterleaveIntrinsicToStore(IntrinsicInst *II,
SmallVector<Value *> &LeafNodes,
StoreInst *SI) const {
return false;
}
Expand Down
83 changes: 6 additions & 77 deletions llvm/lib/CodeGen/InterleavedAccessPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
#include <queue>
#include <utility>

using namespace llvm;
Expand Down Expand Up @@ -489,57 +488,12 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(

LLVM_DEBUG(dbgs() << "IA: Found a deinterleave intrinsic: " << *DI << "\n");

std::stack<IntrinsicInst *> DeinterleaveTreeQueue;
SmallVector<Value *> TempLeafNodes, LeafNodes;
std::map<IntrinsicInst *, bool> mp;
SmallVector<Instruction *> TempDeadInsts;

DeinterleaveTreeQueue.push(DI);
while (!DeinterleaveTreeQueue.empty()) {
auto CurrentDI = DeinterleaveTreeQueue.top();
DeinterleaveTreeQueue.pop();
TempDeadInsts.push_back(CurrentDI);
// iterate over extract users of deinterleave
for (auto UserExtract : CurrentDI->users()) {
Instruction *Extract = dyn_cast<Instruction>(UserExtract);
if (!Extract || Extract->getOpcode() != Instruction::ExtractValue)
continue;
bool IsLeaf = true;
// iterate over deinterleave users of extract
for (auto UserDI : UserExtract->users()) {
IntrinsicInst *Child_DI = dyn_cast<IntrinsicInst>(UserDI);
if (!Child_DI || Child_DI->getIntrinsicID() !=
Intrinsic::experimental_vector_deinterleave2)
continue;
IsLeaf = false;
if (mp.count(Child_DI) == 0) {
DeinterleaveTreeQueue.push(Child_DI);
}
continue;
}
if (IsLeaf) {
TempLeafNodes.push_back(UserExtract);
TempDeadInsts.push_back(Extract);
} else {
TempDeadInsts.push_back(Extract);
}
}
}
// sort the deinterleaved nodes in the order that
// they will be extracted from the target-specific intrinsic.
for (unsigned I = 1; I < TempLeafNodes.size(); I += 2)
LeafNodes.push_back(TempLeafNodes[I]);

for (unsigned I = 0; I < TempLeafNodes.size(); I += 2)
LeafNodes.push_back(TempLeafNodes[I]);

// Try and match this with target specific intrinsics.
if (!TLI->lowerDeinterleaveIntrinsicToLoad(DI, LeafNodes, LI))
if (!TLI->lowerDeinterleaveIntrinsicToLoad(DI, LI))
return false;

// We now have a target-specific load, so delete the old one.
DeadInsts.insert(DeadInsts.end(), TempDeadInsts.rbegin(),
TempDeadInsts.rend());
DeadInsts.push_back(DI);
DeadInsts.push_back(LI);
return true;
}
Expand All @@ -555,38 +509,14 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
return false;

LLVM_DEBUG(dbgs() << "IA: Found an interleave intrinsic: " << *II << "\n");
std::queue<IntrinsicInst *> IeinterleaveTreeQueue;
SmallVector<Value *> TempLeafNodes, LeafNodes;
SmallVector<Instruction *> TempDeadInsts;

IeinterleaveTreeQueue.push(II);
while (!IeinterleaveTreeQueue.empty()) {
auto node = IeinterleaveTreeQueue.front();
TempDeadInsts.push_back(node);
IeinterleaveTreeQueue.pop();
for (unsigned i = 0; i < 2; i++) {
auto op = node->getOperand(i);
if (auto CurrentII = dyn_cast<IntrinsicInst>(op)) {
if (CurrentII->getIntrinsicID() !=
Intrinsic::experimental_vector_interleave2)
continue;
IeinterleaveTreeQueue.push(CurrentII);
continue;
}
TempLeafNodes.push_back(op);
}
}
for (unsigned I = 0; I < TempLeafNodes.size(); I += 2)
LeafNodes.push_back(TempLeafNodes[I]);
for (unsigned I = 1; I < TempLeafNodes.size(); I += 2)
LeafNodes.push_back(TempLeafNodes[I]);

// Try and match this with target specific intrinsics.
if (!TLI->lowerInterleaveIntrinsicToStore(II, LeafNodes, SI))
if (!TLI->lowerInterleaveIntrinsicToStore(II, SI))
return false;

// We now have a target-specific store, so delete the old one.
DeadInsts.push_back(SI);
DeadInsts.insert(DeadInsts.end(), TempDeadInsts.begin(), TempDeadInsts.end());
DeadInsts.push_back(II);
return true;
}

Expand All @@ -607,8 +537,7 @@ bool InterleavedAccessImpl::runOnFunction(Function &F) {
// with a factor of 2.
if (II->getIntrinsicID() == Intrinsic::vector_deinterleave2)
Changed |= lowerDeinterleaveIntrinsic(II, DeadInsts);

else if (II->getIntrinsicID() == Intrinsic::vector_interleave2)
if (II->getIntrinsicID() == Intrinsic::vector_interleave2)
Changed |= lowerInterleaveIntrinsic(II, DeadInsts);
}
}
Expand Down
153 changes: 104 additions & 49 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16441,18 +16441,56 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
return true;
}

bool GetDeinterleaveLeaves(Value *DI,
SmallVectorImpl<Value *> &DeinterleaveUsers,
SmallVectorImpl<Instruction *> &DeadInsts) {
if (!DI->hasNUses(2))
return false;

auto *Extr0 = *(++DI->user_begin());
auto *Extr1 = *(DI->user_begin());
if (!match(Extr0, m_ExtractValue<0>(m_Deinterleave2(m_Value()))))
return false;

auto De1 = *(Extr0->user_begin());
if (!GetDeinterleaveLeaves(De1, DeinterleaveUsers, DeadInsts))
// leaf extract
DeinterleaveUsers.push_back(Extr0);
else {
// parent extract that will not be used anymore
DeadInsts.push_back(dyn_cast<Instruction>(De1));
DeadInsts.push_back(dyn_cast<Instruction>(Extr0));
}
auto De2 = *(Extr1->user_begin());
if (!GetDeinterleaveLeaves(De2, DeinterleaveUsers, DeadInsts))
// leaf extract
DeinterleaveUsers.push_back(Extr1);
else {
// parent extract that will not be used anymore
DeadInsts.push_back(dyn_cast<Instruction>(De2));
DeadInsts.push_back(dyn_cast<Instruction>(Extr1));
}
return true;
}

bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
IntrinsicInst *DI, SmallVector<Value *> &LeafNodes, LoadInst *LI) const {
IntrinsicInst *DI, LoadInst *LI) const {
// Only deinterleave2 supported at present.
if (DI->getIntrinsicID() != Intrinsic::vector_deinterleave2)
return false;

const unsigned Factor = std::max(2, (int)LeafNodes.size());

VectorType *VTy = (LeafNodes.size() > 0)
? cast<VectorType>(LeafNodes.front()->getType())
: cast<VectorType>(DI->getType()->getContainedType(0));
SmallVector<Value *, 4> ValuesToDeinterleave;
SmallVector<Instruction *, 10> DeadInsts;
const DataLayout &DL = DI->getModule()->getDataLayout();
unsigned Factor = 2;
VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));
if (GetDeinterleaveLeaves(DI, ValuesToDeinterleave, DeadInsts)) {
Factor = ValuesToDeinterleave.size();
VTy = cast<VectorType>(ValuesToDeinterleave[0]->getType());
}

assert(Factor && "Expected Interleave Factor >= 2");

bool UseScalable;
if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
return false;
Expand All @@ -16463,7 +16501,6 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
return false;

unsigned NumLoads = getNumInterleavedAccesses(VTy, DL, UseScalable);

VectorType *LdTy =
VectorType::get(VTy->getElementType(),
VTy->getElementCount().divideCoefficientBy(NumLoads));
Expand All @@ -16473,7 +16510,6 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
UseScalable, LdTy, PtrTy);

IRBuilder<> Builder(LI);

Value *Pred = nullptr;
if (UseScalable)
Pred =
Expand All @@ -16482,9 +16518,8 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
Value *BaseAddr = LI->getPointerOperand();
Value *Result;
if (NumLoads > 1) {
Value *Left = PoisonValue::get(VTy);
Value *Right = PoisonValue::get(VTy);

// Create multiple legal small ldN instead of a wide one.
SmallVector<Value *, 4> WideValues(Factor, (PoisonValue::get(VTy)));
for (unsigned I = 0; I < NumLoads; ++I) {
Value *Offset = Builder.getInt64(I * Factor);

Expand All @@ -16494,49 +16529,71 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
LdN = Builder.CreateCall(LdNFunc, {Pred, Address}, "ldN");
else
LdN = Builder.CreateCall(LdNFunc, Address, "ldN");

Value *Idx =
Builder.getInt64(I * LdTy->getElementCount().getKnownMinValue());
Left = Builder.CreateInsertVector(
VTy, Left, Builder.CreateExtractValue(LdN, 0), Idx);
Right = Builder.CreateInsertVector(
VTy, Right, Builder.CreateExtractValue(LdN, 1), Idx);
for (int J = 0; J < Factor; ++J) {
WideValues[J] = Builder.CreateInsertVector(
VTy, WideValues[J], Builder.CreateExtractValue(LdN, J), Idx);
}
}
// FIXME: the types should NOT be added manually.
if (2 == Factor)
Result = PoisonValue::get(StructType::get(VTy, VTy));
else
Result = PoisonValue::get(StructType::get(VTy, VTy, VTy, VTy));
// Construct the wide result out of the small results.
for (int J = 0; J < Factor; ++J) {
Result = Builder.CreateInsertValue(Result, WideValues[J], J);
}

Result = PoisonValue::get(DI->getType());
Result = Builder.CreateInsertValue(Result, Left, 0);
Result = Builder.CreateInsertValue(Result, Right, 1);
} else {
if (UseScalable) {
if (UseScalable)
Result = Builder.CreateCall(LdNFunc, {Pred, BaseAddr}, "ldN");
if (Factor == 2) {
DI->replaceAllUsesWith(Result);
return true;
}
for (unsigned I = 0; I < LeafNodes.size(); I++) {
llvm::Value *CurrentExtract = LeafNodes[I];
Value *Newextrct = Builder.CreateExtractValue(Result, I);
CurrentExtract->replaceAllUsesWith(Newextrct);
}
return true;
} else
else
Result = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
}
if (Factor > 2) {
for (unsigned I = 0; I < ValuesToDeinterleave.size(); I++) {
llvm::Value *CurrentExtract = ValuesToDeinterleave[I];
Value *NewExtract = Builder.CreateExtractValue(Result, I);
CurrentExtract->replaceAllUsesWith(NewExtract);
dyn_cast<Instruction>(CurrentExtract)->eraseFromParent();
}

for (auto &dead : DeadInsts)
dead->eraseFromParent();
return true;
}
DI->replaceAllUsesWith(Result);
return true;
}

bool GetInterleaveLeaves(Value *II, SmallVectorImpl<Value *> &InterleaveOps) {
Value *Op0, *Op1;
if (!match(II, m_Interleave2(m_Value(Op0), m_Value(Op1))))
return false;

if (!GetInterleaveLeaves(Op0, InterleaveOps)) {
InterleaveOps.push_back(Op0);
}

if (!GetInterleaveLeaves(Op1, InterleaveOps)) {
InterleaveOps.push_back(Op1);
}
return true;
}

bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore(
IntrinsicInst *II, SmallVector<Value *> &LeafNodes, StoreInst *SI) const {
IntrinsicInst *II, StoreInst *SI) const {
// Only interleave2 supported at present.
if (II->getIntrinsicID() != Intrinsic::vector_interleave2)
return false;

// leaf nodes are the nodes that will be interleaved
const unsigned Factor = LeafNodes.size();
SmallVector<Value *, 4> ValuesToInterleave;
GetInterleaveLeaves(II, ValuesToInterleave);
unsigned Factor = ValuesToInterleave.size();
assert(Factor >= 2 && "Expected Interleave Factor >= 2");
VectorType *VTy = cast<VectorType>(ValuesToInterleave[0]->getType());

VectorType *VTy = cast<VectorType>(LeafNodes.front()->getType());
const DataLayout &DL = II->getModule()->getDataLayout();
bool UseScalable;
if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
Expand Down Expand Up @@ -16566,28 +16623,26 @@ bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore(
Pred =
Builder.CreateVectorSplat(StTy->getElementCount(), Builder.getTrue());

Value *L = II->getOperand(0);
Value *R = II->getOperand(1);

auto InterleaveOps = ValuesToInterleave;
if (UseScalable)
ValuesToInterleave.push_back(Pred);
ValuesToInterleave.push_back(BaseAddr);
for (unsigned I = 0; I < NumStores; ++I) {
Value *Address = BaseAddr;
if (NumStores > 1) {
Value *Offset = Builder.getInt64(I * Factor);
Address = Builder.CreateGEP(StTy, BaseAddr, {Offset});

Value *Idx =
Builder.getInt64(I * StTy->getElementCount().getKnownMinValue());
L = Builder.CreateExtractVector(StTy, II->getOperand(0), Idx);
R = Builder.CreateExtractVector(StTy, II->getOperand(1), Idx);
for (int J = 0; J < Factor; J++) {
ValuesToInterleave[J] =
Builder.CreateExtractVector(StTy, InterleaveOps[J], Idx);
}
// update the address
ValuesToInterleave[ValuesToInterleave.size() - 1] = Address;
}

if (UseScalable) {
SmallVector<Value *> Args(LeafNodes);
Args.push_back(Pred);
Args.push_back(Address);
Builder.CreateCall(StNFunc, Args);
} else
Builder.CreateCall(StNFunc, {L, R, Address});
Builder.CreateCall(StNFunc, ValuesToInterleave);
}

return true;
Expand Down
2 changes: 0 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -686,11 +686,9 @@ class AArch64TargetLowering : public TargetLowering {
unsigned Factor) const override;

bool lowerDeinterleaveIntrinsicToLoad(IntrinsicInst *DI,
SmallVector<Value *> &LeafNodes,
LoadInst *LI) const override;

bool lowerInterleaveIntrinsicToStore(IntrinsicInst *II,
SmallVector<Value *> &LeafNodes,
StoreInst *SI) const override;

bool isLegalAddImmediate(int64_t) const override;
Expand Down
Loading

0 comments on commit dec7c6b

Please sign in to comment.