Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Implement dot_scaled(mmav3) #5269

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ LinearLayout ensureLayoutNotSmallerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);

SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

// Dump information about which threads/registers contain each of the tensor
// elements.
void dumpLayout(RankedTensorType tensorType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (srcBlocked && dstDotOp) {
auto dotParent = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent());
if (dotParent && dotParent.isAmpere()) {
if (dotParent) {
return;
}
Attribute sharedMemorySpace =
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,36 @@ LinearLayout ensureLayoutNotSmallerThan(
return ret;
}

// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
SmallVector<StringAttr> ret;
for (int i = 0; i < rank; i++) {
ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(i)));
}
return ret;
}

// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
// permute(shape, order).
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
MLIRContext *ctx = inDimName.getContext();
auto rank = shape.size();

// The order in triton is written wrt. [dim0, dim1, ...].
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

LinearLayout ret = LinearLayout::empty();
for (int i = 0; i < shape.size(); i++) {
// Start with the most-minor dimension, which is order[0].
int dim = order[i];
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
}
return ret;
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
30 changes: 0 additions & 30 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ namespace {

#define S(v) StringAttr::get(ctx, (v))

// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
SmallVector<StringAttr> ret;
for (int i = 0; i < rank; i++) {
ret.push_back(S("dim" + llvm::Twine(i)));
}
return ret;
}

// TODO Have order be a mandatory argument of standardOutDimNames.
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
const SmallVector<unsigned> &order) {
Expand All @@ -52,27 +43,6 @@ SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
return ret;
}

// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
// permute(shape, order).
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
MLIRContext *ctx = inDimName.getContext();
auto rank = shape.size();

// The order in triton is written wrt. [dim0, dim1, ...].
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

LinearLayout ret = LinearLayout::empty();
for (int i = 0; i < shape.size(); i++) {
// Start with the most-minor dimension, which is order[0].
int dim = order[i];
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
}
return ret;
}

// Make a LinearLayout that maps a block-id to an N-dimensional index.
//
// The tensor is split up into CTAsPerCGA pieces, which are distributed among
Expand Down
117 changes: 73 additions & 44 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

Expand Down Expand Up @@ -394,6 +395,10 @@ class DecomposeScaledBlocked
auto aType = scaledDotOp.getLhsType();
auto bType = scaledDotOp.getRhsType();

auto rank = oldRetType.getShape().size();
if (rank != 2)
return rewriter.notifyMatchFailure(scaledDotOp, "NYI: rank==3");

assert((aType == ScaleDotElemType::E4M3 ||
aType == ScaleDotElemType::E5M2 ||
aType == ScaleDotElemType::E2M1) &&
Expand Down Expand Up @@ -430,71 +435,95 @@ class DecomposeScaledBlocked
// `bases[warps] = {(0, 0), (0, 0), ...}`

auto newAEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaEnc, aKWidth);
auto rank = mmaEnc.getInstrShape().size();

// MMAv3 uses the first dimension for the M dimension, while MMAv2 uses the
// penultimate (ugh)
auto instrShapeM = mmaEnc.getInstrShape()[versionMajor == 3 ? 0 : rank - 2];
auto instrShapeM =
mmaEnc.getInstrShape()[versionMajor == 3
? 0
: mmaEnc.getInstrShape().size() - 2];
auto warpSize = getWarpSize(newAEncoding);
assert(instrShapeM <= warpSize);
// Necessary choice to leave all the scales of the tile in that given warp
auto threadsPerWarp =
SmallVector<unsigned>{instrShapeM, warpSize / instrShapeM};

assert(versionMajor == 2 &&
"NYI: MMAv3. Need to rethink the scale layout otherwise");

// Copy the bases

// This has to align with the order in UpcastMXFPOp
auto order = getMatrixOrder(rank, /*rowMajor=*/true);
Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(),
newAEncoding.getCTAOrder(), mmaEnc.getCTALayout());
ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), order,
mmaEnc.getCTALayout());

