Skip to content

Commit

Permalink
[BACKEND] Add folder for addptr(ptr, 0) -> ptr (#5166)
Browse files Browse the repository at this point in the history
I noticed this rather obvious pattern was missing. It might come up for
example if you have an expression like:
```python
ptrs = ptr + y_stride * tl.arange(0, YBLOCK)[:, None]
```
and the `YBLOCK` is set to 1 during autotuning.
  • Loading branch information
peterbell10 authored Nov 15, 2024
1 parent 03b807d commit 9883a9b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def TT_AddPtrOp : TT_Op<"addptr",
let results = (outs TT_PtrLike:$result);

let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
let hasFolder = 1;
}

def TT_AdvanceOp : TT_Op<"advance",
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,15 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state,
builder.getDenseI32ArrayAttr(order));
}

//-- AddPtrOp --
OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) {
// addptr(ptr, 0) -> ptr
if (matchPattern(adaptor.getOffset(), m_Zero())) {
return getPtr();
}
return {};
}

//-- AdvanceOp --
OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
// advance(ptr, 0, 0) -> ptr
Expand Down
29 changes: 28 additions & 1 deletion test/Triton/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ tt.func @dead_load(%ptr: tensor<32x128x!tt.ptr<f16>>) {
tt.return
}

// -----

// CHECK-LABEL: make_range
tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) {
// CHECK-DAG: %[[c:.*]] = arith.constant dense<0> : tensor<128x1xi32>
Expand All @@ -25,6 +27,32 @@ tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) {
tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32>
}

// -----

// CHECK-LABEL: fold_addptr
tt.func @fold_addptr(%arg: tensor<64x64x!tt.ptr<f16>>) -> (tensor<64x64x!tt.ptr<f16>>) {
// CHECK-NOT: tt.addptr
// CHECK-NOT: arith.constant
// CHECK: tt.return %arg
%c0_i32 = arith.constant dense<0> : tensor<64x64xi32>
%0 = tt.addptr %arg, %c0_i32 : tensor<64x64x!tt.ptr<f16>>, tensor<64x64xi32>
tt.return %0 : tensor<64x64x!tt.ptr<f16>>
}

// -----

// CHECK-LABEL: fold_addptr_scalar
tt.func @fold_addptr_scalar(%arg: !tt.ptr<f16>) -> (!tt.ptr<f16>) {
// CHECK-NOT: tt.addptr
// CHECK-NOT: arith.constant
// CHECK: tt.return %arg
%c0_i32 = arith.constant 0 : i32
%0 = tt.addptr %arg, %c0_i32 : !tt.ptr<f16>, i32
tt.return %0 : !tt.ptr<f16>
}

// -----

// CHECK-LABEL: fold_advance
tt.func @fold_advance(%arg: !tt.ptr<tensor<64x64xf16>>) -> (!tt.ptr<tensor<64x64xf16>>) {
%c0_i32 = arith.constant 0 : i32
Expand All @@ -34,7 +62,6 @@ tt.func @fold_advance(%arg: !tt.ptr<tensor<64x64xf16>>) -> (!tt.ptr<tensor<64x64
tt.return %0 : !tt.ptr<tensor<64x64xf16>>
}


// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
Expand Down
2 changes: 1 addition & 1 deletion test/TritonGPU/loop-pipeline-hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: async_following_sync
tt.func @async_following_sync(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) {
%cst = arith.constant dense<0> : tensor<64x16xi32, #blocked>
%cst = arith.constant dense<64> : tensor<64x16xi32, #blocked>
%c0_i32 = arith.constant 0 : i32
%cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked>
%cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>
Expand Down

0 comments on commit 9883a9b

Please sign in to comment.