Skip to content

Commit

Permalink
[Frontend] [BC breaking] Implement PyTorch/JAX/NumPy 2.0 typecast sem…
Browse files Browse the repository at this point in the history
…antics for scalars (triton-lang#4613)

The idea here is that if you have a tensor `t` of dtype `uint8` and you
want
to do `t << 2`, the result should be of dtype `uint8`, not `int32`!

We do this for all dunder ops that don't output booleans.

This follows roughly the semantics of PyTorch, JAX and NumPy 2.0.

I would like to document this behaviour, but it's not clear to me where
is the best place to say so.

The PR has much more churn than I would like, as I had to move the
`to_tensor` method to `semantic` (which is where it belongs anyway).
For reviewers, the only two relevant changes are in
`computation_type_impl` and
in `bitwise_op_type_checking_impl`, where we say that we do perform
casting
for bitwise ops.
  • Loading branch information
lezcano authored Sep 6, 2024
1 parent b933f0f commit e14ee2d
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 206 deletions.
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Python API
- :doc:`triton <python-api/triton>`
- :doc:`triton.language <python-api/triton.language>`
- :doc:`triton.testing <python-api/triton.testing>`
- :doc:`Triton semantics <python-api/triton-semantics>`


.. toctree::
Expand All @@ -35,6 +36,7 @@ Python API
python-api/triton
python-api/triton.language
python-api/triton.testing
python-api/triton-semantics


Triton MLIR Dialects and Ops
Expand Down
44 changes: 44 additions & 0 deletions docs/python-api/triton-semantics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
Triton mostly follows the semantics of NumPy with minor exceptions. In this document, we go over some of the array computing features supported in Triton, and we cover the exceptions where Triton's semantics deviate from that NumPy.

Type Promotion
==============

**Type Promotion** occurs when tensors of different data types are used in an operation. For binary operations associated to `dunder methods <https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types>`_ and the ternary function ``tl.where`` on its last two arguments, Triton automatically converts the input tensors to a common data type following a hierarchy of kinds (sets of dtypes): ``{bool} < {integral dypes} < {floating point dtypes}``.

The algorithm is as follows:

1. **Kind** If one tensor is of a dtype of a higher kind, the other tensor is promoted to this dtype: ``(int32, bfloat16) -> bfloat16``

2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32``

3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32``

3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``.

4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32``

The rules are a bit different when they involve a scalar. By scalar here we mean a numeric literal, a variable marked with `tl.constexpr` or a combination of these. These are represented by NumPy scalars and have types ``bool``, ``int`` and ``float``.

When an operation involves a tensor and a scalar:

1. If the scalar is of a kind lower or equal to the tensor, it will not participate in the promotion: ``(uint8, int) -> uint8``

2. If the scalar is of a higher kind, we choose the lowest dtype in which it fits among ``int32`` < ``uint32`` < ``int64`` < ``uint64`` for ints and ``float32`` < ``float64`` for floats. Then, both the tensor and the scalar are promoted to this dtype: ``(int16, 4.0) -> float32``


Broadcasting
============

**Broadcasting** allows operations on tensors of different shapes by automatically expanding their shapes to a compatible size without copying the data. This follows the following rules:

1. If one of the tensor shapes is shorter, pad it on the left with ones until both tensors have the same number of dimensions: ``((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))``

2. Two dimensions are compatible if they are equal, or if one of them is 1. A dimension of 1 will be expanded to match the dimension of the other tensor. ``((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))``


Differences with NumPy
======================

**C rounding in integer division** Operators in Triton follow C semantics rather than Python semantics for efficiency. As such, ``int // int`` implements `rounding towards zero as in C <https://en.wikipedia.org/wiki/Modulo#In_programming_languages>`_ for integers of mixed signs, rather than rounding towards minus infinity as in Python. For the same reason, the modulus operator ``int % int`` (which is defined as ``a % b = a - b * (a // b)``) also follows C semantics rather than Python semantics.

Perhaps confusingly, integer division and modulus follow Python semantics for computations where all the inputs are scalars.
137 changes: 96 additions & 41 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa: F821,F841
import contextlib
import itertools
import re
from typing import Optional
Expand All @@ -18,11 +19,16 @@
from triton.language.extra import libdevice

from triton._internal_testing import (
integral_dtypes,
int_dtypes,
uint_dtypes,
float_dtypes,
dtypes,
dtypes_with_bfloat16,
is_cuda,
is_interpreter,
is_hip,
get_arch,
torch_float8_dtypes,
torch_dtypes,
numpy_random,
Expand All @@ -32,29 +38,14 @@
)


def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'


def get_current_target():
if is_interpreter():
return None
return triton.runtime.driver.active.get_current_target()


def is_cuda():
target = get_current_target()
return False if target is None else target.backend == "cuda"


def is_hip():
target = get_current_target()
return False if target is None else target.backend == "hip"


def get_arch():
target = get_current_target()
return "" if target is None else str(target.arch)
@contextlib.contextmanager
def promotion_numpy_2_0():
state = np._get_promotion_state()
np._set_promotion_state("weak")
try:
yield
finally:
np._set_promotion_state(state)


# TODO: enable multiple cta cluster testing.
Expand Down Expand Up @@ -276,7 +267,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:


def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1,
y_low=None, y_high=None, test_broadcast=True):
y_low=None, y_high=None, filter_y=None, test_broadcast=True, test_scalar=True):
check_type_supported(dtype_x, device) # early return if dtype_x is not supported
check_type_supported(dtype_y, device)
SIZE = 128
Expand Down Expand Up @@ -306,45 +297,92 @@ def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr):
z = GENERATE_TEST_HERE
tl.store(Z + off, z)

@triton.jit
def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)

replacements = {'GENERATE_TEST_HERE': expr}
kernel = patch_kernel(kernel, replacements)
kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements)
kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements)
kernel_scalar_rhs = patch_kernel(kernel_scalar_rhs, replacements)

# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
if filter_y:
y[filter_y(y)] = 1
if mode_x == 'nan':
x[:] = float('nan')
if mode_y == 'nan':
y[:] = float('nan')

def do_test(x, y, kernel_fn):
# reference result
z_ref = eval(expr if numpy_expr is None else numpy_expr)
x_is_scalar = isinstance(x, (bool, int, float))
y_is_scalar = isinstance(y, (bool, int, float))
scalar_test = x_is_scalar or y_is_scalar

# For scalars, we follow the NumPy 2.0 (and JAX/PyTorch pretty much) casting rules.
if scalar_test:
# We remove any explicit casting
pattern = r'\.astype\(np\.\w+\)'
scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr)
with promotion_numpy_2_0():
z_ref = eval(scalar_expr)
else:
z_ref = eval(expr if numpy_expr is None else numpy_expr)

dtype_z = _binary_op_dtype_override(dtype_x, dtype_y)
if dtype_z is not None:
if not scalar_test and dtype_z is not None:
z_ref = z_ref.astype(dtype_z)

# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device, dst_type=dtype_y)
x_tri = x if x_is_scalar else to_triton(x, device=device, dst_type=dtype_x)
y_tri = y if y_is_scalar else to_triton(y, device=device, dst_type=dtype_y)
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas)
err_msg = f"{expr}, {kernel_fn.__name__}"
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=1e-3, rtol=0.01)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01)

def get_scalar(x, dtype, low, high, filter):
# If dtype is int, don't choose a huge number for the scalar
# as it'll overflow easily when converted to the other dtype
if dtype in integral_dtypes:
# Choose in range [-7, 7] ([0, 7] for uints)
low_x = 0 if dtype in uint_dtypes else -7
if low is not None:
low_x = max(low_x, low)
high_x = 7
if high is not None:
high_x = min(high_x, high)
scalar = numpy_random((), dtype_str=dtype, rs=rs, low=low_x, high=high_x).item()
if filter and filter(scalar):
# https://xkcd.com/221/
scalar = 4
else:
scalar = x.flat[0].item()
return scalar

