Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[AMD] Use warp shuffle for MFMA to Dot operand layout conversion (FP8)" #5240

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);

// Check if MFMA layout can be converted to the dot operand
// layout using warp shuffle.
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
RankedTensorType dstTy);

// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
Expand Down
25 changes: 1 addition & 24 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -633,25 +632,6 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
return ans;
}

bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (!mfmaLayout || !dotOperandLayout)
return false;

// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
return dotOperandLayout.getParent() == mfmaLayout &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == 8 &&
getContigPerThread(mfmaLayout)[1] == 4 &&
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
triton::type::isFloat8(srcTy.getElementType()) &&
triton::type::isFloat8(dstTy.getElementType()) &&
mfmaLayout.getWarpsPerCTA()[1] == 1;
}

// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
// have a transformation that's the identity on kBlock, we don't need to use
Expand Down Expand Up @@ -750,10 +730,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
// to be removed when generalized warp shuffle conversions
// are ready:
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand Down
6 changes: 0 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return failure();
}

// The following check can be removed when generalized warp shuffle
// conversions are ready:
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
return failure();
}

assert(cvtNeedsSharedMemory(srcTy, dstTy));

SmallVector<Value> inVals =
Expand Down
190 changes: 1 addition & 189 deletions test/Conversion/amd/mfma-shortcut.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s

#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
Expand Down Expand Up @@ -27,191 +27,3 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_f8_mfma32
tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load

// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]

// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)

// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>

// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
// CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]]

// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
// CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]]

// Input (8 values): (vec0, vec1)
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
// resVec0 resVec1
// lanes 0-31: (vec0 , vec0 >> 32) (mask0=1)
// lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0)

// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]]
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]]

// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>

// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]

// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma32
tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_f8_mfma16
tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load

// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]

// CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32)
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
// CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32)
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)

// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]

// CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]
// CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]]
// CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]]
// CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]]

// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>

// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]]
// CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
// CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]]

// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
// CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]]

// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
// CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
// CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]]

// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]]
// CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
// CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]]

// Input (8 values): (vec0, vec1)
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
// resVec0 resVec1
// lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1)
// lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0)
// lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1)
// lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0)

// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8>
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8>
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>

// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8>
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8>
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>

// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>

// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]

// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
tt.return
}
}

// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma16
tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: rocdl.ds_bpermute
// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
tt.return
}
}
Loading
Loading