Skip to content

Commit

Permalink
[BACKEND][LAYOUT] Use LL for AMDMfma related layout conversions (#5210)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Nov 21, 2024
1 parent 9c7a8c6 commit d5ba6ac
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
22 changes: 11 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,24 +374,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
// completed before we can remove the layoutIsOK check:
// 1. Support for AMD's MFMA and WMMA
// 1. Support for AMD's WMMA
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (useLegacyMMAConversion) {
return false;
}
return true;
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(layout)) {
return !useLegacyMMAConversion;
}
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (auto nvidiaMma =
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
if (useLegacyMMAConversion) {
return false;
}
auto parent = dotOperand.getParent();
if (isa<MmaEncodingTrait>(parent) && useLegacyMMAConversion) {
return false;
}
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (nvidiaMma.isAmpere()) {
return true;
}
}
if (isa<AMDMfmaEncodingAttr>(parent)) {
return true;
}
return false;
}
if (isa<BlockedEncodingAttr>(layout)) {
Expand Down
2 changes: 2 additions & 0 deletions test/Conversion/amd/mfma-shortcut.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) {
// CHECK-NOT: store
// CHECK-NOT: load
// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop>
tt.return
}
Expand All @@ -21,6 +22,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) {
// CHECK: store
// CHECK: load
// CHECK: llvm.return
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop>
tt.return
}
Expand Down

0 comments on commit d5ba6ac

Please sign in to comment.