// Lezcano: In the future we could just use the LLs unconditionally
// Not doing it now as they are not as performant as Blocked encoding at
// times E.g., we bail on them in the backwardMaterialization pass
auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1;
if (dotBroadcastsWarpLevel) {
// If mma has warpsPerCTA == {2, 2}, then newAEncoding has
// warpsPerCTA == {2, 1}. In this case, we need to broadcast the warps
// on the second dimension as per
// A: 0 1 | 0 1
// - - | - -
// 2 3 | 2 3
// This broadcasting is not representable by standard blocked encodings,
// so we need to use linear layouts.
// This broadcasting is implemented in ampereDotToLinearLayout
auto blocked = cast<BlockedEncodingAttr>(newScaleEncoding);
auto blockedLL = *blocked.toLinearLayout(a.getType().getShape());
LinearLayout::BasesT scaleBases = blockedLL.getBases();
auto nBases = llvm::Log2_32(mmaEnc.getWarpsPerCTA()[1]);
auto &warps = scaleBases[StringAttr::get(ctx, "warp")];
// Prepend the vector of zeros to the warpBases
warps.insert(warps.begin(), nBases, std::vector<int32_t>(rank, 0));
auto outDims = llvm::to_vector(blockedLL.getOutDimNames());
auto newLL = LinearLayout(scaleBases, outDims);
auto llEncoding = LinearEncodingAttr::get(ctx, std::move(newLL));
// Adjust the shape of the layout to match the scale operand
auto scaleShape = scale.getType().getShape();
newScaleEncoding =
LinearEncodingAttr::get(ctx, *llEncoding.toLinearLayout(scaleShape));
auto kRegister = StringAttr::get(ctx, "register");
auto regs = identityStandardND(kRegister, {1, 1}, order);
auto lanes =
identityStandardND(StringAttr::get(ctx, "lane"), {16, 2}, order);

// Extract warp layout from dotAEncoding
// In the future we'll have some nice division utils, but until then...
auto dotLL = *newAEncoding.toLinearLayout(a.getType().getShape());
LinearLayout::BasesT scaleBases = dotLL.getBases();
auto kWarp = StringAttr::get(ctx, "warp");
auto &warpBases = scaleBases[kWarp];
// The tile shape was [16, 2 * 4 * kWidth] with broadcasting in K
// We divide the M dimension by 16
auto div = 16;
for (auto &warpBase : warpBases) {
if (warpBase[rank - 2] != 0) {
assert(warpBase[rank - 2] % div == 0);
warpBase[rank - 2] /= div;
}
}

LinearLayout::BasesT warpBlockBases;
auto standardOutDims = llvm::to_vector(dotLL.getOutDimNames());
warpBlockBases[kWarp] = warpBases;
auto kBlock = StringAttr::get(ctx, "block");
assert(scaleBases[kBlock].empty() && "NYI: CGAs");
warpBlockBases[kBlock] = {};
auto warpBlock = LinearLayout(std::move(warpBlockBases), standardOutDims);

auto newLL =
(regs * lanes) *
warpBlock.transposeOuts(llvm::to_vector(lanes.getOutDimNames()));
auto shape = scale.getType().getShape();

// Broadcast to the correct shape Equivalent to
// newLL = ensureLayoutNotSmallerThan(newLL.transposeOuts(getRepOrder),
// shape);
for (auto d : newAEncoding.getRepOrder()) {
auto outDim = standardOutDims[d];
auto dimSize = newLL.getOutDimSize(outDim);
newLL *=
LinearLayout::identity1D(shape[d] / dimSize, kRegister, outDim);
}
newLL = newLL.transposeOuts(standardOutDims);
newScaleEncoding = LinearEncodingAttr::get(ctx, std::move(newLL));
}

a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding);

// Upcast B operand
assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4");
auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth);
b = createArg(rewriter, b, 1, bType, newBEncoding,
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt);
Operation *newDot = nullptr;
if (versionMajor == 2) {
// Upcast B operand
assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4");
auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth);
b = createArg(rewriter, b, 1, bType, newBEncoding,
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt);
newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), newRetType, a, b,
newAcc);
} else {
assert(versionMajor == 3);
// At the time of this writing, this is always true
auto allowTranspose = b.getType().getElementType().isBF16();
b = cast<TypedValue<RankedTensorType>>(
getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose));
auto bShmem = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
scaledDotOp.getLoc(), newRetType, a, b, newAcc, nullptr);
scaledDotOp.getLoc(), newRetType, a, bShmem, newAcc, nullptr);
}

// convert dot instruction
Expand Down Expand Up @@ -578,11 +607,11 @@ class DecomposeScaledBlocked
auto dotOp = rewriter.create<DotOp>(
scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC());

// Waiting for https://github.com/triton-lang/triton/pull/5003 to land
// cf.
// https://github.com/triton-lang/triton/pull/5003#issuecomment-2445091746
// int versionMajor = getMMAVersionSafe(computeCapability, dotOp);
int versionMajor = 2;
// We just support bf16 for MMAv3 on the rhs
if (bType == ScaleDotElemType::BF16) {
versionMajor = getMMAVersionSafe(computeCapability, dotOp);
}
int versionMinor = computeCapability == 75 ? 1 : 0;

