Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][NVGPU] Adding nvgpu.warpgroup.mma Op for Hopper GPUs #65440

Merged
merged 8 commits into from
Sep 22, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Sep 6, 2023

This work introduces a new operation called warpgroup.mma to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture.

Previously, the nvvm.wgmma.mma_async operation was introduced to support warpgroup-level matrix operations in NVVM dialect. This op is used multiple instances of nvvm.wgmma.mma_async to achieve the desired shape. The new nvgpu.warpgroup.mma operation abstracts this complexity and provides a higher-level interface for performing warpgroup-level matrix operations.

The nvgpu.warpgroup.mma does followings:

  1. Corresponds multiple wgmma instructions.
  2. Iterates input matrix descriptors to achieve the desired computation shape. 3) Groups and runs wgmma instructions asynchronously, and eventually waits them. This are done by wgmma.fence.aligned, wgmma.commit.group.sync.aligned, and wgmma.wait.group.sync.aligned 4) Results fragmented matrices

Here's an example usage of the nvgpu.warpgroup.mma operation:

%wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: 
      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, 
      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, 
      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
      -> 
      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, 
      !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>  

The op will result following PTX:

wgmma.fence.sync.aligned;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2,    62 more registers}, %descA,     %descB,     p, 1, 1, 0, 1;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2,    62 more registers}, %descA+2,   %descB+128, p, 1, 1, 0, 1;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2,    62 more registers}, %descA+4,   %descB+256, p, 1, 1, 0, 1;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2,    62 more registers}, %descA+8,   %descB+348, p, 1, 1, 0, 1;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+512, %descB,     p, 1, 1, 0, 1;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+514, %descB+128, p, 1, 1, 0, 1;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+516, %descB+256, p, 1, 1, 0, 1;
wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+518, %descB+348, p, 1, 1, 0, 1;
wgmma.commit_group.sync.aligned;
wgmma.wait_group.sync.aligned 1;

The Op keeps

  • first 64 registers ({%f1, %f2, 62 more registers}) -> %acc1
  • second 64 registers ({%f500,%f501, 62 more registers}) -> %acc2.

@joker-eph
Copy link
Collaborator

Typo in the title: nvgpu.wargroup.mma -> nvgpu.warpgroup.mma

@grypp grypp changed the title [MLIR][NVGPU] Adding nvgpu.wargroup.mma Op for Hopper GPUs [MLIR][NVGPU] Adding nvgpu.warpgroup.mma Op for Hopper GPUs Sep 7, 2023
@grypp
Copy link
Member Author

grypp commented Sep 7, 2023

Typo in the title: nvgpu.wargroup.mma -> nvgpu.warpgroup.mma

Good catch. I somehow wrote wargroup, I put it everywhere :)

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good to me.
Highlighted a few places where comments would make things easier to understand.

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Show resolved Hide resolved
```mlir
%res = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc:
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call out that we only support sizes that are a multiple of 8 and are in [8; 256].
Unless we plan to lift that up?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understood this comment

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-nvgpu

Changes This work introduces a new operation called `warpgroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture.

Previously, the nvvm.wgmma.mma_async operation was introduced to support warpgroup-level matrix operations in NVVM dialect. This op is used multiple instances of nvvm.wgmma.mma_async to achieve the desired shape. The new nvgpu.warpgroup.mma operation abstracts this complexity and provides a higher-level interface for performing warpgroup-level matrix operations.

The nvgpu.warpgroup.mma does followings:

  1. Corresponds multiple wgmma instructions.
  2. Iterates input matrix descriptors to achieve the desired computation shape. 3) Groups and runs wgmma instructions asynchronously, and eventually waits them. This are done by wgmma.fence.aligned, wgmma.commit.group.sync.aligned, and wgmma.wait.group.sync.aligned 4) Results fragmented matrices

Here's an example usage of the nvgpu.warpgroup.mma operation:

%wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
      vector<128x128xf32>
      -> !nvgpu.warpgroup.result<tensor = !llvm.struct<...>,
         !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>

--

Patch is 36.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65440.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+56)
  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+2)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+163-3)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+131-4)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+15)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+61-1)
  • (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+44)

<pre>
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index a3245bf9196eed1..90381648dac6acc 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -192,6 +192,19 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type&lt;&quot;WarpgroupMatrixDescriptor&quot;, &quot;w
let assemblyFormat = &quot;&amp;lt; struct(params) &amp;gt;&quot;;
}

+def NVGPU_WarpgroupAccumulator : NVGPU_Type&lt;&quot;WarpgroupAccumulator&quot;, &quot;warpgroup.accumulator&quot;, []&gt; {

  • let parameters = (ins &quot;VectorType&quot;:$fragmented);
  • let assemblyFormat = &quot;&amp;lt; struct(params) &amp;gt;&quot;;
  • let description = [{
  • This type represents the result matrix obtained from nvgpu.warpgroup.mma.
  • The $fragmented type signifies the distributed or fragmented result
  • vector that is collectively owned by all the threads in the warp-group
  • that executed nvgpu.warpgroup.mma.
  • [See the details of register fragment layout for accumulator matrix D]
  • (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
  • }];
    +}

//===----------------------------------------------------------------------===//
// NVGPU Op Definitions
//===----------------------------------------------------------------------===//
@@ -664,5 +677,48 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op&lt;&quot;wgmma.generate.descriptor&quot;, []&gt; {
let hasVerifier = 1;
}

+def NVGPU_WarpgroupMmaOp : NVGPU_Op&lt;&quot;warpgroup.mma&quot;&gt; {

  • let description = [{
  • The nvgpu.warpgroup.mma op performs the warpgroup-level (4 warps)
  • matrix-multiply-and-accumulate (mma) operation that results in
  • nvvm.wgmma.mma_async.
  • The operands are descriptorA and descriptorB that are wgmma matrix
  • descriptors that shows the properties of the matrix in shared memory. The
  • results are thread-level ownership to the warpgroup-level mma operation
  • shape. The shape is deduced from the descriptor types and output vector.
  • The Op corresponds multiple nvvm.wgmma.mma_async operations to complete the
  • given shape. As the instruction nvvm.wgmma.async is an asynchronous,
  • this Op groups the nvvm.wgmma.async and surrounds them between
  • wgmma.fence.aligned and wgmma.commit.group.sync.aligned,
  • wgmma.wait.group.sync.aligned Ops.
  • Example:
  •  %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2: 
    
  •             !nvgpu.wgmma.descriptor&amp;lt;tensor = memref&amp;lt;128x64xf16, 3&amp;gt;&amp;gt;, 
    
  •             !nvgpu.wgmma.descriptor&amp;lt;tensor = memref&amp;lt;64x128xf16, 3&amp;gt;&amp;gt;, 
    
  •             !nvgpu.warpgroup.accumulator&amp;lt;fragmented = vector&amp;lt;64x128xf32&amp;gt;&amp;gt;,
    
  •             !nvgpu.warpgroup.accumulator&amp;lt;fragmented = vector&amp;lt;64x128xf32&amp;gt;&amp;gt;
    
  •             -&amp;gt; 
    
  •             !nvgpu.warpgroup.accumulator&amp;lt;fragmented = vector&amp;lt;64x128xf32&amp;gt;&amp;gt;,
    
  •             !nvgpu.warpgroup.accumulator&amp;lt;fragmented = vector&amp;lt;64x128xf32&amp;gt;&amp;gt;
    
  • }];
  • let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
  •                   NVGPU_WarpgroupMatrixDescriptor:$descriptorB,                                               
    
  •                   DefaultValuedOptionalAttr&amp;lt;I32Attr, &amp;quot;1&amp;quot;&amp;gt;:$waitGroup,
    
  •                   OptionalAttr&amp;lt;UnitAttr&amp;gt;:$transposeA,
    
  •                   OptionalAttr&amp;lt;UnitAttr&amp;gt;:$transposeB,
    
  •                   Variadic&amp;lt;NVGPU_WarpgroupAccumulator&amp;gt;:$matrixC);
    
  • let results = (outs Variadic&lt;NVGPU_WarpgroupAccumulator&gt;:$matrixD);
  • let assemblyFormat = [{
  • $descriptorA, $descriptorB, $matrixC attr-dict
  • : type($descriptorA) , type($descriptorB) , type($matrixC) -&amp;gt; type($matrixD)
  • }];
  • let hasVerifier = 1;
    +}

#endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 192afcb2dba7913..96af26842dafea2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -21,6 +21,8 @@

#include &quot;mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc&quot;

+constexpr int kWarpSize = 32;
+
#define GET_ATTRDEF_CLASSES
#include &quot;mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc&quot;

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..046727e4ea9ab83 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -17,10 +17,12 @@
#include &quot;mlir/Dialect/LLVMIR/NVVMDialect.h&quot;
#include &quot;mlir/Dialect/MemRef/IR/MemRef.h&quot;
#include &quot;mlir/Dialect/NVGPU/IR/NVGPUDialect.h&quot;
+#include &quot;mlir/Dialect/SCF/Transforms/Patterns.h&quot;
#include &quot;mlir/IR/PatternMatch.h&quot;
#include &quot;mlir/IR/TypeUtilities.h&quot;
#include &quot;mlir/Pass/Pass.h&quot;
#include &quot;llvm/Support/Debug.h&quot;
+#include &quot;llvm/Support/ErrorHandling.h&quot;
#include &quot;llvm/Support/raw_ostream.h&quot;

#define DEBUG_TYPE &quot;nvgpu-to-nvvm&quot;
@@ -34,6 +36,10 @@ namespace mlir {

using namespace mlir;

+/// Number of bits that needs to excluded when building matrix descriptor for
+/// wgmma operations.
+constexpr int exclude4LSB = 4;
+
/// GPU has 32 bit registers, this function truncates values when larger width
/// is not needed.
static Value truncToI32(ConversionPatternRewriter &amp;rewriter, Location loc,
@@ -419,6 +425,15 @@ struct ConvertNVGPUToNVVMPass
converter.addConversion([&amp;](nvgpu::DeviceAsyncTokenType type) -&gt; Type {
return converter.convertType(IntegerType::get(type.getContext(), 32));
});

  • converter.addConversion([&amp;](nvgpu::WarpgroupAccumulatorType type) -&gt; Type {
  •  VectorType vtype = type.getFragmented();
    
  •  SmallVector&amp;lt;Type&amp;gt; structBody;
    
  •  for (unsigned i = 0; i &amp;lt; vtype.getDimSize(0); i++)
    
  •    structBody.push_back(vtype.getElementType());
    
  •  auto convertedType =
    
  •      LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
    
  •  return converter.convertType(convertedType);
    
  • });
    converter.addConversion([&amp;](nvgpu::MBarrierTokenType type) -&gt; Type {
    return converter.convertType(IntegerType::get(type.getContext(), 64));
    });
    @@ -438,6 +453,8 @@ struct ConvertNVGPUToNVVMPass
    target.addLegalDialect&lt;::mlir::LLVM::LLVMDialect&gt;();
    target.addLegalDialect&lt;::mlir::memref::MemRefDialect&gt;();
    target.addLegalDialect&lt;::mlir::NVVM::NVVMDialect&gt;();
  • mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
  •    converter, patterns, target);
    
    if (failed(applyPartialConversion(getOperation(), target,
    std::move(patterns))))
    signalPassFailure();
    @@ -984,10 +1001,9 @@ struct NVGPUGenerateGmmaDescriptorLowering
    shiftLeft(val, startBit));
    };
  • int ex4LSB = 4;
    int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
  • uint64_t strideDimVal = (layout &lt;&lt; 3) &gt;&gt; ex4LSB;
  • uint64_t leadDimVal = (sizeN * layout) &gt;&gt; ex4LSB;
  • uint64_t strideDimVal = (layout &lt;&lt; 3) &gt;&gt; exclude4LSB;

  • uint64_t leadDimVal = (sizeN * layout) &gt;&gt; exclude4LSB;
    uint64_t offsetVal = 0;

    Value strideDim = makeConst(strideDimVal);
    @@ -1141,6 +1157,149 @@ struct NVGPUTmaCreateDescriptorOpLowering
    }
    };

+struct NVGPUWarpgroupMmaOpLowering

  • : public ConvertOpToLLVMPattern&lt;nvgpu::WarpgroupMmaOp&gt; {
  • using ConvertOpToLLVMPattern&lt;nvgpu::WarpgroupMmaOp&gt;::ConvertOpToLLVMPattern;
  • LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
  •                          int &amp;amp;wgmmaShapeM, int &amp;amp;wgmmaShapeN,
    
  •                          int &amp;amp;wgmmaShapeK) const {
    
  • wgmmaShapeM = 64;
  • wgmmaShapeN = sizeN;
  • if (inputElemType.isTF32()) {
  •  wgmmaShapeK = 8;
    
  • } else if (inputElemType.isF16() || inputElemType.isBF16()) {
  •  wgmmaShapeK = 16;
    
  • } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
  •           inputElemType.isInteger(16)) {
    
  •  wgmmaShapeK = 32;
    
  • } else if (inputElemType.isInteger(1)) {
  •  wgmmaShapeK = 256;
    
  • } else {
  •  llvm_unreachable(&amp;quot;msg: not supported K shape&amp;quot;);
    
  • }
  • LLVM_DEBUG(DBGS() &lt;&lt; &quot;Generating wgmma.mma.async shape[m = &quot; &lt;&lt; wgmmaShapeM
  •                  &amp;lt;&amp;lt; &amp;quot;, n = &amp;quot; &amp;lt;&amp;lt; wgmmaShapeN &amp;lt;&amp;lt; &amp;quot;, k = &amp;quot; &amp;lt;&amp;lt; wgmmaShapeK
    
  •                  &amp;lt;&amp;lt; &amp;quot;]\n&amp;quot;);
    
  • return success();
  • }
  • Value generateNVVMWgmmaOp(MLIRContext *ctx,
  •                        ConversionPatternRewriter &amp;amp;rewriter, Location loc,
    
  •                        int m, int n, int k, Type resultStructType,
    
  •                        Value inout, Value descriptorA,
    
  •                        Value descriptorB) const {
    
  • TypeRange resultTypes = {resultStructType};
  • auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
  • auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
  • auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
  • auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
  • auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
  • // todo input type
  • auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
  • auto overflow =
  •    NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
    
  • Value res = rewriter.create&lt;NVVM::WgmmaMmaAsyncOp&gt;(
  •    loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
    
  •    scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
    
  • return res;
  • }
  • LogicalResult
  • matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
  •              ConversionPatternRewriter &amp;amp;rewriter) const override {
    
  • int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
  • int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
  • int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
  • LLVM_DEBUG(DBGS() &lt;&lt; &quot;===--- GEMM D[&quot; &lt;&lt; sizeM &lt;&lt; &quot;][&quot; &lt;&lt; sizeN &lt;&lt; &quot;] += A[&quot;
  •                  &amp;lt;&amp;lt; sizeM &amp;lt;&amp;lt; &amp;quot;][&amp;quot; &amp;lt;&amp;lt; sizeK &amp;lt;&amp;lt; &amp;quot;] * B[&amp;quot; &amp;lt;&amp;lt; sizeK &amp;lt;&amp;lt; &amp;quot;][&amp;quot;
    
  •                  &amp;lt;&amp;lt; sizeN &amp;lt;&amp;lt; &amp;quot;] ---===\n&amp;quot;);
    
  • int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
  • if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
  •                         wgmmaShapeN, wgmmaShapeK))) {
    
  •  return failure();
    
  • }
  • Value descriptorA = adaptor.getDescriptorA();
  • Value descriptorB = adaptor.getDescriptorB();
  • // Generate wgmma group
  • auto loc = op-&gt;getLoc();
  • MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
  • MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
  • auto makeAdd = [&amp;](Value lhs, Value rhs) -&gt; Value {
  •  return rewriter.create&amp;lt;LLVM::AddOp&amp;gt;(loc, lhs.getType(), lhs, rhs);
    
  • };
  • auto iterateDescA = [&amp;](Value desc, int iterM, int iterN,
  •                        int iterK) -&amp;gt; Value {
    
  •  // todo : Handle column major
    
  •  int byte = typeTensorA.getElementTypeBitWidth() / 8;
    
  •  int tileShapeA = typeTensorA.getDimSize(1);
    
  •  int incrementVal =
    
  •      ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
    
  •  incrementVal = incrementVal &amp;gt;&amp;gt; exclude4LSB;
    
  •  LLVM_DEBUG(DBGS() &amp;lt;&amp;lt; &amp;quot;\t\t[m: &amp;quot; &amp;lt;&amp;lt; iterM &amp;lt;&amp;lt; &amp;quot; n: &amp;quot; &amp;lt;&amp;lt; iterN &amp;lt;&amp;lt; &amp;quot; k: &amp;quot;
    
  •                    &amp;lt;&amp;lt; iterK &amp;lt;&amp;lt; &amp;quot;] [wgmma descriptors] Descriptor A + &amp;quot;
    
  •                    &amp;lt;&amp;lt; incrementVal &amp;lt;&amp;lt; &amp;quot; | \t &amp;quot;);
    
  •  if (!incrementVal)
    
  •    return desc;
    
  •  return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
    
  • };
  • auto iterateDescB = [&amp;](Value desc, int iterM, int iterN,
  •                        int iterK) -&amp;gt; Value {
    
  •  // todo : Handle row major
    
  •  int byte = typeTensorB.getElementTypeBitWidth() / 8;
    
  •  int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
    
  •  incrementVal = incrementVal &amp;gt;&amp;gt; exclude4LSB;
    
  •  LLVM_DEBUG(DBGSE() &amp;lt;&amp;lt; &amp;quot;Descriptor B + &amp;quot; &amp;lt;&amp;lt; incrementVal &amp;lt;&amp;lt; &amp;quot;\n&amp;quot;);
    
  •  if (!incrementVal)
    
  •    return desc;
    
  •  return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
    
  • };
  • rewriter.create&lt;NVVM::WgmmaFenceAlignedOp&gt;(loc);
  • SmallVector&lt;Value&gt; wgmmaResults;
  • for (int iterM = 0; iterM &lt; (sizeM / wgmmaShapeM); iterM++) {
  •  Value matrixC = adaptor.getMatrixC()[iterM];
    
  •  Value matrixD = op.getMatrixD()[iterM];
    
  •  Type structType = getTypeConverter()-&amp;gt;convertType(matrixD.getType());
    
  •  LLVM_DEBUG(DBGS() &amp;lt;&amp;lt; &amp;quot; D[&amp;quot; &amp;lt;&amp;lt; (iterM * wgmmaShapeM) &amp;lt;&amp;lt; &amp;quot;:&amp;quot;
    
  •                    &amp;lt;&amp;lt; (iterM * wgmmaShapeM) + wgmmaShapeM &amp;lt;&amp;lt; &amp;quot;][&amp;quot; &amp;lt;&amp;lt; 0
    
  •                    &amp;lt;&amp;lt; &amp;quot;:&amp;quot; &amp;lt;&amp;lt; wgmmaShapeN &amp;lt;&amp;lt; &amp;quot;] += \n&amp;quot;);
    
  •  for (int iterK = 0; iterK &amp;lt; (sizeK / wgmmaShapeK); iterK++) {
    
  •    Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
    
  •    Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
    
  •    LLVM_DEBUG(DBGS() &amp;lt;&amp;lt; &amp;quot;\t wgmma.&amp;quot;
    
  •                      &amp;lt;&amp;lt; &amp;quot;m&amp;quot; &amp;lt;&amp;lt; wgmmaShapeM &amp;lt;&amp;lt; &amp;quot;n&amp;quot; &amp;lt;&amp;lt; wgmmaShapeN &amp;lt;&amp;lt; &amp;quot;k&amp;quot;
    
  •                      &amp;lt;&amp;lt; wgmmaShapeK &amp;lt;&amp;lt; &amp;quot;(A[&amp;quot; &amp;lt;&amp;lt; (iterM * wgmmaShapeM)
    
  •                      &amp;lt;&amp;lt; &amp;quot;:&amp;quot; &amp;lt;&amp;lt; (iterM * wgmmaShapeM) + wgmmaShapeM &amp;lt;&amp;lt; &amp;quot;][&amp;quot;
    
  •                      &amp;lt;&amp;lt; (iterK * wgmmaShapeK) &amp;lt;&amp;lt; &amp;quot;:&amp;quot;
    
  •                      &amp;lt;&amp;lt; (iterK * wgmmaShapeK + wgmmaShapeK) &amp;lt;&amp;lt; &amp;quot;] * &amp;quot;
    
  •                      &amp;lt;&amp;lt; &amp;quot; B[&amp;quot; &amp;lt;&amp;lt; (iterK * wgmmaShapeK) &amp;lt;&amp;lt; &amp;quot;:&amp;quot;
    
  •                      &amp;lt;&amp;lt; (iterK * wgmmaShapeK + wgmmaShapeK) &amp;lt;&amp;lt; &amp;quot;][&amp;quot; &amp;lt;&amp;lt; 0
    
  •                      &amp;lt;&amp;lt; &amp;quot;:&amp;quot; &amp;lt;&amp;lt; wgmmaShapeN &amp;lt;&amp;lt; &amp;quot;])\n&amp;quot;);
    
  •    matrixC = generateNVVMWgmmaOp(op-&amp;gt;getContext(), rewriter, loc,
    
  •                                  wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
    
  •                                  structType, matrixC, descA, descB);
    
  •  }
    
  •  wgmmaResults.push_back(matrixC);
    
  • }
  • rewriter.create&lt;NVVM::WgmmaGroupSyncAlignedOp&gt;(loc);
  • rewriter.create&lt;NVVM::WgmmaWaitGroupSyncOp&gt;(loc, op.getWaitGroup());
  • ValueRange myres(wgmmaResults);
  • rewriter.replaceOp(op, myres);
  • return success();
  • }
    +};

} // namespace

