-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][nvgpu] Improve WarpgroupAccumulator
type to simplify IR
#68728
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-nvgpu Author: Guray Ozen (grypp) Changes
This PR improves the transformation of Example: Current GEMM in NVGPU-IR
Example: This PR simplifies the IR as below:
Patch is 43.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68728.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 57cd1a3806c2ed6..fd16376be366912 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
- Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
- let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
+ NVGPU_WarpgroupAccumulator:$matrixC);
+ let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
let assemblyFormat = [{
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
@@ -739,13 +739,25 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
Note that, the op must be run with warp group.
}];
- let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+ let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
let assemblyFormat = [{
- `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
+def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {
+ let summary = "Initializes the accumulator matrix";
+
+ let description = [{
+ This Op generates and initializes the accumulator matrix for
+ `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
+ }];
+ let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
+ let assemblyFormat = "attr-dict `->` type($matrixC)";
+ let hasVerifier = 1;
+}
+
#endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 96af26842dafea2..e6bba7e6082964b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
constexpr int kWarpSize = 32;
+/// M size of wgmma.mma_async instruction
+constexpr int kWgmmaSizeM = 64;
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 99c4d4223351352..2d43230938526b9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
- VectorType vtype = type.getFragmented();
+ Type elemType = type.getFragmented().getElementType();
+ int64_t sizeM = type.getFragmented().getDimSize(0);
+ int64_t sizeN = type.getFragmented().getDimSize(1);
+
+ unsigned numMembers;
+ if (elemType.isF32() || elemType.isInteger(32))
+ numMembers = sizeN / 2;
+ else if (elemType.isF16())
+ numMembers = sizeN / 4;
+ else
+ llvm_unreachable("unsupported type for warpgroup accumulator");
+
+ SmallVector<Type> innerStructBody;
+ for (unsigned i = 0; i < numMembers; i++)
+ innerStructBody.push_back(elemType);
+ auto innerStructType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
+
SmallVector<Type> structBody;
- for (unsigned i = 0; i < vtype.getDimSize(0); i++)
- structBody.push_back(vtype.getElementType());
+ for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+ structBody.push_back(innerStructType);
+
auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return converter.convertType(convertedType);
@@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
nvgpu::WarpgroupMmaOp op;
ImplicitLocOpBuilder b;
OpAdaptor adaptor;
- const LLVMTypeConverter &typeConverter;
// Entire shape of the given Op
int64_t totalM, totalN, totalK;
@@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
- Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+ Value generateWgmma(int i, int j, int k, Value matrixC) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
@@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);
- Type resultStructType = typeConverter.convertType(matrixD.getType());
-
return b.create<NVVM::WgmmaMmaAsyncOp>(
- resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+ matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
- SmallVector<Value> generateWgmmaGroup() {
- SmallVector<Value> wgmmaResults;
+ Value generateWgmmaGroup() {
+ Value wgmmaResult =
+ b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
// Perform GEMM
+ SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
- Value matrixC = adaptor.getMatrixC()[i];
- Value matrixD = op.getMatrixD()[i];
+ Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
- matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+ matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}
-
- return wgmmaResults;
+ for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
+ wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
+ wgmmaResult, matrix, idx);
+ }
+ return wgmmaResult;
}
public:
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
- OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
- : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
+ OpAdaptor adaptor)
+ : op(op), b(b), adaptor(adaptor) {
// Find the entire GEMM Shape
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
/// instructions and group synchronization, as well as waiting
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
- SmallVector<Value> generateWarpgroupMma() {
+ Value generateWarpgroupMma() {
b.create<NVVM::WgmmaFenceAlignedOp>();
- SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+ Value wgmmaResult = generateWgmmaGroup();
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
- return wgmmaResults;
+ return wgmmaResult;
}
};
-
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+
// Step 1. Build a helper class
- WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
+ WarpgroupGemm warpgroupGemm(op, b, adaptor);
// Step 2. Get the entire GEMM Shape
- SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+ Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
// Step 3. Replace fragmented result struct with the op results
- rewriter.replaceOp(op, wgmmaResults);
+ rewriter.replaceOp(op, wgmmaResult);
return success();
}
};
@@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int offset = 0;
- ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
- for (Value matrixD : adaptor.getMatrixD()) {
- auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
- storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value matriDValue = adaptor.getMatrixD();
+ auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
+ for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
+ auto structType = matrixD.cast<LLVM::LLVMStructType>();
+ Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+ storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
@@ -1546,6 +1568,39 @@ struct NVGPUWarpgroupMmaStoreOpLowering
}
};
+struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ LLVM::LLVMStructType structType =
+ getTypeConverter()
+ ->convertType(op.getMatrixC().getType())
+ .cast<LLVM::LLVMStructType>();
+ Type elemType = structType.getBody()
+ .front()
+ .cast<LLVM::LLVMStructType>()
+ .getBody()
+ .front();
+ Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
+ Value structValue = b.create<LLVM::UndefOp>(structType);
+ for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
+ auto innerStructType = s.cast<LLVM::LLVMStructType>();
+ int ii = idx;
+ Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
+ for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
+ innerStructValue = b.create<LLVM::InsertValueOp>(
+ innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
+ }
+ }
+ rewriter.replaceOp(op, structValue);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1563,6 +1618,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
+ NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index e8ecd0faa4c86d3..f5b02fe1b515591 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -435,6 +435,12 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return failure();
}
+LogicalResult isAllowedSizeM(int sizeM) {
+ if (sizeM % kWgmmaSizeM)
+ return failure();
+ return success();
+}
+
LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
72, 80, 88, 96, 104, 112, 120, 128,
@@ -443,7 +449,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
- if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
+ if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
if (llvm::is_contained(allowedN, sizeN))
return success();
@@ -456,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
LogicalResult WarpgroupMmaOp::verify() {
if (getTransposeA() && !getTransposeB())
- return emitOpError() << "supports non-transpose A (Row Major) "
- "and transpose B (Column Major) for the time being";
+ return emitOpError()
+ << "supports non-transpose A (Row Major) "
+ "and transpose B (Column Major) for the time being ";
MemRefType matrixA = getDescriptorA().getType().getTensor();
MemRefType matrixB = getDescriptorB().getType().getTensor();
- VectorType matrixC = getMatrixC()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
- VectorType matrixD = getMatrixD()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
- unsigned sizeAcc = getMatrixC().size();
-
- if (getMatrixC().size() != getMatrixD().size())
- return emitOpError() << "number of matrix C and matrix D must be the same";
-
- if (llvm::all_of(getMatrixC(),
- [&](Value rhs) { return rhs.getType() == matrixC; })) {
- return emitOpError()
- << "types of all operands in matrix C must be the same";
- }
- if (llvm::all_of(getMatrixD(),
- [&](Value rhs) { return rhs.getType() == matrixC; })) {
- return emitOpError()
- << "types of all operands in matrix D must be the same as matrix C";
- }
+ VectorType matrixC = getMatrixC().getType().getFragmented();
+ VectorType matrixD = getMatrixD().getType().getFragmented();
+
+ if (matrixC != matrixD)
+ return emitOpError() << "type of matrix C and matrix D must be the same";
if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
matrixC.getRank() != 2 || matrixD.getRank() != 2) {
@@ -496,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
<< ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
<< " )";
- if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
+ if (matrixA.getShape()[0] != matrixC.getShape()[0])
return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
<< " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
<< " )";
@@ -532,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {
LogicalResult WarpgroupMmaStoreOp::verify() {
MemRefType dstMemrefType = getDstMemref().getType();
- VectorType firstVtype = getMatrixD()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
-
- int64_t totalFirstDimension = 0;
- for (Value result : getMatrixD()) {
- VectorType vtype =
- result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
- if (vtype != firstVtype)
- return emitOpError() << "all fragmented types must be the same";
- // Limitation
- if (!vtype.getElementType().isF32()) {
- return emitOpError()
- << "hit a limitation: only f32 results for the time being";
- }
- totalFirstDimension += vtype.getDimSize(0);
+ VectorType vtype = getMatrixD().getType().getFragmented();
+
+ // Limitation
+ if (!vtype.getElementType().isF32()) {
+ return emitOpError()
+ << "hit a limitation: only f32 results for the time being";
}
- if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
- firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
- return emitOpError() << "results [" << totalFirstDimension << "]["
- << firstVtype.getDimSize(1)
+ if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
+ vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
+ return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
<< "] values. However, destination memref["
<< dstMemrefType.getDimSize(0) << "]["
<< dstMemrefType.getDimSize(1)
@@ -563,6 +537,27 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// WarpgroupMmaInitAccumulatorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
+
+ nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
+ int64_t sizeM = accType.getFragmented().getDimSize(0);
+ int64_t sizeN = accType.getFragmented().getDimSize(1);
+ Type elemType = accType.getFragmented().getElementType();
+
+ if (failed(isAllowedSizeM(sizeM)) ||
+ failed(isAllowedSizeN(sizeN, elemType))) {
+ return emitOpError() << "has type " << accType.getFragmented()
+ << ". It does not fit into warp-group "
+ "level (wgmma) matrix multiplication instruction "
+ "(or not supported yet)";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index e54b62a06d4313a..bf660e2683158e5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -713,18 +713,18 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.
}
// CHECK-LABEL: @warpgroup_mma_128_128_64(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
func.func @warpgroup_mma_128_128_64(
%descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
%descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
- %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
- %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>)
+ %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
{
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[ar...
[truncated]
|
5b8b2bd
to
ce2d8ed
Compare
ce2d8ed
to
04582c0
Compare
llvm#68728 simplified significantly the accumulator matrix. But we forget packing the struct after initilization. This PR fixes that.
The #68728 significantly simplified the accumulator matrix type, making it easier to work with the nvgpu dialect without worrying about the number of required structs, as this information is abstracted away in the nvgpu-to-nvvm transformation. However, we forgot packing the structs after initialization, causing the accumulator matrix to hold undefined values, which is wrong. This PR addresses that.
Local branch amd-gfx b433a27 Merged main:cbf7d5f82b72 into amd-gfx:8494c22787c6 Remote branch main 52db7e2 [mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (llvm#68728)
WarpgroupAccumulator
(or!nvgpu.warpgroup.accumulator
) is a type that keeps the accumulator matrix that is used by warp-group level matrix multiplication. It is handy to have a special type for that as the matrix is distributed among the threads of the warp-group. However, current transformations requires to create and use multipleWarpgroupAccumulator
if the shape of GEMM is larger than the supported shape ofwgmma.mma_async
instruction. This makes IR looks dense.This PR improves the transformation of
WarpgroupAccumulator
type in every nvgpu Op that uses it.Example: Current GEMM in NVGPU-IR
Example: This PR simplifies the IR as below: