Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][NVGPU] Introduce warpgroup.init.accumulator Op #67530

Merged
merged 4 commits into from
Oct 11, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Sep 27, 2023

This Op generates and initilizes the accumulator matrix for nvgpu.warpgroup.mma op to perform matrix-multiply-and-accumulate (mma).

Its associated transformation generates !llvm.struct<> and fill it with the initial values. The size of struct is number of required inout registers for nvgpu.warpgroup.mma op.

@llvmbot
Copy link
Member

llvmbot commented Sep 27, 2023

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir

Changes

This Op generates and initilizes the accumulator matrix for nvgpu.warpgroup.mma op to perform matrix-multiply-and-accumulate (mma).

Its associated transformation generates !llvm.struct&lt;&gt; and fill it with the initial values. The size of struct is number of required inout registers for nvgpu.warpgroup.mma op.


Full diff: https://github.com/llvm/llvm-project/pull/67530.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+11)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+30)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+71)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 31b137160545772..e129ab39ff115c5 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -727,4 +727,15 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {  
+  let summary = "Initialize accumulator matrix for `warppgroup.mma`";
+
+  let description = [{
+    This Op generates and initilizes the accumulator matrix for 
+    `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate (mma).
+  }];
+  let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
+  let assemblyFormat = "attr-dict `->` type($matrixC)";
+}
+
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4d1f6641af6dca3..ecc7da9bd7414bb 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1306,6 +1306,35 @@ struct NVGPUWarpgroupMmaOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    SmallVector<Value> results;
+    for (auto matrixD : op.getMatrixC()) {
+      nvgpu::WarpgroupAccumulatorType matrixDType =
+          matrixD.getType().cast<nvgpu::WarpgroupAccumulatorType>();
+      Type stype = getTypeConverter()->convertType(matrixDType);
+      Value undefStruct = rewriter.create<LLVM::UndefOp>(loc, stype);
+      Type elemType = matrixDType.getFragmented().getElementType();
+      int64_t elemSize = matrixDType.getFragmented().getDimSize(0);
+      Value zero = rewriter.create<LLVM::ConstantOp>(
+          loc, elemType, rewriter.getZeroAttr(elemType));
+      for (int64_t i = 0; i < elemSize; ++i) {
+        undefStruct = rewriter.create<LLVM::InsertValueOp>(
+            loc, stype, undefStruct, zero, ArrayRef<int64_t>({i}));
+      }
+      results.push_back(undefStruct);
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1322,6 +1351,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
       NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
+      NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 8c2f8dbbd5ad9a3..65f2bba8d4bf4ca 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -772,6 +772,77 @@ func.func @warpgroup_mma_128_128_64(
   return
 }
 
+func.func @warpgroup_mma_init(){
+  //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+  //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
+  //CHECK: %[[S2:.+]] = llvm.insertvalue %[[S1]], %[[S0]][0] : !llvm.struct
+  //CHECK: %[[S3:.+]] = llvm.insertvalue %[[S1]], %[[S2]][1] : !llvm.struct
+  //CHECK: %[[S4:.+]] = llvm.insertvalue %[[S1]], %[[S3]][2] : !llvm.struct
+  //CHECK: %[[S5:.+]] = llvm.insertvalue %[[S1]], %[[S4]][3] : !llvm.struct
+  //CHECK: %[[S6:.+]] = llvm.insertvalue %[[S1]], %[[S5]][4] : !llvm.struct
+  //CHECK: %[[S7:.+]] = llvm.insertvalue %[[S1]], %[[S6]][5] : !llvm.struct
+  //CHECK: %[[S8:.+]] = llvm.insertvalue %[[S1]], %[[S7]][6] : !llvm.struct
+  //CHECK: %[[S9:.+]] = llvm.insertvalue %[[S1]], %[[S8]][7] : !llvm.struct
+  //CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][8] : !llvm.struct
+  //CHECK: %[[S11:.+]] = llvm.insertvalue %[[S1]], %[[S10]][9] : !llvm.struct
+  //CHECK: %[[S12:.+]] = llvm.insertvalue %[[S1]], %[[S11]][10] : !llvm.struct
+  //CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][11] : !llvm.struct
+  //CHECK: %[[S14:.+]] = llvm.insertvalue %[[S1]], %[[S13]][12] : !llvm.struct
+  //CHECK: %[[S15:.+]] = llvm.insertvalue %[[S1]], %[[S14]][13] : !llvm.struct
+  //CHECK: %[[S16:.+]] = llvm.insertvalue %[[S1]], %[[S15]][14] : !llvm.struct
+  //CHECK: %[[S17:.+]] = llvm.insertvalue %[[S1]], %[[S16]][15] : !llvm.struct
+  //CHECK: %[[S18:.+]] = llvm.insertvalue %[[S1]], %[[S17]][16] : !llvm.struct
+  //CHECK: %[[S19:.+]] = llvm.insertvalue %[[S1]], %[[S18]][17] : !llvm.struct
+  //CHECK: %[[S20:.+]] = llvm.insertvalue %[[S1]], %[[S19]][18] : !llvm.struct
+  //CHECK: %[[S21:.+]] = llvm.insertvalue %[[S1]], %[[S20]][19] : !llvm.struct
+  //CHECK: %[[S22:.+]] = llvm.insertvalue %[[S1]], %[[S21]][20] : !llvm.struct
+  //CHECK: %[[S23:.+]] = llvm.insertvalue %[[S1]], %[[S22]][21] : !llvm.struct
+  //CHECK: %[[S24:.+]] = llvm.insertvalue %[[S1]], %[[S23]][22] : !llvm.struct
+  //CHECK: %[[S25:.+]] = llvm.insertvalue %[[S1]], %[[S24]][23] : !llvm.struct
+  //CHECK: %[[S26:.+]] = llvm.insertvalue %[[S1]], %[[S25]][24] : !llvm.struct
+  //CHECK: %[[S27:.+]] = llvm.insertvalue %[[S1]], %[[S26]][25] : !llvm.struct
+  //CHECK: %[[S28:.+]] = llvm.insertvalue %[[S1]], %[[S27]][26] : !llvm.struct
+  //CHECK: %[[S29:.+]] = llvm.insertvalue %[[S1]], %[[S28]][27] : !llvm.struct
+  //CHECK: %[[S30:.+]] = llvm.insertvalue %[[S1]], %[[S29]][28] : !llvm.struct
+  //CHECK: %[[S31:.+]] = llvm.insertvalue %[[S1]], %[[S30]][29] : !llvm.struct
+  //CHECK: %[[S32:.+]] = llvm.insertvalue %[[S1]], %[[S31]][30] : !llvm.struct
+  //CHECK: %[[S33:.+]] = llvm.insertvalue %[[S1]], %[[S32]][31] : !llvm.struct
+  //CHECK: %[[S34:.+]] = llvm.insertvalue %[[S1]], %[[S33]][32] : !llvm.struct
+  //CHECK: %[[S35:.+]] = llvm.insertvalue %[[S1]], %[[S34]][33] : !llvm.struct
+  //CHECK: %[[S36:.+]] = llvm.insertvalue %[[S1]], %[[S35]][34] : !llvm.struct
+  //CHECK: %[[S37:.+]] = llvm.insertvalue %[[S1]], %[[S36]][35] : !llvm.struct
+  //CHECK: %[[S38:.+]] = llvm.insertvalue %[[S1]], %[[S37]][36] : !llvm.struct
+  //CHECK: %[[S39:.+]] = llvm.insertvalue %[[S1]], %[[S38]][37] : !llvm.struct
+  //CHECK: %[[S40:.+]] = llvm.insertvalue %[[S1]], %[[S39]][38] : !llvm.struct
+  //CHECK: %[[S41:.+]] = llvm.insertvalue %[[S1]], %[[S40]][39] : !llvm.struct
+  //CHECK: %[[S42:.+]] = llvm.insertvalue %[[S1]], %[[S41]][40] : !llvm.struct
+  //CHECK: %[[S43:.+]] = llvm.insertvalue %[[S1]], %[[S42]][41] : !llvm.struct
+  //CHECK: %[[S44:.+]] = llvm.insertvalue %[[S1]], %[[S43]][42] : !llvm.struct
+  //CHECK: %[[S45:.+]] = llvm.insertvalue %[[S1]], %[[S44]][43] : !llvm.struct
+  //CHECK: %[[S46:.+]] = llvm.insertvalue %[[S1]], %[[S45]][44] : !llvm.struct
+  //CHECK: %[[S47:.+]] = llvm.insertvalue %[[S1]], %[[S46]][45] : !llvm.struct
+  //CHECK: %[[S48:.+]] = llvm.insertvalue %[[S1]], %[[S47]][46] : !llvm.struct
+  //CHECK: %[[S49:.+]] = llvm.insertvalue %[[S1]], %[[S48]][47] : !llvm.struct
+  //CHECK: %[[S50:.+]] = llvm.insertvalue %[[S1]], %[[S49]][48] : !llvm.struct
+  //CHECK: %[[S51:.+]] = llvm.insertvalue %[[S1]], %[[S50]][49] : !llvm.struct
+  //CHECK: %[[S52:.+]] = llvm.insertvalue %[[S1]], %[[S51]][50] : !llvm.struct
+  //CHECK: %[[S53:.+]] = llvm.insertvalue %[[S1]], %[[S52]][51] : !llvm.struct
+  //CHECK: %[[S54:.+]] = llvm.insertvalue %[[S1]], %[[S53]][52] : !llvm.struct
+  //CHECK: %[[S55:.+]] = llvm.insertvalue %[[S1]], %[[S54]][53] : !llvm.struct
+  //CHECK: %[[S56:.+]] = llvm.insertvalue %[[S1]], %[[S55]][54] : !llvm.struct
+  //CHECK: %[[S57:.+]] = llvm.insertvalue %[[S1]], %[[S56]][55] : !llvm.struct
+  //CHECK: %[[S58:.+]] = llvm.insertvalue %[[S1]], %[[S57]][56] : !llvm.struct
+  //CHECK: %[[S59:.+]] = llvm.insertvalue %[[S1]], %[[S58]][57] : !llvm.struct
+  //CHECK: %[[S60:.+]] = llvm.insertvalue %[[S1]], %[[S59]][58] : !llvm.struct
+  //CHECK: %[[S61:.+]] = llvm.insertvalue %[[S1]], %[[S60]][59] : !llvm.struct
+  //CHECK: %[[S62:.+]] = llvm.insertvalue %[[S1]], %[[S61]][60] : !llvm.struct
+  //CHECK: %[[S63:.+]] = llvm.insertvalue %[[S1]], %[[S62]][61] : !llvm.struct
+  //CHECK: %[[S64:.+]] = llvm.insertvalue %[[S1]], %[[S63]][62] : !llvm.struct
+  //CHECK: %[[S65:.+]] = llvm.insertvalue %[[S1]], %[[S64]][63] : !llvm.struct
+  %matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+  return 
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Show resolved Hide resolved
//CHECK: %[[S63:.+]] = llvm.insertvalue %[[S1]], %[[S62]][61] : !llvm.struct
//CHECK: %[[S64:.+]] = llvm.insertvalue %[[S1]], %[[S63]][62] : !llvm.struct
//CHECK: %[[S65:.+]] = llvm.insertvalue %[[S1]], %[[S64]][63] : !llvm.struct
%matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, the matrixC needs to live in register.
How do we make sure this gets lowered to that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it needs to be in virtual registers. We generate a struct for that

I'm not sure I really understood the question.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accumulator (or the struct) is used by the wgmma.mma_async PTX instruction. You pass the members this struct as operands.

Do you think there might be case that the struct can end up in different place?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it needs to be in virtual registers. We generate a struct for that

I'm not sure I really understood the question.

You got the question right :).

I was afraid that this may produce some spill, but thinking more about it, even if that's the case, the uses in wgmma.* should force them back to register.

So nothing to worry about.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you got me thinking :) It is actually possible for llvm to put the struct in local memory. I don't have a good answer on how to prevent this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can :)

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Show resolved Hide resolved
@grypp grypp force-pushed the nvgpu-mma-init branch 2 times, most recently from 25071a0 to 2781170 Compare October 5, 2023 07:10
Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unsure about the variadic part of the new init instruction, as in, I don't understand what it buys us. Maybe a snippet of an example would help.

Let's do one more iteration on that, but other than that, lgtm.

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Show resolved Hide resolved
grypp added 4 commits October 5, 2023 18:25
This Op generates and initilizes the accumulator matrix for `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate (mma).

Its associated transformation generates `!llvm.struct<>` and fill it with the initial values. The size of struct is number of required inout registers for `nvgpu.warpgroup.mma` op.
Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@grypp grypp merged commit 315ab3c into llvm:main Oct 11, 2023
@grypp grypp deleted the nvgpu-mma-init branch October 11, 2023 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants