Skip to content

Commit

Permalink
[AMD] Implement RepOrder for AMD MMA layouts (#5126)
Browse files Browse the repository at this point in the history
Implement RepOrder methods for MFMA and WMMA layouts. Both layouts have
row major rep layout. Also,
isTranspose flag in MFMA layout does not affect RepOrder, meaning
RepOrder is row major in both cases.

Co-authored-by: Ognjen Plavsic <[email protected]>
  • Loading branch information
oplavsic and Ognjen Plavsic authored Nov 18, 2024
1 parent 0bd30a2 commit 66e8629
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
"getSizePerThreadForOperand",
(ins "int":$opIdx,
"int":$kWidth)>,

InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrderForOperand",
(ins "int":$opIdx)>,
];
}

Expand Down
21 changes: 18 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
}

SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder");
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}

SmallVector<unsigned>
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<int64_t>
Expand Down Expand Up @@ -1745,8 +1752,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder");
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -2016,7 +2031,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
// DotOperand Encoding
//===----------------------------------------------------------------------===//
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
return mma.getRepOrderForOperand(getOpIdx());
}
llvm::report_fatal_error(
Expand Down

0 comments on commit 66e8629

Please sign in to comment.