void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &amp;converter,
@@ -1156,6 +1315,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &amp;converter,
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor

  •  NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
     MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
     NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
     NVGPUMmaSparseSyncLowering&amp;gt;(converter);
    

diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d832a983a132d61..d96ed69982870b4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -22,6 +22,7 @@
#include &quot;mlir/IR/PatternMatch.h&quot;
#include &quot;mlir/IR/TypeUtilities.h&quot;
#include &quot;mlir/IR/Verifier.h&quot;
+#include &quot;llvm/ADT/STLExtras.h&quot;
#include &quot;llvm/ADT/StringExtras.h&quot;
#include &quot;llvm/ADT/TypeSwitch.h&quot;

@@ -151,7 +152,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
// - For F32 (TF32), F16, S8, and S4 data
// types the fundamental tensor core operation is of shape 8-by-8-by-128b.
// - F64 is an exception and is of shape 8-by-8-by-256b.

  • constexpr int kThreads = 32; // 32 threads per warp
    int64_t shapeM = 8;
    int64_t shapeN = 8;
    int64_t shapeK; // set based on data type (128b for all data types except F64)
    @@ -206,17 +206,17 @@ static LogicalResult verifyMmaSyncOp(Operation *op,

    // verify warp-wide size for vector a
    int64_t sparseFactor = sparse ? 2 : 1;

  • if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor)

  • if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
    return op-&gt;emitOpError()
    &lt;&lt; &quot;expected &quot; &lt;&lt; m * k &lt;&lt; &quot; warp-wide matrix A elements&quot;;

    // verify warp-wide size for vector b

  • if (bShape[0] * bShape[1] * kThreads != k * n)
  • if (bShape[0] * bShape[1] * kWarpSize != k * n)
    return op-&gt;emitOpError()
    &lt;&lt; &quot;expected &quot; &lt;&lt; k * n &lt;&lt; &quot; warp-wide matrix B elements&quot;;

    // verify warp-wide size for vector c

  • if (cShape[0] * cShape[1] * kThreads != m * n)
  • if (cShape[0] * cShape[1] * kWarpSize != m * n)
    return op-&gt;emitOpError()
    &lt;&lt; &quot;expected &quot; &lt;&lt; m * n &lt;&lt; &quot; warp-wide matrix C elements&quot;;