do_test(x, y, kernel)
if mode_y != 'nan' and test_scalar:
if dtype_x in uint_dtypes:
low = 0 if y_low is None else max(y_low, 0)
else:
low = y_low
y_scalar = get_scalar(y, dtype_y, low, y_high, filter_y)
do_test(x, y_scalar, kernel_scalar_rhs)
if test_broadcast:
do_test(x[:1].reshape(()), y, kernel_broadcast_lhs)
do_test(x, y[:1].reshape(()), kernel_broadcast_rhs)


def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
# The result of x % y is ill-conditioned if x % y is much smaller than x.
# pytorch/CUDA has slightly different (probably better) rounding on
# remainders than stock LLVM. We currently don't expect to match it
# bit-for-bit.
# FIXME For large x, we are casting x to a floating point where it does not fit
# For small y, we are computing floor(div(float(x), y)) which may not fit
return (dtype_x, dtype_y) in [
('int32', 'bfloat16'),
('int32', 'float16'),
Expand Down Expand Up @@ -386,7 +424,7 @@ def test_dtype_codegen():
])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
expr = f' x {op} y'
expr = f'x {op} y'
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
Expand All @@ -410,11 +448,25 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas)
else:
# skip when bfloat16, as NumPy's ref performs the computation in float32
# while Triton performs it in bfloat16
# We also skip mod when it is ill-conditioned
skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y)
or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes
and _mod_operation_ill_conditioned(dtype_x, "float32")))
# can't divide by zero
not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes
# can't represent -int(max)
not_minus_one = op in ('*', '/') and dtype_x in int_dtypes and dtype_y in int_dtypes
if not_zero or not_minus_one:
filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1)
else:
filter_y = None
_test_binary(
dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas,
# fails with values where fmod(x, y) is roughly zero, but happens to
# pass with the random values chosen for non-broadcast tests
test_broadcast=(op != "%"))
test_broadcast=(op != "%"), filter_y=filter_y, test_scalar=not skip_scalar_test)


@pytest.mark.interpreter
Expand Down Expand Up @@ -454,7 +506,13 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device):
# reference result for //.
expr = 'x // y'
numpy_expr = '((x - np.fmod(x, y)) / y)'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas)
# can't represent -int(max)
not_minus_one = dtype_x in int_dtypes and dtype_y in int_dtypes
if not_minus_one:
filter_y = lambda y: y == -1
else:
filter_y = None
_test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas)


def test_unsigned_name_mangling(device):
Expand Down Expand Up @@ -519,10 +577,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device):

@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ #
(dtype_x, dtype_y, op)
for op in ['<<', '>>']
for dtype_x in int_dtypes + uint_dtypes
for dtype_y in int_dtypes + uint_dtypes
(dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes
])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_shift_op(dtype_x, dtype_y, op, num_ctas, device):
Expand Down
33 changes: 28 additions & 5 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
import numpy as np
import torch
Expand All @@ -11,13 +12,39 @@

int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
integral_dtypes = int_dtypes + uint_dtypes
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes = integral_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']


def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'


def get_current_target():
if is_interpreter():
return None
return triton.runtime.driver.active.get_current_target()


def is_cuda():
target = get_current_target()
return False if target is None else target.backend == "cuda"


def is_hip():
target = get_current_target()
return False if target is None else target.backend == "hip"


def get_arch():
target = get_current_target()
return "" if target is None else str(target.arch)


def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
"""
Override `rs` if you're calling this function twice and don't want the same
Expand Down Expand Up @@ -89,10 +116,6 @@ def to_numpy(x):
raise ValueError(f"Not a triton-compatible tensor: {x}")


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"


def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9

Expand Down
Loading

0 comments on commit e14ee2d

Please sign in to comment.