Skip to content

Commit

Permalink
[mlir][nvgpu] Improve WarpgroupAccumulator type to simplify IR (#68728
Browse files Browse the repository at this point in the history
)

`WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type
that keeps the accumulator matrix that is used by warp-group level
matrix multiplication. It is handy to have a special type for that as
the matrix is distributed among the threads of the warp-group. However,
current transformations requires to create and use multiple
`WarpgroupAccumulator` if the shape of GEMM is larger than the supported
shape of `wgmma.mma_async` instruction. This makes IR looks dense.

This PR improves the transformation of `WarpgroupAccumulator` type in
every nvgpu Op that uses it.

**Example: Current GEMM in NVGPU-IR**
```
// Init
%m1, %m2 = nvgpu.warpgroup.mma.init.accumulator ->  
                    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
                    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>

// GEMM
%r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}: 
      !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
      !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> 
      -> 
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>  


// Epilogue 
nvgpu.warpgroup.mma.store [%r1, %r2] to %sharedMemoryBuffer
  : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
    into memref<128x128xf32,3>
```

**Example: This PR simplifies the IR as below:**
```
// Init
%m = nvgpu.warpgroup.mma.init.accumulator ->  
           !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>

// GEMM
%r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}: 
      !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
      !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> 
      -> 
      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>  

// Epilogue 
nvgpu.warpgroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer
  : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
    into memref<128x128xf32,3>
