From 491376587c11ce7f66c1de09031e895162d5c344 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 22 Nov 2024 12:07:12 -0800 Subject: [PATCH] [BACKEND] Fix a missed tranpose optimization during refactor --- .../Transforms/OptimizeDotOperands.cpp | 6 +++--- test/TritonGPU/dot-operands.mlir | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index c776944a24b9..01e8acf25842 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -148,16 +148,16 @@ class SwizzleShmemConvert : public OpRewritePattern { LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, PatternRewriter &rewriter) const override { // Match outerCvt(trans(innerCvt(x))). - auto trans = cvtOp.getSrc().getDefiningOp(); + auto trans = cvtOp.getSrc().getDefiningOp(); if (!trans || trans.getOrder() != ArrayRef{1, 0}) return failure(); - auto srcTy = dyn_cast(trans.getSrc().getType()); + RankedTensorType srcTy = trans.getSrc().getType(); if (auto srcCvt = trans.getSrc().getDefiningOp()) { srcTy = srcCvt.getSrc().getType(); } - auto sharedLoadTy = cast(cvtOp.getType()); + RankedTensorType sharedLoadTy = cvtOp.getType(); auto cvtEncoding = dyn_cast(sharedLoadTy.getEncoding()); if (!cvtEncoding) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 990a0b4f7a78..5e244889fb04 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -276,3 +276,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return %r : tensor<128x64xf32, #mma> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: mmav2_reorder_transpose +// CHECK: triton_gpu.local_alloc +// CHECK: triton_gpu.memdesc_trans +// CHECK: triton_gpu.local_load +// CHECK: tt.dot + tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a = tt.trans %t {order = array} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked> + %cv = triton_gpu.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +}