diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 12c3997ec7c4..0128385824b5 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -7,22 +7,7 @@ import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback - - -def is_interpreter(): - return os.environ.get('TRITON_INTERPRET', '0') == '1' - - -def is_cuda(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" - - -def is_hip(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" - - -def is_on_mi300(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') +from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300 def test_err_undefined_variable(): @@ -367,7 +352,7 @@ def test_fp8_support(dtype): if cc >= (8, 9): supported_dtypes.append(tl.float8e4nv) elif is_hip(): - if is_on_mi300(): + if is_hip_mi300(): supported_dtypes += [tl.float8e4b8, tl.float8e5b16] elif is_interpreter(): supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15] diff --git a/python/test/unit/language/test_conversions.py b/python/test/unit/language/test_conversions.py index 723a15fe847f..12f3e9c66fe8 100644 --- a/python/test/unit/language/test_conversions.py +++ b/python/test/unit/language/test_conversions.py @@ -1,24 +1,14 @@ # fmt: off -import os import numpy as np import torch import pytest import triton import triton.language as tl -def is_interpreter(): - return os.environ.get('TRITON_INTERPRET', '0') == '1' +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300 -def is_cuda(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" - -def is_hip(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" - -def is_on_mi300(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') def matching_int(dtype): if dtype.primitive_bitwidth == 8: @@ -283,7 +273,7 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia def test_typeconvert_upcast(src_dtype, dst_dtype, device): if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9)) or (src_dtype in ('float8e4nv', 'float8e4b15') and is_hip()) - or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()))): + or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_hip_mi300()))): # If the dtype should error out in the given device, we assert that and return with pytest.raises(triton.CompilationError, match="not supported in this architecture"): launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) @@ -334,7 +324,7 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") - if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_hip_mi300()): pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") # dtype : (exponent_bits, mantissa_bits, exponent_bias) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d2dbe2044236..18e22139ba4c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -29,6 +29,7 @@ is_cuda, is_interpreter, is_hip, + is_hip_cdna, is_hip_mi200, get_arch, torch_float8_dtypes, @@ -3338,13 +3339,12 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") if is_hip(): + if not is_hip_cdna(): + pytest.skip("scaled_dot only implemented for HIP CDNA") if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]): pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP") if mma == 16 and K == 64: pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") - arch = triton.runtime.driver.active.get_current_target().arch - if "gfx11" in arch or "gfx12" in arch: - pytest.skip("scaled_dot not yet implemented for gfx11 and gfx12") @triton.jit def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py index fa5f34290b49..f92cd5892b29 100644 --- a/python/test/unit/language/test_pipeliner.py +++ b/python/test/unit/language/test_pipeliner.py @@ -6,22 +6,7 @@ import triton.language as tl import triton.tools.experimental_descriptor - -def is_cuda(): - return triton.runtime.driver.active.get_current_target().backend == "cuda" - - -def is_hopper(): - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 - - -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" - - -def is_hip_mi200(): - target = triton.runtime.driver.active.get_current_target() - return target.backend == 'hip' and target.arch == 'gfx90a' +from triton._internal_testing import is_cuda, is_hopper, is_hip_cdna, is_hip_mi200 def check_capabilities(): @@ -229,8 +214,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): @pytest.mark.parametrize("scale", [True, False]) def test_pipeline_matmul(scale, device): check_capabilities() - if scale and not is_cuda(): - pytest.skip("NYI: scale_dot just implemented in CUDA") + if scale and not (is_cuda() or is_hip_cdna()): + pytest.skip("NYI: scale_dot just implemented in CUDA/HIP") M, N, K = 512, 512, 128 BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 NUM_STAGES = 4 diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index ac50f6372a42..5765357fca36 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -36,6 +36,10 @@ def is_cuda(): return False if target is None else target.backend == "cuda" +def is_hopper(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + def is_hip(): target = get_current_target() return False if target is None else target.backend == "hip" @@ -46,6 +50,15 @@ def is_hip_mi200(): return target.backend == 'hip' and target.arch == 'gfx90a' +def is_hip_mi300(): + target = get_current_target() + return target.backend == 'hip' and target.arch in ('gfx940', 'gfx941', 'gfx942') + + +def is_hip_cdna(): + return is_hip_mi200() or is_hip_mi300() + + def get_arch(): target = get_current_target() return "" if target is None else str(target.arch)