```
  • Loading branch information
grypp authored Oct 17, 2023
1 parent 838f289 commit 52db7e2
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 158 deletions.
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
NVGPU_WarpgroupAccumulator:$matrixC);
let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
let assemblyFormat = [{
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
Expand All @@ -739,11 +739,11 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
Note that, the op must be run with warp group.
}];

let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);

let assemblyFormat = [{
`[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
$matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
Expand All @@ -755,7 +755,7 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
This Op generates and initializes the accumulator matrix for
`nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
}];
let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
let assemblyFormat = "attr-dict `->` type($matrixC)";
let hasVerifier = 1;
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

constexpr int kWarpSize = 32;

/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"

Expand Down
112 changes: 69 additions & 43 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
VectorType vtype = type.getFragmented();
Type elemType = type.getFragmented().getElementType();
int64_t sizeM = type.getFragmented().getDimSize(0);
int64_t sizeN = type.getFragmented().getDimSize(1);

unsigned numMembers;
if (elemType.isF32() || elemType.isInteger(32))
numMembers = sizeN / 2;
else if (elemType.isF16())
numMembers = sizeN / 4;
else
llvm_unreachable("unsupported type for warpgroup accumulator");

SmallVector<Type> innerStructBody;
for (unsigned i = 0; i < numMembers; i++)
innerStructBody.push_back(elemType);
auto innerStructType =
LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);

SmallVector<Type> structBody;
for (unsigned i = 0; i < vtype.getDimSize(0); i++)
structBody.push_back(vtype.getElementType());
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
structBody.push_back(innerStructType);

auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return converter.convertType(convertedType);
Expand Down Expand Up @@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
nvgpu::WarpgroupMmaOp op;
ImplicitLocOpBuilder b;
OpAdaptor adaptor;
const LLVMTypeConverter &typeConverter;

// Entire shape of the given Op
int64_t totalM, totalN, totalK;
Expand Down Expand Up @@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering

/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
Value generateWgmma(int i, int j, int k, Value matrixC) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
Expand Down Expand Up @@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);

Type resultStructType = typeConverter.convertType(matrixD.getType());

return b.create<NVVM::WgmmaMmaAsyncOp>(
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
}

/// Generates multiple wgmma instructions to complete the given GEMM shape
SmallVector<Value> generateWgmmaGroup() {
SmallVector<Value> wgmmaResults;
Value generateWgmmaGroup() {
Value wgmmaResult =
b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());

// Perform GEMM
SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
Value matrixC = adaptor.getMatrixC()[i];
Value matrixD = op.getMatrixD()[i];
Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}

return wgmmaResults;
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
wgmmaResult, matrix, idx);
}
return wgmmaResult;
}

public:
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
OpAdaptor adaptor)
: op(op), b(b), adaptor(adaptor) {
// Find the entire GEMM Shape
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
Expand All @@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
/// instructions and group synchronization, as well as waiting
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
SmallVector<Value> generateWarpgroupMma() {
Value generateWarpgroupMma() {
b.create<NVVM::WgmmaFenceAlignedOp>();
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
Value wgmmaResult = generateWgmmaGroup();
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
return wgmmaResults;
return wgmmaResult;
}
};

LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);

// Step 1. Build a helper class
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
WarpgroupGemm warpgroupGemm(op, b, adaptor);

// Step 2. Get the entire GEMM Shape
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();

// Step 3. Replace fragmented result struct with the op results
rewriter.replaceOp(op, wgmmaResults);
rewriter.replaceOp(op, wgmmaResult);
return success();
}
};
Expand Down Expand Up @@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int offset = 0;
ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
for (Value matrixD : adaptor.getMatrixD()) {
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value matriDValue = adaptor.getMatrixD();
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = matrixD.cast<LLVM::LLVMStructType>();
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
Expand All @@ -1554,23 +1576,27 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
SmallVector<Value> results;
for (OpResult m : op.getMatrixC()) {
nvgpu::WarpgroupAccumulatorType mType =
m.getType().cast<nvgpu::WarpgroupAccumulatorType>();
Type stype = getTypeConverter()->convertType(mType);
Value undefStruct = b.create<LLVM::UndefOp>(stype);
Type elemType = mType.getFragmented().getElementType();
int64_t elemSize = mType.getFragmented().getDimSize(0);
Value zero =
b.create<LLVM::ConstantOp>(elemType, rewriter.getZeroAttr(elemType));
for (int64_t i = 0; i < elemSize; ++i) {
undefStruct = b.create<LLVM::InsertValueOp>(stype, undefStruct, zero,
ArrayRef<int64_t>({i}));
LLVM::LLVMStructType structType =
getTypeConverter()
->convertType(op.getMatrixC().getType())
.cast<LLVM::LLVMStructType>();
Type elemType = structType.getBody()
.front()
.cast<LLVM::LLVMStructType>()
.getBody()
.front();
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
Value structValue = b.create<LLVM::UndefOp>(structType);
for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
auto innerStructType = s.cast<LLVM::LLVMStructType>();
int ii = idx;
Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
innerStructValue = b.create<LLVM::InsertValueOp>(
innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
}
results.push_back(undefStruct);
}
rewriter.replaceOp(op, results);
rewriter.replaceOp(op, structValue);
return success();
}
};
Expand Down
99 changes: 35 additions & 64 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,11 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return failure();
}

LogicalResult isAllowedSizeM(int sizeM) { return success(sizeM == 64); }
LogicalResult isAllowedSizeM(int sizeM) {
if (sizeM % kWgmmaSizeM)
return failure();
return success();
}

LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
Expand All @@ -458,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {

LogicalResult WarpgroupMmaOp::verify() {
if (getTransposeA() && !getTransposeB())
return emitOpError() << "supports non-transpose A (Row Major) "
"and transpose B (Column Major) for the time being";
return emitOpError()
<< "supports non-transpose A (Row Major) "
"and transpose B (Column Major) for the time being ";
MemRefType matrixA = getDescriptorA().getType().getTensor();
MemRefType matrixB = getDescriptorB().getType().getTensor();
VectorType matrixC = getMatrixC()
.front()
.getType()
.cast<WarpgroupAccumulatorType>()
.getFragmented();
VectorType matrixD = getMatrixD()
.front()
.getType()
.cast<WarpgroupAccumulatorType>()
.getFragmented();
unsigned sizeAcc = getMatrixC().size();

if (getMatrixC().size() != getMatrixD().size())
return emitOpError() << "number of matrix C and matrix D must be the same";

if (llvm::all_of(getMatrixC(),
[&](Value rhs) { return rhs.getType() == matrixC; })) {
return emitOpError()
<< "types of all operands in matrix C must be the same";
}
if (llvm::all_of(getMatrixD(),
[&](Value rhs) { return rhs.getType() == matrixC; })) {
return emitOpError()
<< "types of all operands in matrix D must be the same as matrix C";
}
VectorType matrixC = getMatrixC().getType().getFragmented();
VectorType matrixD = getMatrixD().getType().getFragmented();

if (matrixC != matrixD)
return emitOpError() << "type of matrix C and matrix D must be the same";

if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
matrixC.getRank() != 2 || matrixD.getRank() != 2) {
Expand All @@ -498,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
<< ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
<< " )";
if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
if (matrixA.getShape()[0] != matrixC.getShape()[0])
return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
<< " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
<< " )";
Expand Down Expand Up @@ -534,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {

LogicalResult WarpgroupMmaStoreOp::verify() {
MemRefType dstMemrefType = getDstMemref().getType();
VectorType firstVtype = getMatrixD()
.front()
.getType()
.cast<WarpgroupAccumulatorType>()
.getFragmented();

int64_t totalFirstDimension = 0;
for (Value result : getMatrixD()) {
VectorType vtype =
result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
if (vtype != firstVtype)
return emitOpError() << "all fragmented types must be the same";
// Limitation
if (!vtype.getElementType().isF32()) {
return emitOpError()
<< "hit a limitation: only f32 results for the time being";
}
totalFirstDimension += vtype.getDimSize(0);
VectorType vtype = getMatrixD().getType().getFragmented();

// Limitation
if (!vtype.getElementType().isF32()) {
return emitOpError()
<< "hit a limitation: only f32 results for the time being";
}
if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
return emitOpError() << "results [" << totalFirstDimension << "]["
<< firstVtype.getDimSize(1)
if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
<< "] values. However, destination memref["
<< dstMemrefType.getDimSize(0) << "]["
<< dstMemrefType.getDimSize(1)
Expand All @@ -570,19 +542,18 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
for (OpResult matrix : getMatrixC()) {
VectorType vectorType = matrix.getType()
.cast<nvgpu::WarpgroupAccumulatorType>()
.getFragmented();
// Check [M][N] shape
if (failed(isAllowedSizeM(vectorType.getDimSize(0))) ||
failed(isAllowedSizeN(vectorType.getDimSize(1),
vectorType.getElementType()))) {
return emitOpError() << "has type " << vectorType
<< ". It does not fit into warp-group "
"level (wgmma) matrix multiplication instruction "
"(or not supported yet)";
}

nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
int64_t sizeM = accType.getFragmented().getDimSize(0);
int64_t sizeN = accType.getFragmented().getDimSize(1);
Type elemType = accType.getFragmented().getElementType();

if (failed(isAllowedSizeM(sizeM)) ||
failed(isAllowedSizeN(sizeN, elemType))) {
return emitOpError() << "has type " << accType.getFragmented()
<< ". It does not fit into warp-group "
"level (wgmma) matrix multiplication instruction "
"(or not supported yet)";
}
return success();
}
Expand Down
Loading

0 comments on commit 52db7e2

Please sign in to comment.