Skip to content

Commit

Permalink
[NFC] Add test_bessel into test_libdevice.py (#5261)
Browse files Browse the repository at this point in the history
Just a port of one of our tests. I didn't find any similar ones in
Triton itself, this should increase the test coverage.

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Nov 27, 2024
1 parent 8b29bb7 commit 2ea9daa
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions python/test/unit/language/test_libdevice.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,45 @@
import pytest
import torch

import triton
import triton.language as tl

from triton.language.extra import libdevice
from triton.language.extra.libdevice import fast_dividef as my_fast_dividef


@pytest.mark.parametrize("dtype_str", ["float32", "float64"])
@pytest.mark.parametrize(
"libdevice_fn, torch_special_fn",
[
("j0", "bessel_j0"),
("j1", "bessel_j1"),
("y0", "bessel_y0"),
("y1", "bessel_y1"),
("cyl_bessel_i0", "i0"),
("cyl_bessel_i1", "i1"),
],
)
def test_bessel(dtype_str, libdevice_fn, torch_special_fn, device):
SIZE = 128
dtype = getattr(torch, dtype_str)

x = torch.randn((SIZE, ), dtype=dtype, device=device)
y_exp = torch.empty((SIZE, ), dtype=dtype, device=device)
y_ref = getattr(torch.special, torch_special_fn)(x)

@triton.jit
def kernel(in_p, out_p, fn: tl.constexpr, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(in_p + off)
res = getattr(libdevice, fn)(x)
tl.store(out_p + off, res)

kernel[(1, )](x, y_exp, fn=libdevice_fn, SIZE=SIZE, num_warps=4, num_ctas=1)

torch.testing.assert_close(y_ref, y_exp, equal_nan=True)


def test_libdevice_rename(device):
# mark the import as used by this test
_ = my_fast_dividef
Expand Down

0 comments on commit 2ea9daa

Please sign in to comment.