RankedTensorType oldRetType = dotOp.getType();
Expand Down
25 changes: 20 additions & 5 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :

// -----

// Verify that dot_scaled (mxfp4 x bf16) decomposes as expected
// Verify that dot_scaled (mxfp4 x {bf16,fp8}) decomposes to mmav3 if it's bf16, otherwise it fallsback to mmav2
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
Expand All @@ -174,13 +174,28 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.func @dot_scaled(
%a: tensor<128x32xi8, #blocked2>,
%scale: tensor<128x2xi8, #blocked1>,
%b: tensor<64x128xbf16, #blocked>)
-> tensor<128x128xf32, #blocked> {
%b_bf16: tensor<64x128xbf16, #blocked>
) -> tensor<128x128xf32, #blocked> {
// CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, {{.*}}>
// CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>>
// CHECK: triton_nvidia_gpu.warp_group_dot
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
tt.return %result : tensor<128x128xf32, #blocked>
}

// Verify that dot_scaled (mxfp4 x fp8) decomposes into mmav2
// CHECK: dot_scaled_fp8
tt.func @dot_scaled_fp8(
%a: tensor<128x32xi8, #blocked2>,
%scale: tensor<128x2xi8, #blocked1>,
%b_fp8: tensor<64x128xf8E4M3FN, #blocked>
) -> tensor<128x128xf32, #blocked> {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
// CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]>
// CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>>
// CHECK: tt.dot
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
%result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked>
tt.return %result : tensor<128x128xf32, #blocked>
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,7 @@ struct DecomposeUnsupportedConversions
// Remove the decomposeTensorCoreToDotLayoutConversion class entirely after
// we have enabled the new layout conversion for all the cases.
auto nvidiaShortCutFn = [&](RankedTensorType srcTy,
RankedTensorType dstTy) {
auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
// Supported mma to dot conversion
if (nvidiaMma && nvidiaMma.isAmpere())
return true;
// No need to decompose if shared memory is not needed
return matchMmaV3AndDotOperandLayout(srcTy, dstTy) ||
cvtReordersRegisters(srcTy, dstTy);
};
RankedTensorType dstTy) { return true; };
ModuleOp mod = getOperation();
triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod);
triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod,
Expand Down
43 changes: 29 additions & 14 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,43 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
Value warpId = udiv(tid, warpSize);
Value laneId = urem(tid, warpSize);

auto kWidth =
cast<DotOperandEncodingAttr>(op.getType().getEncoding()).getKWidth();

if (fpType == ScaleDotElemType::E2M1)
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);

// Each thread owns elements of 4 mxfp vectors so we need 4 scales
// Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c +
// 16, c + 17
// Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2
// Then, we need elements c and c + 16 for the first two mxfp vectors
// and elements c + 1 and c + 17 for the last two mxfp vectors
auto c = mul(udiv(laneId, i32_val(4)), i32_val(2));
std::array<Value, 4> ci = {c, add(c, i32_val(1)), add(c, i32_val(16)),
std::array<Value, 4> ci = {c, add(c, i32_val(16)), add(c, i32_val(1)),
add(c, i32_val(17))};

// TODO Move this logic to using LinearLayouts
// Each scale in a warp has to be replicated to cover a tile of shape mxk =
// 16x64 This 16x64 tile is split into 4 subtiles of shape 8x32, each of
// which will have to gather a scale and multiply its relevant part of the
// mxfp vector This tile of 8x32 is split in to 8x4 vectors, leaving each
// vector with 1x8 mxfp elements as long as kWidth * 4 <= 32
assert(kWidth <= 8 &&
"NYI for larger kWidth (but we could do it with less shuffles!)");
for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
// column major as per the DotOperandEncoding(opidx=0) layout
auto si = std::array<Value, 4>{
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]),
};

for (int j = 0; j < 32; ++j) {
xVals[32 * i + j] =
LLVM::mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], si[j / 8]);
for (int mxfp = 0; mxfp < 2; ++mxfp) {
auto si = std::array<Value, 2>{
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[mxfp * 2 + 0]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[mxfp * 2 + 1])};
for (int rep = 0; rep < 8 / kWidth; ++rep) {
for (int subTile = 0; subTile < 2; ++subTile) {
for (int k = 0; k < kWidth; ++k) {
auto idx =
32 * i + 16 * mxfp + rep * 2 * kWidth + subTile * kWidth + k;
xVals[idx] =
LLVM::mxfpScaleBf16(rewriter, loc, xVals[idx], si[subTile]);
}
}
}
}
}

Expand Down
Loading