Skip to content

Commit

Permalink
[MLIR][NVGPU] Introduce warpgroup.init.accumulator Op (#67530)
Browse files Browse the repository at this point in the history
This Op generates and initilizes the accumulator matrix for
`nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate
(mma).

Its associated transformation generates `!llvm.struct<>` and fill it
with the initial values. The size of struct is number of required inout
registers for `nvgpu.warpgroup.mma` op.
  • Loading branch information
grypp authored Oct 11, 2023
1 parent 660a78f commit 315ab3c
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 1 deletion.
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -748,4 +748,16 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
let hasVerifier = 1;
}

def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {
let summary = "Initializes the accumulator matrix";

let description = [{
This Op generates and initializes the accumulator matrix for
`nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
}];
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
let assemblyFormat = "attr-dict `->` type($matrixC)";
let hasVerifier = 1;
}

#endif // NVGPU
30 changes: 30 additions & 0 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,35 @@ struct NVGPUWarpgroupMmaStoreOpLowering
}
};

struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
using ConvertOpToLLVMPattern<
nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
SmallVector<Value> results;
for (OpResult m : op.getMatrixC()) {
nvgpu::WarpgroupAccumulatorType mType =
m.getType().cast<nvgpu::WarpgroupAccumulatorType>();
Type stype = getTypeConverter()->convertType(mType);
Value undefStruct = b.create<LLVM::UndefOp>(stype);
Type elemType = mType.getFragmented().getElementType();
int64_t elemSize = mType.getFragmented().getDimSize(0);
Value zero =
b.create<LLVM::ConstantOp>(elemType, rewriter.getZeroAttr(elemType));
for (int64_t i = 0; i < elemSize; ++i) {
undefStruct = b.create<LLVM::InsertValueOp>(stype, undefStruct, zero,
ArrayRef<int64_t>({i}));
}
results.push_back(undefStruct);
}
rewriter.replaceOp(op, results);
return success();
}
};

} // namespace

void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
Expand All @@ -1563,6 +1592,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
Expand Down
26 changes: 25 additions & 1 deletion mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return failure();
}

LogicalResult isAllowedSizeM(int sizeM) { return success(sizeM == 64); }

LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
72, 80, 88, 96, 104, 112, 120, 128,
Expand All @@ -443,7 +445,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
if (llvm::is_contained(allowedN, sizeN))
return success();
Expand Down Expand Up @@ -563,6 +565,28 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// WarpgroupMmaInitAccumulatorOp
//===----------------------------------------------------------------------===//

LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
for (OpResult matrix : getMatrixC()) {
VectorType vectorType = matrix.getType()
.cast<nvgpu::WarpgroupAccumulatorType>()
.getFragmented();
// Check [M][N] shape
if (failed(isAllowedSizeM(vectorType.getDimSize(0))) ||
failed(isAllowedSizeN(vectorType.getDimSize(1),
vectorType.getElementType()))) {
return emitOpError() << "has type " << vectorType
<< ". It does not fit into warp-group "
"level (wgmma) matrix multiplication instruction "
"(or not supported yet)";
}
}
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,77 @@ func.func @warpgroup_mma_store(
return
}

