Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAYOUTS] [BE] Simplify Ampere/Hopper paths introduced in #5189 #5200

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 40 additions & 39 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,23 +938,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
return elemsPerThread;
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.isAmpere() || mma.isHopper()) {
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
auto sizePerThread = getSizePerThread();
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
if (rank == 3)
elemsPerThread[0] = rep[0];
elemsPerThread[rank - 2] =
(idx == 0)
? rep[1] * sizePerThread[rank - 2]
: std::max<int>(rep[1] * elemsPerKRep, sizePerThread[rank - 2]);
elemsPerThread[rank - 1] =
(idx == 0)
? std::max<int>(rep[2] * elemsPerKRep, sizePerThread[rank - 1])
: rep[2] * sizePerThread[rank - 1];
return elemsPerThread;
assert(getCTALayout(*this) ==
CTALayoutAttr::getDefault(getContext(), rank) &&
"NYI");
auto sizePerThread = getSizePerThread();
auto threadsPerWarp = getThreadsPerWarp();
auto warpsPerCTA = getWarpsPerCTA();
SmallVector<unsigned> regs;
for (auto [n, nsize, nThread, nWarp] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: nsize->nSize

llvm::zip(shape, sizePerThread, threadsPerWarp, warpsPerCTA)) {
regs.push_back(std::max<int64_t>(nsize, n / (nThread * nWarp)));
}
return regs;
}

llvm_unreachable("getElemsPerThread is not supported for dot operand");
Expand Down Expand Up @@ -1976,35 +1971,41 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int kWidth, int opIdx) const {
assert(
kWidth >= 32 / bitwidth &&
"kWidth must be >= 32 / bitwidth for this function to be well-defined");
auto rank = shape.size();
// Broadcast long K
auto warpsPerCTA = getWarpsPerCTA();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
warpsPerCTA[kDim] = 1;

// {batch, m, n, k}
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
// TODO: rep per operand is not accurate for Hopper. It is currently done that
// way to allow us to get the correct total number of elements. this will be
// fixed when moving to linear layout.
SmallVector<int> shapePerWarp = {
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
: 1;

SmallVector<int> tileSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarity: maybe we can initialize the tileSize and numRep with all 1s.

Suggested change
SmallVector<int> tileSize;
SmallVector<int> tileSize(rank ,1);

Then we use something like

tileSize[rank - 2] = 16;

It's up to you.

if (rank == 3) {
tileSize.push_back(1);
}
if (opIdx == 0) {
return {numRepBatch,
std::max<int64_t>(1, /*repM=*/shape[rank - 2] /
(shapePerWarp[1] * warpsPerCTA[rank - 2])),
std::max<int64_t>(1, /*repK=*/shape[rank - 1] / shapePerWarp[3])};
// m x k
tileSize.push_back(16);
tileSize.push_back(4 * 64 / bitwidth);
} else {
assert(opIdx == 1);
return {
numRepBatch,
std::max<int64_t>(1, /*repK=*/shape[rank - 2] / shapePerWarp[3]),
std::max<int64_t>(1, /*repN=*/shape[rank - 1] /
(shapePerWarp[2] * warpsPerCTA[rank - 1]))};
// k x n
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
// so it's fine if the n is incorrect here
tileSize.push_back(4 * 64 / bitwidth);
tileSize.push_back(8);
}

SmallVector<int64_t> numRep;
// Lezcano: This is odd. Why do we always return a vector of size 3?
if (rank != 3) {
numRep.push_back(1);
}
for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) {
numRep.push_back(std::max<int64_t>(1, s / (size * warp)));
}
return numRep;
}

SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
Expand Down
Loading