Skip to content

Commit

Permalink
[AMD] Enable pipeliner test for scaled_dot (triton-lang#5068)
Browse files Browse the repository at this point in the history
This commit enables pipeliner test for scaled dot
on the AMD backend.

Along the way, unified some target/arch probe
utilities into the common `_internal_testing` file.
  • Loading branch information
antiagainst authored and Luosuu committed Nov 13, 2024
1 parent 9498cff commit db1b6ed
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 51 deletions.
19 changes: 2 additions & 17 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,7 @@
import triton.language as tl
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
import traceback


def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'


def is_cuda():
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_hip():
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip"


def is_on_mi300():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')
from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300


def test_err_undefined_variable():
Expand Down Expand Up @@ -367,7 +352,7 @@ def test_fp8_support(dtype):
if cc >= (8, 9):
supported_dtypes.append(tl.float8e4nv)
elif is_hip():
if is_on_mi300():
if is_hip_mi300():
supported_dtypes += [tl.float8e4b8, tl.float8e5b16]
elif is_interpreter():
supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]
Expand Down
16 changes: 3 additions & 13 deletions python/test/unit/language/test_conversions.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
# fmt: off


import os
import numpy as np
import torch
import pytest
import triton
import triton.language as tl

def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300

def is_cuda():
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda"

def is_hip():
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip"

def is_on_mi300():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')

def matching_int(dtype):
if dtype.primitive_bitwidth == 8:
Expand Down Expand Up @@ -283,7 +273,7 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
def test_typeconvert_upcast(src_dtype, dst_dtype, device):
if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9))
or (src_dtype in ('float8e4nv', 'float8e4b15') and is_hip())
or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()))):
or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_hip_mi300()))):
# If the dtype should error out in the given device, we assert that and return
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
Expand Down Expand Up @@ -334,7 +324,7 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)):
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")

if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()):
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_hip_mi300()):
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300")

# dtype : (exponent_bits, mantissa_bits, exponent_bias)
Expand Down
6 changes: 3 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
is_cuda,
is_interpreter,
is_hip,
is_hip_cdna,
is_hip_mi200,
get_arch,
torch_float8_dtypes,
Expand Down Expand Up @@ -3338,13 +3339,12 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
if cc < (8, 9):
pytest.skip("float8e4nv not supported on CUDA < 8.9")
if is_hip():
if not is_hip_cdna():
pytest.skip("scaled_dot only implemented for HIP CDNA")
if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]):
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
arch = triton.runtime.driver.active.get_current_target().arch
if "gfx11" in arch or "gfx12" in arch:
pytest.skip("scaled_dot not yet implemented for gfx11 and gfx12")

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,
Expand Down
21 changes: 3 additions & 18 deletions python/test/unit/language/test_pipeliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,7 @@
import triton.language as tl
import triton.tools.experimental_descriptor


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_hopper():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_hip_mi200():
target = triton.runtime.driver.active.get_current_target()
return target.backend == 'hip' and target.arch == 'gfx90a'
from triton._internal_testing import is_cuda, is_hopper, is_hip_cdna, is_hip_mi200


def check_capabilities():
Expand Down Expand Up @@ -229,8 +214,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
@pytest.mark.parametrize("scale", [True, False])
def test_pipeline_matmul(scale, device):
check_capabilities()
if scale and not is_cuda():
pytest.skip("NYI: scale_dot just implemented in CUDA")
if scale and not (is_cuda() or is_hip_cdna()):
pytest.skip("NYI: scale_dot just implemented in CUDA/HIP")
M, N, K = 512, 512, 128
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
NUM_STAGES = 4
Expand Down
13 changes: 13 additions & 0 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def is_cuda():
return False if target is None else target.backend == "cuda"


def is_hopper():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9


def is_hip():
target = get_current_target()
return False if target is None else target.backend == "hip"
Expand All @@ -46,6 +50,15 @@ def is_hip_mi200():
return target.backend == 'hip' and target.arch == 'gfx90a'


def is_hip_mi300():
target = get_current_target()
return target.backend == 'hip' and target.arch in ('gfx940', 'gfx941', 'gfx942')


def is_hip_cdna():
return is_hip_mi200() or is_hip_mi300()


def get_arch():
target = get_current_target()
return "" if target is None else str(target.arch)
Expand Down

0 comments on commit db1b6ed

Please sign in to comment.