@@ -402,6 +402,133 @@ LogicalResult GenerateGmmaDescriptorOp::verify() {
return success();
}

+//===----------------------------------------------------------------------===//
+// WarpgroupMmaOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {

  • // F32 += F16 + F16
  • // F16 += F16 + F16
  • if (typeA.isF16() &amp;&amp; typeB.isF16() &amp;&amp; (typeD.isF32() || typeD.isF16()))
  • return success();
  • // F32 += TF32 + TF32
  • if (typeA.isTF32() &amp;&amp; typeD.isF32() &amp;&amp; typeB.isTF32())
  • return success();
  • // s32 += i8 + i8
  • if (typeA.isInteger(16) &amp;&amp; typeB.isInteger(16) &amp;&amp; typeD.isInteger(32))
  • return success();
  • // s32 += i1 + i1
  • if (typeA.isInteger(1) &amp;&amp; typeB.isInteger(1) &amp;&amp; typeD.isInteger(32))
  • return success();
  • // F32 += BF16 + BF16
  • // F16 += BF16 + BF16
  • if (typeA.isBF16() &amp;&amp; typeB.isBF16() &amp;&amp; (typeD.isF32() || typeD.isF16()))
  • return success();
  • // F16 += f8 + f8
  • // F32 += f8 + f8
  • if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &amp;&amp;
  •  (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &amp;amp;&amp;amp;
    
  •  (typeD.isF32() || typeD.isF16()))
    
  • return success();
  • return failure();
    +}

