From 44ad13a0c545dd1e5caa1b7981524e6b9f74de5d Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Wed, 30 Oct 2024 21:05:03 -0500 Subject: [PATCH] [AMD] Reland sinking the 2nd tt.load after local_load's (#4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert https://github.com/triton-lang/triton/pull/4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64 (cherry picked from commit 4f6f76874ff623562903d5452d499cae3d40d448) --- .../amd/amd-reorder-instructions.mlir | 423 ------------------ test/TritonGPU/amd/amd-sched-2nd-load.mlir | 211 +++++++++ .../ReorderInstructions.cpp | 92 +++- 3 files changed, 297 insertions(+), 429 deletions(-) create mode 100644 test/TritonGPU/amd/amd-sched-2nd-load.mlir diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 686e5a24e8dd..5dfd0f2a5f4c 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -460,429 +460,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } } -// ----- -// This test ensures that loads will not be moved across `for` loops. - -// CHECK-LABEL: tt.func public @_attn_bwd -// CHECK: tt.load -// CHECK: tt.load -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// Moved before the independent `tt.store` ops but not before the `for` ops. -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.load -// CHECK: tt.store -// CHECK: tt.store -// CHECK: scf.for -// CHECK: } -// CHECK: scf.for -// CHECK: } -// CHECK: tt.store - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> -#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> -#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> -#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @_attn_bwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c-1_i32 = arith.constant -1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma> - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %c32_i32 = arith.constant 32 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c16_i32 = arith.constant 16 : i32 - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma> - %0 = tt.get_program_id z : i32 - %1 = arith.muli %0, %arg14 : i32 - %2 = arith.extsi %1 : i32 to i64 - %3 = arith.remsi %0, %arg13 : i32 - %4 = arith.muli %arg11, %3 : i32 - %5 = arith.divsi %0, %arg13 : i32 - %6 = arith.muli %arg10, %5 : i32 - %7 = arith.addi %4, %6 : i32 - %8 = arith.extsi %7 : i32 to i64 - %9 = tt.get_program_id x : i32 - %10 = tt.addptr %arg0, %8 : !tt.ptr, i64 - %11 = tt.addptr %arg1, %8 : !tt.ptr, i64 - %12 = tt.addptr %arg2, %8 : !tt.ptr, i64 - %13 = tt.addptr %arg4, %8 : !tt.ptr, i64 - %14 = tt.addptr %arg5, %8 : !tt.ptr, i64 - %15 = tt.addptr %arg6, %8 : !tt.ptr, i64 - %16 = tt.addptr %arg7, %8 : !tt.ptr, i64 - %17 = tt.addptr %arg8, %2 : !tt.ptr, i64 - %18 = tt.addptr %arg9, %2 : !tt.ptr, i64 - %19 = arith.muli %9, %c128_i32 : i32 - %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> - %36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> - %38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma> - %39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked> - %40 = arith.muli %35, %38 : tensor<128x1xi32, #mma> - %41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked> - %42 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> - %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %50 = tt.broadcast %43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> - %52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> - %53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %54 = tt.load %53 : tensor<128x64x!tt.ptr, #blocked> - %55 = tt.splat %12 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %57 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %59 = tt.load %58 : tensor<128x64x!tt.ptr, #blocked> - %60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> - %66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2> - %68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2> - %69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2> - %70 = tt.splat %10 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2> - %75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> - %76 = tt.broadcast %71 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2> - %78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> - %80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1> - %81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1> - %82 = tt.splat %13 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked1> - %83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr, #blocked1>, tensor<16x1xi32, #blocked1> - %84 = tt.broadcast %83 : tensor<16x1x!tt.ptr, #blocked1> -> tensor<16x64x!tt.ptr, #blocked1> - %85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1> - %86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - %87 = tt.splat %17 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %89 = tt.splat %18 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %90 = arith.muli %arg12, %c16_i32 : i32 - %91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2> - %92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1> - %93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1>) : i32 { - %206 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %209 = tt.addptr %87, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %210 = tt.load %209 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1> - %223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %224 = tt.load %arg20 : tensor<16x64x!tt.ptr, #blocked1> - %225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %89, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %232 = tt.load %231 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %234 = tt.trans %233 {order = array} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> - %240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> - %241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1> - %242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1> - %243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %245 = tt.trans %244 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %250 = arith.addi %arg18, %c16_i32 : i32 - %251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> - scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1> - } - %94 = arith.addi %19, %c128_i32 : i32 - %95 = arith.subi %arg14, %94 : i32 - %96 = arith.divsi %95, %c32_i32 : i32 - %97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3> - %106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3> - %107 = tt.splat %10 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %109 = tt.broadcast %108 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3> - %111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> - %113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked> - %114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked> - %115 = tt.splat %13 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> - %116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> - %117 = tt.broadcast %116 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> - %118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> - %119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - %120 = tt.splat %17 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %121 = tt.splat %18 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %122 = arith.muli %arg12, %c32_i32 : i32 - %123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3> - %124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked> - %125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked>) : i32 { - %206 = tt.load %arg19 : tensor<64x32x!tt.ptr, #blocked3> - %207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %209 = tt.addptr %120, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %210 = tt.load %209 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %218 = arith.subf %215, %217 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = tt.load %arg20 : tensor<32x64x!tt.ptr, #blocked> - %221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %226 = tt.addptr %121, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %227 = tt.load %226 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> - %228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> - %229 = tt.trans %228 {order = array} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> - %235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> - %236 = arith.subf %233, %235 : tensor<128x32xf32, #mma> - %237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma> - %238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %240 = tt.trans %239 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %244 = arith.addi %arg18, %c32_i32 : i32 - %245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> - scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked> - } - %126 = tt.splat %16 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %128 = tt.broadcast %127 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %129, %130 : tensor<128x64x!tt.ptr, #mma> - %131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma> - %132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma> - %133 = tt.splat %15 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %135 = tt.broadcast %134 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %136, %137 : tensor<128x64x!tt.ptr, #mma> - %138 = tt.splat %10 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %140 = tt.broadcast %139 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %142 = tt.load %141 : tensor<128x64x!tt.ptr, #blocked> - %143 = tt.splat %13 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> - %144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> - %145 = tt.broadcast %144 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> - %146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> - %147 = tt.load %146 : tensor<128x64x!tt.ptr, #blocked> - %148 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %149 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %150 = tt.addptr %148, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %151 = tt.addptr %149, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %152 = tt.load %150 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %153 = tt.load %151 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %156 = tt.splat %11 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %158 = tt.broadcast %157 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %160 = tt.splat %12 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> - %161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> - %162 = tt.broadcast %161 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> - %163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %164 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %165 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %166 = tt.addptr %164, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %167 = tt.addptr %165, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %168 = tt.load %166 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> - %169 = tt.load %167 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> - %170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> - %172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> - %173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> - %174 = arith.muli %arg12, %c16_i32 : i32 - %175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2> - %176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32) : i32 { - %206 = arith.addi %arg20, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg18 : tensor<64x16x!tt.ptr, #blocked2> - %210 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> - %211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1> - %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> - %220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> - %222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> - %223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> - %224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1> - %225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> - %228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> - %229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1> - %230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1> - %231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> - %233 = tt.trans %232 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> - %234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> - %236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %238 = arith.addi %arg17, %c16_i32 : i32 - %239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - %240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> - scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32 - } - triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> - %178 = arith.divsi %19, %c32_i32 : i32 - %179 = arith.muli %178, %c32_i32 : i32 - %180 = arith.subi %19, %179 : i32 - %181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> - %183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> - %184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3> - %185 = tt.splat %11 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %187 = tt.broadcast %186 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %189 = tt.splat %12 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> - %190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> - %191 = tt.broadcast %190 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> - %192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> - %195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> - %196 = arith.muli %arg12, %c32_i32 : i32 - %197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3> - %198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32) : i32 { - %206 = arith.addi %arg19, %c1_i32 : i32 - %207 = arith.cmpi slt, %206, %c1_i32 : i32 - %208 = arith.select %207, %206, %c0_i32 : i32 - %209 = tt.load %arg17 : tensor<64x32x!tt.ptr, #blocked3> - %210 = tt.load %arg18 : tensor<64x32x!tt.ptr, #blocked3> - %211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> - %216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %218 = arith.subf %217, %193 : tensor<128x32xf32, #mma> - %219 = math.exp2 %218 : tensor<128x32xf32, #mma> - %220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> - %221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> - %223 = arith.subf %222, %195 : tensor<128x32xf32, #mma> - %224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma> - %225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> - %226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> - %227 = tt.trans %226 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> - %228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> - %231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - %232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> - scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32 - } - triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> - %200 = tt.splat %14 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> - %201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> - %202 = tt.broadcast %201 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> - %203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> - %204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma> - %205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - tt.store %203, %205 : tensor<128x64x!tt.ptr, #mma> - tt.return - } -} - // ----- // CHECK-LABEL: sink_convert_dealloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir new file mode 100644 index 000000000000..5c173ffb4858 --- /dev/null +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -0,0 +1,211 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s + +// Check the logic of sched-2nd-load optimizations +// + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> + +// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough +// The following tile sizes should apply the optimization +// 256x256x128 +// 256x256x64 +// The following tile sizes should NOT apply the optimization +// 256x64x128 +// 256x256x32 +// + +// Should apply: tile size 256x256x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should apply: tile size 256x256x64 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x64 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x64x128 with single dot +// CHECK-LABEL: sink_2nd_load_256x64x128 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x64xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> + tt.return + } +} + +// Should NOT apply: tile size 256x256x32 with single dot +// CHECK-LABEL: sink_2nd_load_256x256x32 +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +// CHECK-NEXT: triton_gpu.local_store %[[tileB]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} + + +// Category 2: single dot with two loads and tile size is large enough (128x128x128). +// We make sure the move is legal. +// Should NOT apply: the 2nd load has a user before the dot +// CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot +// CHECK: %[[tileA:.*]] = tt.load +// CHECK-NEXT: %[[tileB:.*]] = tt.load +// CHECK-NEXT: local_load +// CHECK-NEXT: local_load +// CHECK-NEXT: tt.store +// CHECK-NEXT: tt.dot +// CHECK-NEXT: triton_gpu.local_store %[[tileA]] +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> + %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> + tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> + %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> + triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<128x128xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> + tt.return + } +} + + +// ----- + +// Category 3: two dots in the for loop. Make sure the optimization is not applied +// should NOT apply: two dots +// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot +// CHECK: triton_gpu.local_load +// CHECK-NEXT: triton_gpu.local_load +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: triton_gpu.local_store +// CHECK-NEXT: triton_gpu.local_store +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> +#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> + triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %3 : tensor<256x256xf32, #mma> + } + tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 22349c50e308..9371c8b5f897 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -20,13 +20,15 @@ namespace ttg = mlir::triton::gpu; // Return true if the given moduleOp contains a pure matmul problem; i.e., // single dot in the main loop. static bool isPureMatmulProblem(ModuleOp moduleOp) { - for (auto forOp : moduleOp.getOps()) { + bool isMatmul = true; + bool foundLoop = false; + moduleOp.walk([&](scf::ForOp forOp) -> void { int counter = 0; forOp.walk([&counter](triton::DotOp dotOp) { ++counter; }); - if (counter != 1) - return false; - } - return true; + isMatmul = (isMatmul && (counter == 1)); + foundLoop = true; + }); + return foundLoop && isMatmul; } // Search through block to find earliest insertion point for move op. This can @@ -267,6 +269,82 @@ static void scheduleGlobalLoadLocalStore(ModuleOp m) { } } +/** + * Sched-load optimization for matmul kernels with large tile sizes + * The basic idea of sched-load optimization is to sink the 2nd tt.load + * after local_load so that global_load instructions can be interleaved with + * mfma's. This can help hide the issue latency of global_load instructions + * and improve performance on MI300X. + * + * It's assumed that the IR before this optimization has the following + * structure: + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * tileB = tt.load b_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * After this optimization, the IR is transformed to + * ```mlir + * scf.for .. + * { + * tileA = tt.load a_ptr + * opA = local_load bufferA + * opB = local_load bufferB + * tileB = tt.load b_ptr <-- 2nd tt.load is sinked here + * res = tt.dot opA, opB + * local_store tileA, bufferA + * local_store tileB, bufferB + * } + * ``` + * For now, we don't have a perfect hueristic about when should this + * optimization be applied. Therefore, we implement a simple hueristic that + * this is applied when the tile size of A and B are large enough, i.e. + * nonKDim >= 128 and kDim >= 64. And also this is only applied for typical + * matmul kernels, i.e. only two tt.load's and one dotOp inside the loop. We + * are experimenting how to better control instruction scheduling and enable + * such optimizations. + */ +static void sinkSecondLoad(ModuleOp m) { + m.walk([&](scf::ForOp forOp) -> void { + SetVector loadOps; + triton::DotOp dotOp; + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) + loadOps.insert(loadOp); + if (auto curOp = dyn_cast(&op)) + dotOp = curOp; + } + // Only apply the optimization when there are 2 load's in the loop + if (loadOps.size() != 2) + return; + // Only apply the optimization when tile size is large enough + // 1. nonKDim >= 128 + // 2. kDim >= 64 + auto ldAOp = loadOps[0]; + auto tileAShape = cast(ldAOp.getType()).getShape(); + auto ldBOp = loadOps[1]; + auto tileBShape = cast(ldBOp.getType()).getShape(); + if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128)) + return; + // Only apply the optimization when the moving is legal + // 1. Make sure the 2nd loadOp is before the dot + // 2. Make sure the first user of the 2nd loadOp is after the dot. + bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp); + auto firstUser = *ldBOp.getResult().getUsers().begin(); + bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser); + if (isBeforeDotOp && firstUserAfterDotOp) + // move ldBOp right before tt.dot + ldBOp->moveBefore(dotOp); + }); +} + //===----------------------------------------------------------------------===// // Pass definition //===----------------------------------------------------------------------===// @@ -288,8 +366,10 @@ struct TritonAMDGPUReorderInstructionsPass moveUpTranspose(m); - if (isPureMatmulProblem(m)) + if (isPureMatmulProblem(m)) { scheduleGlobalLoadLocalStore(m); + sinkSecondLoad(m); + } } }; } // namespace