Skip to content

Commit

Permalink
[BACKEND] Fix ProgramPoint passing in AxisInfoAnalysis (#5181)
Browse files Browse the repository at this point in the history
Fixes #5122.

The `ProgramPoint`
[here](https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/lib/Analysis/AxisInfo.cpp#L1087)
is created on the stack. Then its address is
[passed](https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/lib/Analysis/AxisInfo.cpp#L1088-L1089)
to the MLIR `SparseAnalysis` code, where it is [added as a
dependency](https://github.com/llvm/llvm-project/blob/33ff9e43b4c5bdc3da31c6b11ad51d35a69bec5f/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L311)
and later
[dereferenced](https://github.com/llvm/llvm-project/blob/33ff9e43b4c5bdc3da31c6b11ad51d35a69bec5f/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L90).
By the time the `ProramPoint` is dereferenced in the
`AbstractSparseForwardDataFlowAnalysis::visit`, the
`AxisInfoAnalysis::visitForOpInductionVar` will have finished and the
`ProgramPoint` stack variable destroyed. This leads to a segfault (which
can be reproed on the base rev with the lit test added in this PR).

The code modified in this PR was originally added in #4927, in
conjunction with updating the `llvm-project` hash to `b5cc222d7429`.
However, as noted in llvm/llvm-project#110344
(the `llvm-project` PR that has made the refactoring prompting the
`AxisInfo.cpp` change in #4927):

> For dense forward data-flow analysis and other analysis (except dense
backward data-flow analysis), the program point corresponding to the
original operation can be obtained by `getProgramPointAfter(op)`

As the `AxisInfoAnalysis` (in Triton) inherits from
`SparseForwardDataFlowAnalysis` (in MLIR), in this PR we follow the
above which resolves the segfault issue (as the `ProgramPoint` is now
stored in the instance-level state of the pass).

P.S. The lit test added in this PR is not exactly minimal. However, I
did my best to minimize it starting from the 400-line repro TTGIR in
#5122. Further minimization does not seem to expose the segfault.
  • Loading branch information
aakhundov authored Nov 18, 2024
1 parent 66e8629 commit 220e51c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,9 +1084,9 @@ LogicalResult AxisInfoAnalysis::visitOperation(

void AxisInfoAnalysis::visitForOpInductionVar(
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
ProgramPoint programPoint(op);
auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue();
auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue();
ProgramPoint *programPoint = getProgramPointAfter(op);
auto lb = getLatticeElementFor(programPoint, op.getLowerBound())->getValue();
auto step = getLatticeElementFor(programPoint, op.getStep())->getValue();

AxisInfo::DimVectorT knownContiguity(1, 1);
AxisInfo::DimVectorT knownDivisibility(1, 1);
Expand Down
29 changes: 29 additions & 0 deletions test/TritonGPU/coalesce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,32 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

// COM: Reproducer for issue #5122
// CHECK-LABEL: @test_5122
module {
tt.func public @test_5122(%arg0: i32) attributes {noinline = false} {
%c1_i32 = arith.constant 1 : i32
%0 = arith.cmpi sgt, %arg0, %c1_i32 : i32
scf.if %0 {
%1 = scf.if %0 -> (i32) {
scf.yield %c1_i32 : i32
} else {
scf.yield %c1_i32 : i32
}
%2 = arith.cmpi sgt, %1, %c1_i32 : i32
%3 = scf.if %2 -> (i32) {
scf.yield %c1_i32 : i32
} else {
scf.yield %c1_i32 : i32
}
%4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 {
%5 = arith.addi %arg2, %c1_i32 : i32
scf.yield %5 : i32
}
}
tt.return
}
}

0 comments on commit 220e51c

Please sign in to comment.