Skip to content

Commit

Permalink
[BACKEND] Fix and document logic for creating warp shapes in MMAv3 (#…
Browse files Browse the repository at this point in the history
…5441)

Cherry pick of #5277 for 3.2 release

Co-authored-by: Keren Zhou <[email protected]>
  • Loading branch information
bertmaher and Jokeren authored Dec 17, 2024
1 parent 8af9311 commit e74f027
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
8 changes: 6 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,12 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) { return isa<DotOp>(op); }) !=
slices.end())
// Contains a chained dot. We prefer to assign warps to one axis
// to facilitate use cases like flash attention, allowing reductions within
// the same warp.
if (llvm::find_if(slices, [](Operation *op) {
return op->hasTrait<OpTrait::DotLike>();
}) != slices.end())
return {(unsigned)numWarps, 1};

// For MMAv3, the smallest indivisible unit of warp shape is (4, 1).
Expand Down
27 changes: 27 additions & 0 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}>
// CHECK: #mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: chained_dot
tt.func public @chained_dot_wgmma(
%arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>,
%arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>,
%arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1>
// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma>
%d = tt.dot %arg0, %arg1, %cst_0 :
tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked>
%t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked>
%c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1>
%r = tt.dot %c, %arg2, %cst_1 :
tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1>
tt.return %r : tensor<64x128xf32, #blocked1>
}
}

// -----

// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
Expand Down

0 comments on commit e74f027

Please sign in to comment.