Skip to content

Commit

Permalink
[NFC] Add functional regression test for cummax with bool type (#5264)
Browse files Browse the repository at this point in the history
This kernel was obtained using PyTorch inductor some time ago.

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Nov 27, 2024
1 parent 2ea9daa commit 9e508a4
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions python/test/regression/test_functional_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,40 @@ def kernel(in_ptr, out_ptr):
kernel[(1, )](data, res)
ref = torch.flip(data[1:513], [0])
assert (res == ref).all()


@triton.jit
def _triton_cummax_helper_fn(arg0_0, arg0_1, arg1_0, arg1_1):
tmp0 = arg0_0 > arg1_0
tmp1 = arg0_0 == arg1_0
tmp2 = arg0_1 > arg1_1
tmp3 = tmp1 & tmp2
tmp4 = tmp0 | tmp3
tmp5 = tl.where(tmp4, arg0_0, arg1_0)
tmp6 = tl.where(tmp4, arg0_1, arg1_1)
return tmp5, tmp6


def test_inductor_cummax_bool(device):

@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, XBLOCK: tl.constexpr):
offset = tl.arange(0, XBLOCK)
tmp0 = tl.load(in_ptr0 + offset).to(tl.int1)
tmp1 = tmp0.to(tl.int1)
tmp3 = offset.to(tl.int64)
tmp5, tmp6, = tl.associative_scan((
tmp1,
tmp3,
), 0, _triton_cummax_helper_fn)
tl.store(out_ptr0 + offset, tmp5)
tl.store(out_ptr1 + offset, tmp6)

a = torch.randn((64, ), device=device) > 0
values = torch.empty((64, ), dtype=torch.bool, device=device)
indices = torch.empty((64, ), dtype=torch.int64, device=device)
ref = torch.cummax(a, dim=0)

triton_[(1, )](a, values, indices, 64)
torch.testing.assert_close(ref.values, values)
torch.testing.assert_close(ref.indices, indices)

0 comments on commit 9e508a4

Please sign in to comment.