Skip to content

Commit

Permalink
[BACKEND] Fix a missed transpose optimization during refactor (#5236)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored Nov 22, 2024
1 parent 16ce143 commit 340cbc6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp,
PatternRewriter &rewriter) const override {
// Match outerCvt(trans(innerCvt(x))).
auto trans = cvtOp.getSrc().getDefiningOp<MemDescTransOp>();
auto trans = cvtOp.getSrc().getDefiningOp<TransOp>();
if (!trans || trans.getOrder() != ArrayRef<int32_t>{1, 0})
return failure();

auto srcTy = dyn_cast<RankedTensorType>(trans.getSrc().getType());
RankedTensorType srcTy = trans.getSrc().getType();

if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
srcTy = srcCvt.getSrc().getType();
}
auto sharedLoadTy = cast<RankedTensorType>(cvtOp.getType());
RankedTensorType sharedLoadTy = cvtOp.getType();
auto cvtEncoding =
dyn_cast<DotOperandEncodingAttr>(sharedLoadTy.getEncoding());
if (!cvtEncoding)
Expand Down
19 changes: 19 additions & 0 deletions test/TritonGPU/dot-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.return %r : tensor<128x64xf32, #mma>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: mmav2_reorder_transpose
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.memdesc_trans
// CHECK: triton_gpu.local_load
// CHECK: tt.dot
tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
%a = tt.trans %t {order = array<i32: 1, 0>} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked>
%cv = triton_gpu.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
tt.return %r : tensor<128x64xf32, #mma>
}
}

0 comments on commit 340cbc6

Please sign in to comment.