Skip to content

Commit

Permalink
[LAYOUTS] Implement LL conversion for DotOperand(Hopper)
Browse files Browse the repository at this point in the history
We also rewrite the way we implement DotOperand(Ampere) and mma Ampere
to promote code reusing. I also started using what I believe is a rather
compact pattern to write these things, where you first call `identiyND`
with the `repOrder`, which gives you an LL with the dims in the correct
order, and then you construct the final layout by specifying the tiles
by multiplying `identity1D` maps. Using this allowed me to heavily
simplify the handling of the `warps` of `DotOperand` which used to be a
tad messy.
  • Loading branch information
lezcano committed Nov 19, 2024
1 parent 0bd30a2 commit d7f36a5
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 173 deletions.
3 changes: 2 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
assert(rank == 2 || rank == 3 && "Invalid dotLayout");

// Do not split CTA in K dimension
getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1;
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
res[kDim] = 1;
return res;
}
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
Expand Down
272 changes: 115 additions & 157 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,78 +280,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
return ret;
}

LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
NvidiaMmaEncodingAttr mma) {
int rank = shape.size();

assert(mma.isAmpere());
assert(rank == 2 || rank == 3);
assert(mma.getInstrShape().size() == rank);
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));

MLIRContext *ctx = mma.getContext();
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);

auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder());
assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true));

LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
ArrayRef(orderedDimNames).take_front(2));
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant with the order of the out dims.
ctaLayout *=
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}

LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
NvidiaMmaEncodingAttr mma) {
int rank = shape.size();
assert(mma.isHopper());
assert(rank == 2);

// wgmma operates on groups of 4 warps.
assert(product(mma.getWarpsPerCTA()) % 4 == 0);

// Check that it's a known MMA layout.
assert(mma.getInstrShape().size() == 3);
int m = mma.getInstrShape()[0];
int n = mma.getInstrShape()[1];
int k = mma.getInstrShape()[2];
assert(m == 16);
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

MLIRContext *ctx = mma.getContext();
LinearLayout ctaLayout(
{{S("register"), {{1, 0}, {0, 8}}},
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
{S("dim1"), S("dim0")});

// Expand the `register` dimension so the size of dim1 matches `n`.
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
S("register"), S("dim1"));

// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
// Since the warpOrder needs to be M-major, we need to transpose the out
// dimensions AND transpose the order
// FIXME(Lezcano). identityND should not have an `order` param as it's
// redundant. The order is already given by the order of the
// out dims, and if it has an order, it shouldn't change the
// order of the out dims.
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
{S("dim0"), S("dim1")})
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
}

LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef<int64_t> shape,
SharedEncodingAttr shared) {
assert(!shared.getHasLeadingOffset());
Expand Down Expand Up @@ -779,15 +707,76 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}

LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
// Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
int rank = repOrder.size();
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout ctaLayout =
identityND(S("register"), trivialShape, repOrder, dimNames);

assert(rank >= 2);
auto inner = order[0];
auto outer = order[1];

assert(tileShape.size() == rank);
int m = tileShape[outer];
int n = tileShape[inner];

// The relative order of registers and lanes is given by:
// - Inner dim: kWidth registers
// - Inner dim: 4 lanes
// - Outer dim: 8 lanes
// - Outer dim: repeat m / 8 times
// - Inner dim: repeat n / (kWidth * 4) times
assert(m % 8 == 0);
assert(n % (kWidth * 4) == 0);
// There is at least one subtile on the inner-most dimension
// FIXME. We should implement operator* in terms of operator*=
// and chain *= instead of using *
auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames());
ctaLayout = ctaLayout *
LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) *
LinearLayout::identity1D(4, S("lane"), dimNames[inner]) *
LinearLayout::identity1D(8, S("lane"), dimNames[outer]) *
LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) *
LinearLayout::identity1D(n / (kWidth * 4), S("register"),
dimNames[inner]);
return ctaLayout;
}

std::optional<LinearLayout>
NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ctx = getContext();
int rank = shape.size();

SmallVector<unsigned> tileShape;
if (isAmpere()) {
return ampereMmaToLinearLayout(shape, *this);
}
if (isHopper()) {
return hopperMmaToLinearLayout(shape, *this);
// Ampere.getInstrShape() returns the tile shape
tileShape = SmallVector<unsigned>(getInstrShape());
} else {
assert(isHopper());
auto instrShapeMNK = getInstrShape();
tileShape = SmallVector<unsigned>({instrShapeMNK[0], instrShapeMNK[1]});
}
return std::nullopt;
// nvidiamma layout always assumes kWidth = 2
constexpr auto kWidth = 2;
auto ctaLayout =
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(*this), getRepOrder());

