diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2c1f7da609ae..09330e7eda63 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -777,6 +777,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "getSizePerThreadForOperand", (ins "int":$opIdx, "int":$kWidth)>, + + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrderForOperand", + (ins "int":$opIdx)>, ]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1e63c4b390d4..0237d9815c39 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1658,7 +1658,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { } SmallVector AMDMfmaEncodingAttr::getRepOrder() const { - llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder"); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); } SmallVector @@ -1745,8 +1752,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { return shapePerCTATile; } SmallVector AMDWmmaEncodingAttr::getRepOrder() const { - llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder"); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); } + +SmallVector +AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); +} + SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -2016,7 +2031,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { // DotOperand Encoding //===----------------------------------------------------------------------===// SmallVector DotOperandEncodingAttr::getRepOrder() const { - if (auto mma = mlir::dyn_cast(getParent())) { + if (auto mma = mlir::dyn_cast(getParent())) { return mma.getRepOrderForOperand(getOpIdx()); } llvm::report_fatal_error(