+LogicalResult isAllowedSizeN(int sizeN, Type typeA) {

  • SmallVector&lt;int&gt; allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
  •                           72,  80,  88,  96,  104, 112, 120, 128,
    
  •                           136, 144, 152, 160, 168, 176, 184, 192,
    
  •                           200, 208, 216, 224, 232, 240, 248, 256};
    
  • SmallVector&lt;int&gt; allowedNshort = {8, 16, 24, 32, 48, 64,
  •                                80,  96,  112, 128, 144, 160,
    
  •                                176, 192, 208, 224, 240, 256};
    
  • if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
  •  typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
    
  • if (llvm::any_of(allo...

This work introduces a new operation called `wargroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture.

Previously, the `nvvm.wgmma.mma_async` operation was introduced to support wargroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.wargroup.mma` operation abstracts this complexity and provides a higher-level interface for performing wargroup-level matrix operations.

The `nvgpu.wargroup.mma` does followings:
1) Corresponds multiple `wgmma` instructions.
2) Iterates input matrix descriptors to achieve the desired computation shape.
3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned`
4) Results fragmented matrices

Here's an example usage of the `nvgpu.wargroup.mma` operation:
```
%wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
      vector<128x128xf32>
      -> !nvgpu.warpgroup.result<tensor = !llvm.struct<...>,
  		 !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
```

Differential Revision: https://reviews.llvm.org/D158434
@grypp grypp merged commit 2388222 into llvm:main Sep 22, 2023
grypp added a commit to grypp/llvm-project that referenced this pull request Oct 5, 2023
This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.

An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The `warpgroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.

Here's an example usage of the `nvgpu.warpgroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.warpgroup.mma ...

// Stores the fragmented result to the give memory
nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD :
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>,
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
                to memref<128x128xf32,3>
```

Depends on llvm#65440
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants