Skip to content

Commit

Permalink
Strengthen getRepForOperand and simplify getElemsPerThread
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Nov 20, 2024
1 parent 7bce361 commit 54aba83
Showing 1 changed file with 40 additions and 39 deletions.
79 changes: 40 additions & 39 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,23 +938,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
return elemsPerThread;
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.isAmpere() || mma.isHopper()) {
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
auto sizePerThread = getSizePerThread();
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
if (rank == 3)
elemsPerThread[0] = rep[0];
elemsPerThread[rank - 2] =
(idx == 0)
? rep[1] * sizePerThread[rank - 2]
: std::max<int>(rep[1] * elemsPerKRep, sizePerThread[rank - 2]);
elemsPerThread[rank - 1] =
(idx == 0)
? std::max<int>(rep[2] * elemsPerKRep, sizePerThread[rank - 1])
: rep[2] * sizePerThread[rank - 1];
return elemsPerThread;
assert(getCTALayout(*this) ==
CTALayoutAttr::getDefault(getContext(), rank) &&
"NYI");
auto sizePerThread = getSizePerThread();
auto threadsPerWarp = getThreadsPerWarp();
auto warpsPerCTA = getWarpsPerCTA();
SmallVector<unsigned> regs;
for (auto [n, nsize, nThread, nWarp] :
llvm::zip(shape, sizePerThread, threadsPerWarp, warpsPerCTA)) {
regs.push_back(std::max<int64_t>(nsize, n / (nThread * nWarp)));
}
return regs;
}

llvm_unreachable("getElemsPerThread is not supported for dot operand");
Expand Down Expand Up @@ -1976,35 +1971,41 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int kWidth, int opIdx) const {
assert(
kWidth >= 32 / bitwidth &&
"kWidth must be >= 32 / bitwidth for this function to be well-defined");
auto rank = shape.size();
// Broadcast long K
auto warpsPerCTA = getWarpsPerCTA();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
warpsPerCTA[kDim] = 1;

// {batch, m, n, k}
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
// TODO: rep per operand is not accurate for Hopper. It is currently done that
// way to allow us to get the correct total number of elements. this will be
// fixed when moving to linear layout.
SmallVector<int> shapePerWarp = {
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
: 1;

SmallVector<int> tileSize;
if (rank == 3) {
tileSize.push_back(1);
}
if (opIdx == 0) {
return {numRepBatch,
std::max<int64_t>(1, /*repM=*/shape[rank - 2] /
(shapePerWarp[1] * warpsPerCTA[rank - 2])),
std::max<int64_t>(1, /*repK=*/shape[rank - 1] / shapePerWarp[3])};
// m x k
tileSize.push_back(16);
tileSize.push_back(4 * 64 / bitwidth);
} else {
assert(opIdx == 1);
return {
numRepBatch,
std::max<int64_t>(1, /*repK=*/shape[rank - 2] / shapePerWarp[3]),
std::max<int64_t>(1, /*repN=*/shape[rank - 1] /
(shapePerWarp[2] * warpsPerCTA[rank - 1]))};
// k x n
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
// so it's fine if the n is incorrect here
tileSize.push_back(4 * 64 / bitwidth);
tileSize.push_back(8);
}

SmallVector<int64_t> numRep;
// Lezcano: This is odd. Why do we always return a vector of size 3?
if (rank != 3) {
numRep.push_back(1);
}
for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) {
numRep.push_back(std::max<int64_t>(1, s / (size * warp)));
}
return numRep;
}

SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
Expand Down

0 comments on commit 54aba83

Please sign in to comment.