Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Oct 30, 2024
1 parent 9c598bd commit ca22491
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
60 changes: 52 additions & 8 deletions test/TritonGPU/amd/amd-sched-2nd-load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

// Check the logic of sched-2nd-load optimizations
//

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>

// Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough
// The following tile sizes should apply the optimization
// 256x256x128
Expand All @@ -11,13 +20,6 @@
// 256x256x32
//

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
// Should apply: tile size 256x256x128 with single dot
// CHECK-LABEL: sink_2nd_load_256x256x128
// CHECK: %[[tileA:.*]] = tt.load
Expand Down Expand Up @@ -147,7 +149,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
// CHECK-NEXT: tt.dot
// CHECK-NEXT: triton_gpu.local_store %[[tileA]]
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x128x!tt.ptr<i64>, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr<f16>, #blocked>, %C_ptr: tensor<128x128x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x128x!tt.ptr<i64>, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr<f16>, #blocked>, %C_ptr: tensor<128x128x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
Expand All @@ -165,3 +167,45 @@ tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128
tt.return
}
}


// -----

// Category 3: two dots in the for loop. Make sure the optimization is not applied
// should NOT apply: two dots
// CHECK-LABEL: sink_2nd_load_256x256x64_two_dot
// CHECK: triton_gpu.local_load
// CHECK-NEXT: triton_gpu.local_load
// CHECK-NEXT: tt.dot
// CHECK-NEXT: tt.dot
// CHECK-NEXT: tt.load
// CHECK-NEXT: tt.load
// CHECK-NEXT: triton_gpu.local_store
// CHECK-NEXT: triton_gpu.local_store
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<64x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>, %B_LDS: !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
%1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0>
%2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1>
%3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
%6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
%4 = tt.load %A_ptr : tensor<256x64x!tt.ptr<f16>, #blocked>
%5 = tt.load %B_ptr : tensor<64x256x!tt.ptr<f16>, #blocked1>
triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %3 : tensor<256x256xf32, #mma>
}
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -313,35 +313,35 @@ static void scheduleGlobalLoadLocalStore(ModuleOp m) {
*/
static void sinkSecondLoad(ModuleOp m) {
m.walk([&](scf::ForOp forOp) -> void {
SetVector<Operation *> loadOps;
Operation *dotOp;
SetVector<triton::LoadOp> loadOps;
triton::DotOp dotOp;
for (Operation &op : forOp) {
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
loadOps.insert(loadOp);
if (auto curOp = dyn_cast<triton::DotOp>(&op))
dotOp = &op;
dotOp = curOp;
}
// Only apply the optimization when there are 2 load's in the loop
if (loadOps.size() != 2)
return;
// Only apply the optimization when tile size is large enough
// 1. nonKDim >= 128
// 2. kDim >= 64
auto ldAOp = dyn_cast<triton::LoadOp>(loadOps[0]);
auto ldAOp = loadOps[0];
auto tileAShape = cast<RankedTensorType>(ldAOp.getType()).getShape();
auto ldBOp = dyn_cast<triton::LoadOp>(loadOps[1]);
auto ldBOp = loadOps[1];
auto tileBShape = cast<RankedTensorType>(ldBOp.getType()).getShape();
if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128))
return;
// Only apply the optimization when the moving is legal
// 1. Make sure the 2nd loadOp is before the dot
// 2. Make sure the first user of the 2nd loadOp is after the dot.
bool isBeforeDotOp = loadOps[1]->isBeforeInBlock(dotOp);
bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp);
auto firstUser = *ldBOp.getResult().getUsers().begin();
bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser);
if (isBeforeDotOp && firstUserAfterDotOp)
// move ldBOp right before tt.dot
loadOps[1]->moveBefore(dotOp);
ldBOp->moveBefore(dotOp);
});
}

Expand Down

0 comments on commit ca22491

Please sign in to comment.