// The triton orders are defined on [dim0, dim1, ...], so we need to pass
// those dims Then, for some reason, operator* requires the orders to match
// so we need to reorder the outs to match
// FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
// The order in triton assumes the standardDims, so it should
// use those.
ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(),
standardOutDimNames(ctx, rank))
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}

std::optional<LinearLayout>
Expand Down Expand Up @@ -860,62 +849,11 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return ret;
}

LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// Note that, even though MMAv2 looks similar to this layout, they are just
// the same at a register and lane level. The warps treatment is different!
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;

assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
assert(mma.isAmpere());

MLIRContext *ctx = mma.getContext();

// The A and B operands are tiled in a kMajor fashion
auto kMajorOrder = dot.getRepOrder();
assert(kMajorOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));

auto kMajorDims =
permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder);
// This agrees with the order of the elements, which means that we can share
// the code below for both A and B without having to perform any swaps
assert(getOrder(dot) == kMajorOrder);

std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
int32_t i = 1;
// kWidth contiguous elements
while (i < kWidth) {
registers.push_back({i, 0});
i *= 2;
}
// 4 threads per chunk
for (int j = 0; j < 2; j++) {
lanes.push_back({i, 0});
i *= 2;
}
// 8 threads going down
lanes.push_back({0, 1});
lanes.push_back({0, 2});
lanes.push_back({0, 4});
// 2 tiles in column-major order
// Just one if it's the B operand
if (isA) {
registers.push_back({0, 8});
}
registers.push_back({i, 0});

LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
ArrayRef(kMajorDims).take_front(2));

LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
ArrayRef<unsigned> mmaWarpOrder, bool isA) {
// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {0, 1}
// assume warpOrder = {1, 0}
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
// the C is owned as per the following layout:
// C: 0 | 1
Expand All @@ -926,33 +864,55 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
// A: 0 1 | 0 1 B: 0 2 | 1 3
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In particular, for A and B we need to broadcast along K
// In other words, we need to broadcast along K
auto rank = mmaWarpOrder.size();
auto inner = isA ? rank - 1 : rank - 2;
auto outer = isA ? rank - 2 : rank - 1;

assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
auto warpsPerCTAMma = mma.getWarpsPerCTA();
std::vector<std::vector<int32_t>> warps;
if (isA) {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, 0});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, i});
}
} else {
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
warps.push_back({0, i});
}
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
warps.push_back({0, 0});
}
}
if (rank == 3) {
for (auto &w : warps) {
w.push_back(0);
auto dimNames = standardOutDimNames(ctx, rank);
LinearLayout ctaLayout =
identityND(S("register"), {1, 1}, mmaWarpOrder, dimNames);

// We have to broadcast along the inner dimension
// For A, when moving along M we go from 0 to 2.
// For B, when moving along N we go from 0 to 1.
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
// Same happens if the mmaWarpOrder is {0, 1}, like in Hopper
for (auto d : mmaWarpOrder) {
if (d == inner) {
ctaLayout *=
LinearLayout::zeros1D(mmaWarpShape[d], S("warp"), dimNames[d]);
} else {
ctaLayout *=
LinearLayout::identity1D(mmaWarpShape[d], S("warp"), dimNames[d]);
}
}
return ctaLayout;
}

LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;
MLIRContext *ctx = mma.getContext();

ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);
SmallVector<unsigned> tileShape(rank, 1);
if (isA) {
tileShape[rank - 2] = 16;
tileShape[rank - 1] = kWidth * 8;
} else {
// Hopper takes the rhs via shared memory
assert(mma.isAmpere());
tileShape[rank - 2] = kWidth * 8;
tileShape[rank - 1] = 8;
}
auto ctaLayout =
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
ctaLayout *=
warpsNvidiaDot(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), isA)
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}
Expand All @@ -963,9 +923,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.isAmpere()) {
return ampereDotToLinearLayout(shape, *this);
}
return nvidiaDotToLinearLayout(shape, *this);
}
return std::nullopt;
}
Expand Down
Loading

0 comments on commit d7f36a5

Please sign in to comment.