Skip to content

Commit

Permalink
[TritonGPU] Fix incorrect mask operand used in for loop pipeliner
Browse files Browse the repository at this point in the history
When the OOB values for a `tt.load` are non-zero, the for loop pipeliner
needs to generate an `arith.select` to mask the loaded values with the
default OOB value. However, if the load memory requires a layout change,
the wrong mask operand was being passed to the `arith.select`, causing a
shape mismatch. The fix is to just use the same mask operand of the
origianl `tt.load` op.

Fixes triton-lang#4739
  • Loading branch information
Mogball committed Nov 15, 2024
1 parent d5e06fe commit 9a9d425
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static Operation *getFirstUseOfPipelinedLoad(Operation *loadOp) {
return firstUser;
}

static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
int numStages, int maxClusterId) {
Expand Down Expand Up @@ -190,8 +190,10 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
Value other = loadOp.getOther();
if (other && !isZeroConst(other)) {
auto select = builder.createWithStage<arith::SelectOp>(
loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(), mask,
sharedLoad.getResult(), other);
loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(),
// Use the mask operand from the original load, not the one with a
// potentially transformed layout.
loadOp.getMask(), sharedLoad.getResult(), other);
result = select->getResults();
}

Expand Down
30 changes: 30 additions & 0 deletions test/TritonGPU/matmul-loop-pipeline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: triton-opt %s -tritongpu-pipeline | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: @softmax_kernel
tt.func public @softmax_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked>
%0 = tt.get_program_id x : i32
%1 = tt.get_num_programs x : i32
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked>
%3 = tt.splat %arg5 : i32 -> tensor<128xi32, #blocked>
// CHECK: [[MASK:%.*]] = arith.cmpi slt, {{.*}} tensor<128xi32,
%4 = arith.cmpi slt, %2, %3 : tensor<128xi32, #blocked>
// CHECK: scf.for
scf.for %arg6 = %0 to %arg4 step %1 : i32 {
%5 = tt.splat %arg1 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked>
%6 = tt.addptr %5, %2 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr<f32>, #blocked>, tensor<128xi32, #blocked>
// CHECK: [[RESULT:%.*]] = triton_gpu.local_load
// CHECK-NEXT: arith.select [[MASK]], [[RESULT]], %cst
%7 = tt.load %6, %4, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr<f32>, #blocked>
%8 = tt.splat %arg0 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked>
%9 = tt.addptr %8, %2 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr<f32>, #blocked>, tensor<128xi32, #blocked>
tt.store %9, %7, %4 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr<f32>, #blocked>
} {tt.num_stages = 2 : i32}
tt.return
}

}

0 comments on commit 9a9d425

Please sign in to comment.