Skip to content

Commit

Permalink
Generalize unit tests for different backends (#5576)
Browse files Browse the repository at this point in the history
Generalize unit tests for different backends, for example not hard
coding `device` with `cuda`.

---------

Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Jan 11, 2025
1 parent 74de6b4 commit 2efb067
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 20 deletions.
3 changes: 1 addition & 2 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,11 @@ def matmul_kernel(A, B, C, M, N, K, #
for w in input_dtypes
for x in input_dtypes #
for o in out_dtypes])
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype):
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype, device):
if x_dtype == w_dtype:
pytest.skip("skip the same input dtype")
if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]:
pytest.skip("skip due to bug on HIP path")
device = torch.cuda.current_device()
x_dtype: torch.dtype = getattr(torch, x_dtype)
w_dtype: torch.dtype = getattr(torch, w_dtype)

Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0):
pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+")

if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)):
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.is_available() and torch.cuda.get_device_capability(0) < (9, 0)):
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")

if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_hip_mi300()):
Expand Down
22 changes: 12 additions & 10 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,7 @@ def __str__(self):


def is_layout_applicable(layout) -> bool:
common_layouts = [BlockedLayout, SharedLayout]
if layout in common_layouts:
if isinstance(layout, (BlockedLayout, SharedLayout)):
return True
elif isinstance(layout, SliceLayout):
return is_layout_applicable(layout.parent)
Expand Down Expand Up @@ -1447,6 +1446,7 @@ def kernel(X, Y, Z):
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']
for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']]))
def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
check_type_supported(dtype_x_str, device)
if is_interpreter():
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on GPU")
Expand Down Expand Up @@ -1523,6 +1523,7 @@ def kernel(X):
for num_ctas in num_ctas_list
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
check_type_supported(dtype_x_str, device)
shape0, shape1 = shape
# triton kernel

Expand Down Expand Up @@ -2874,7 +2875,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov


@pytest.mark.parametrize("M", [32, 64, 128, 256])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):

ir = f"""
Expand Down Expand Up @@ -3807,7 +3808,7 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_

if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32":
if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties(
torch.cuda.current_device())["max_shared_mem"] < 131072:
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 131072:
pytest.skip(
"Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU."
)
Expand Down Expand Up @@ -6550,7 +6551,7 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
([128, 64], [256, 64], 0),
([128, 64], [128, 128], 1),
])
def test_gather(src_shape, indices_shape, axis):
def test_gather(src_shape, indices_shape, axis, device):

def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)
Expand All @@ -6562,8 +6563,8 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):

return output

src = torch.randn(src_shape, device='cuda')
indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda')
src = torch.randn(src_shape, device=device)
indices = torch.randint(0, src.shape[axis], indices_shape, device=device)
ref = torch.gather(src, axis, indices)
result = triton_gather(src, axis, indices)
torch.testing.assert_close(result, ref, rtol=0, atol=0)
Expand All @@ -6580,7 +6581,8 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
"linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>"
),
])
def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path):
def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path,
device):
if is_hip():
pytest.skip("warp-local gather has issues on HIP")

Expand Down Expand Up @@ -6623,8 +6625,8 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
\1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>"""
return re.sub(pat, repl, ir)

src = torch.randn(src_shape, device='cuda')
indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda')
src = torch.randn(src_shape, device=device)
indices = torch.randint(0, src.shape[axis], indices_shape, device=device)
ref = torch.gather(src, axis, indices)

output, compiled = prepare_kernel(src, axis, indices)
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def kernel(X, seed):

x = torch.empty(1, dtype=torch.float32, device=device)
with pytest.raises(triton.compiler.errors.CompilationError):
seed0 = torch.zeros(1, dtype=torch.int32, device="cuda")
seed0 = torch.zeros(1, dtype=torch.int32, device=device)
kernel[(1, )](x, seed0)
with pytest.raises(triton.compiler.errors.CompilationError):
seed1 = 2.3
Expand Down
8 changes: 4 additions & 4 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):


@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device="cuda"):
def test_index(size, device):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
Expand All @@ -51,7 +51,7 @@ def _tuple_assign(XPtrs, YPtrs, values):
tl.store(Y[2], y[2])


def test_assign(device="cuda"):
def test_assign(device):
vals = (2., 3.)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
Expand Down Expand Up @@ -91,7 +91,7 @@ def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
_tuple_fn0(Ptr, 15, (-1, None, tuple1))


def test_serialize(device="cuda"):
def test_serialize(device):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
Expand Down Expand Up @@ -133,7 +133,7 @@ def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.conste
tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N))


def test_namedtuple(device="cuda"):
def test_namedtuple(device):
x = torch.randn((32, 32), dtype=torch.float32, device=device)
y = torch.empty((16, 16), dtype=torch.float32, device=device)
a = torch.tensor([5.2], dtype=torch.float32, device=device)
Expand Down
3 changes: 3 additions & 0 deletions python/test/unit/runtime/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def do_bench(kernel_call, quantiles):

@pytest.mark.parametrize('use_cuda_graph', [False, True])
def test_kwargs(use_cuda_graph: bool, device: str):
if use_cuda_graph and not torch.cuda.is_available():
pytest.xfail("CUDA is not available")

M, N = 1024, 16
src = torch.randn(M * N, device=device)
dst = torch.empty(M * N, device=device)
Expand Down
4 changes: 2 additions & 2 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def cache_hook(*args, **kwargs):
assert pointer_range_32 == [(0, )]


def test_function_arguments():
def test_function_arguments(device):

@triton.jit
def func1():
Expand All @@ -601,7 +601,7 @@ def kernel(Y, fn: tl.constexpr, fn_args):

JITFunction.cache_hook = None
JITFunction.compiled_hook = None
y = torch.zeros((5, ), dtype=torch.int32, device="cuda")
y = torch.zeros((5, ), dtype=torch.int32, device=device)
kernel[(1, )](y[0], func1, tuple())
kernel[(1, )](y[1], func2, tuple())
kernel[(1, )](y[2], func3, (3, ))
Expand Down

0 comments on commit 2efb067

Please sign in to comment.