func.func @warpgroup_mma_init() {
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S2:.+]] = llvm.insertvalue %[[S1]], %[[S0]][0] : !llvm.struct
//CHECK: %[[S3:.+]] = llvm.insertvalue %[[S1]], %[[S2]][1] : !llvm.struct
//CHECK: %[[S4:.+]] = llvm.insertvalue %[[S1]], %[[S3]][2] : !llvm.struct
//CHECK: %[[S5:.+]] = llvm.insertvalue %[[S1]], %[[S4]][3] : !llvm.struct
//CHECK: %[[S6:.+]] = llvm.insertvalue %[[S1]], %[[S5]][4] : !llvm.struct
//CHECK: %[[S7:.+]] = llvm.insertvalue %[[S1]], %[[S6]][5] : !llvm.struct
//CHECK: %[[S8:.+]] = llvm.insertvalue %[[S1]], %[[S7]][6] : !llvm.struct
//CHECK: %[[S9:.+]] = llvm.insertvalue %[[S1]], %[[S8]][7] : !llvm.struct
//CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][8] : !llvm.struct
//CHECK: %[[S11:.+]] = llvm.insertvalue %[[S1]], %[[S10]][9] : !llvm.struct
//CHECK: %[[S12:.+]] = llvm.insertvalue %[[S1]], %[[S11]][10] : !llvm.struct
//CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][11] : !llvm.struct
//CHECK: %[[S14:.+]] = llvm.insertvalue %[[S1]], %[[S13]][12] : !llvm.struct
//CHECK: %[[S15:.+]] = llvm.insertvalue %[[S1]], %[[S14]][13] : !llvm.struct
//CHECK: %[[S16:.+]] = llvm.insertvalue %[[S1]], %[[S15]][14] : !llvm.struct
//CHECK: %[[S17:.+]] = llvm.insertvalue %[[S1]], %[[S16]][15] : !llvm.struct
//CHECK: %[[S18:.+]] = llvm.insertvalue %[[S1]], %[[S17]][16] : !llvm.struct
//CHECK: %[[S19:.+]] = llvm.insertvalue %[[S1]], %[[S18]][17] : !llvm.struct
//CHECK: %[[S20:.+]] = llvm.insertvalue %[[S1]], %[[S19]][18] : !llvm.struct
//CHECK: %[[S21:.+]] = llvm.insertvalue %[[S1]], %[[S20]][19] : !llvm.struct
//CHECK: %[[S22:.+]] = llvm.insertvalue %[[S1]], %[[S21]][20] : !llvm.struct
//CHECK: %[[S23:.+]] = llvm.insertvalue %[[S1]], %[[S22]][21] : !llvm.struct
//CHECK: %[[S24:.+]] = llvm.insertvalue %[[S1]], %[[S23]][22] : !llvm.struct
//CHECK: %[[S25:.+]] = llvm.insertvalue %[[S1]], %[[S24]][23] : !llvm.struct
//CHECK: %[[S26:.+]] = llvm.insertvalue %[[S1]], %[[S25]][24] : !llvm.struct
//CHECK: %[[S27:.+]] = llvm.insertvalue %[[S1]], %[[S26]][25] : !llvm.struct
//CHECK: %[[S28:.+]] = llvm.insertvalue %[[S1]], %[[S27]][26] : !llvm.struct
//CHECK: %[[S29:.+]] = llvm.insertvalue %[[S1]], %[[S28]][27] : !llvm.struct
//CHECK: %[[S30:.+]] = llvm.insertvalue %[[S1]], %[[S29]][28] : !llvm.struct
//CHECK: %[[S31:.+]] = llvm.insertvalue %[[S1]], %[[S30]][29] : !llvm.struct
//CHECK: %[[S32:.+]] = llvm.insertvalue %[[S1]], %[[S31]][30] : !llvm.struct
//CHECK: %[[S33:.+]] = llvm.insertvalue %[[S1]], %[[S32]][31] : !llvm.struct
//CHECK: %[[S34:.+]] = llvm.insertvalue %[[S1]], %[[S33]][32] : !llvm.struct
//CHECK: %[[S35:.+]] = llvm.insertvalue %[[S1]], %[[S34]][33] : !llvm.struct
//CHECK: %[[S36:.+]] = llvm.insertvalue %[[S1]], %[[S35]][34] : !llvm.struct
//CHECK: %[[S37:.+]] = llvm.insertvalue %[[S1]], %[[S36]][35] : !llvm.struct
//CHECK: %[[S38:.+]] = llvm.insertvalue %[[S1]], %[[S37]][36] : !llvm.struct
//CHECK: %[[S39:.+]] = llvm.insertvalue %[[S1]], %[[S38]][37] : !llvm.struct
//CHECK: %[[S40:.+]] = llvm.insertvalue %[[S1]], %[[S39]][38] : !llvm.struct
//CHECK: %[[S41:.+]] = llvm.insertvalue %[[S1]], %[[S40]][39] : !llvm.struct
//CHECK: %[[S42:.+]] = llvm.insertvalue %[[S1]], %[[S41]][40] : !llvm.struct
//CHECK: %[[S43:.+]] = llvm.insertvalue %[[S1]], %[[S42]][41] : !llvm.struct
//CHECK: %[[S44:.+]] = llvm.insertvalue %[[S1]], %[[S43]][42] : !llvm.struct
//CHECK: %[[S45:.+]] = llvm.insertvalue %[[S1]], %[[S44]][43] : !llvm.struct
//CHECK: %[[S46:.+]] = llvm.insertvalue %[[S1]], %[[S45]][44] : !llvm.struct
//CHECK: %[[S47:.+]] = llvm.insertvalue %[[S1]], %[[S46]][45] : !llvm.struct
//CHECK: %[[S48:.+]] = llvm.insertvalue %[[S1]], %[[S47]][46] : !llvm.struct
//CHECK: %[[S49:.+]] = llvm.insertvalue %[[S1]], %[[S48]][47] : !llvm.struct
//CHECK: %[[S50:.+]] = llvm.insertvalue %[[S1]], %[[S49]][48] : !llvm.struct
//CHECK: %[[S51:.+]] = llvm.insertvalue %[[S1]], %[[S50]][49] : !llvm.struct
//CHECK: %[[S52:.+]] = llvm.insertvalue %[[S1]], %[[S51]][50] : !llvm.struct
//CHECK: %[[S53:.+]] = llvm.insertvalue %[[S1]], %[[S52]][51] : !llvm.struct
//CHECK: %[[S54:.+]] = llvm.insertvalue %[[S1]], %[[S53]][52] : !llvm.struct
//CHECK: %[[S55:.+]] = llvm.insertvalue %[[S1]], %[[S54]][53] : !llvm.struct
//CHECK: %[[S56:.+]] = llvm.insertvalue %[[S1]], %[[S55]][54] : !llvm.struct
//CHECK: %[[S57:.+]] = llvm.insertvalue %[[S1]], %[[S56]][55] : !llvm.struct
//CHECK: %[[S58:.+]] = llvm.insertvalue %[[S1]], %[[S57]][56] : !llvm.struct
//CHECK: %[[S59:.+]] = llvm.insertvalue %[[S1]], %[[S58]][57] : !llvm.struct
//CHECK: %[[S60:.+]] = llvm.insertvalue %[[S1]], %[[S59]][58] : !llvm.struct
//CHECK: %[[S61:.+]] = llvm.insertvalue %[[S1]], %[[S60]][59] : !llvm.struct
//CHECK: %[[S62:.+]] = llvm.insertvalue %[[S1]], %[[S61]][60] : !llvm.struct
//CHECK: %[[S63:.+]] = llvm.insertvalue %[[S1]], %[[S62]][61] : !llvm.struct
//CHECK: %[[S64:.+]] = llvm.insertvalue %[[S1]], %[[S63]][62] : !llvm.struct
//CHECK: %[[S65:.+]] = llvm.insertvalue %[[S1]], %[[S64]][63] : !llvm.struct
%matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
return
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
Expand Down

0 comments on commit 315ab3c

Please sign in to comment.