Skip to content

Commit

Permalink
[AMD] Use warp shuffle for fp8 MFMA to dot operand layout conversion (#…
Browse files Browse the repository at this point in the history
…5139)

Adding a shortcut case for fp8 MFMA to dot operand layout conversion
that avoids using shared memory, to speed up FP8 attention kernels.
  • Loading branch information
ilia-cher authored Nov 22, 2024
1 parent 4ae95e7 commit af0649d
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 2 deletions.
5 changes: 5 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);

// Check if MFMA layout can be converted to the dot operand
// layout using warp shuffle.
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
RankedTensorType dstTy);

// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
Expand Down
25 changes: 24 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -632,6 +633,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
return ans;
}

bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (!mfmaLayout || !dotOperandLayout)
return false;

// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
return dotOperandLayout.getParent() == mfmaLayout &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == 8 &&
getContigPerThread(mfmaLayout)[1] == 4 &&
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
triton::type::isFloat8(srcTy.getElementType()) &&
triton::type::isFloat8(dstTy.getElementType()) &&
mfmaLayout.getWarpsPerCTA()[1] == 1;
}

// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
// have a transformation that's the identity on kBlock, we don't need to use
Expand Down Expand Up @@ -730,7 +750,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
// to be removed when generalized warp shuffle conversions
// are ready:
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return failure();
}

// The following check can be removed when generalized warp shuffle
// conversions are ready:
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
return failure();
}

assert(cvtNeedsSharedMemory(srcTy, dstTy));

SmallVector<Value> inVals =
Expand Down
190 changes: 189 additions & 1 deletion test/Conversion/amd/mfma-shortcut.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s

#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
Expand Down Expand Up @@ -27,3 +27,191 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_f8_mfma32
tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load

// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]

// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)

// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>

// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
// CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]]

// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
// CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]]

// Input (8 values): (vec0, vec1)
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
// resVec0 resVec1
// lanes 0-31: (vec0 , vec0 >> 32) (mask0=1)
// lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0)

// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]]
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]]

// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>

// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]

// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma32
tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_f8_mfma16
tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load

// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]

// CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32)
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
// CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32)
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)

// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]

// CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]
// CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]]
// CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]]
// CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>

// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]]
// CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
// CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]]

// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
// CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]]

// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
// CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]]

// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]]
// CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
// CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]]

// Input (8 values): (vec0, vec1)
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
// resVec0 resVec1
// lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1)
// lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0)
// lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1)
// lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0)

// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8>
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8>
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>

// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8>
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8>
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>

// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>

// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]

// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma16
tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
tt.return
}
}
Loading

0 comments on commit af0649d

Please sign in to comment.