forked from ROCm/triton
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend] [BC breaking] Implement PyTorch/JAX/NumPy 2.0 typecast sem…
…antics for scalars (triton-lang#4613) The idea here is that if you have a tensor `t` of dtype `uint8` and you want to do `t << 2`, the result should be of dtype `uint8`, not `int32`! We do this for all dunder ops that don't output booleans. This follows roughly the semantics of PyTorch, JAX and NumPy 2.0. I would like to document this behaviour, but it's not clear to me where is the best place to say so. The PR has much more churn than I would like, as I had to move the `to_tensor` method to `semantic` (which is where it belongs anyway). For reviewers, the only two relevant changes are in `computation_type_impl` and in `bitwise_op_type_checking_impl`, where we say that we do perform casting for bitwise ops.
- Loading branch information
Showing
10 changed files
with
373 additions
and
206 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
Triton mostly follows the semantics of NumPy with minor exceptions. In this document, we go over some of the array computing features supported in Triton, and we cover the exceptions where Triton's semantics deviate from that NumPy. | ||
|
||
Type Promotion | ||
============== | ||
|
||
**Type Promotion** occurs when tensors of different data types are used in an operation. For binary operations associated to `dunder methods <https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types>`_ and the ternary function ``tl.where`` on its last two arguments, Triton automatically converts the input tensors to a common data type following a hierarchy of kinds (sets of dtypes): ``{bool} < {integral dypes} < {floating point dtypes}``. | ||
|
||
The algorithm is as follows: | ||
|
||
1. **Kind** If one tensor is of a dtype of a higher kind, the other tensor is promoted to this dtype: ``(int32, bfloat16) -> bfloat16`` | ||
|
||
2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32`` | ||
|
||
3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32`` | ||
|
||
3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``. | ||
|
||
4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32`` | ||
|
||
The rules are a bit different when they involve a scalar. By scalar here we mean a numeric literal, a variable marked with `tl.constexpr` or a combination of these. These are represented by NumPy scalars and have types ``bool``, ``int`` and ``float``. | ||
|
||
When an operation involves a tensor and a scalar: | ||
|
||
1. If the scalar is of a kind lower or equal to the tensor, it will not participate in the promotion: ``(uint8, int) -> uint8`` | ||
|
||
2. If the scalar is of a higher kind, we choose the lowest dtype in which it fits among ``int32`` < ``uint32`` < ``int64`` < ``uint64`` for ints and ``float32`` < ``float64`` for floats. Then, both the tensor and the scalar are promoted to this dtype: ``(int16, 4.0) -> float32`` | ||
|
||
|
||
Broadcasting | ||
============ | ||
|
||
**Broadcasting** allows operations on tensors of different shapes by automatically expanding their shapes to a compatible size without copying the data. This follows the following rules: | ||
|
||
1. If one of the tensor shapes is shorter, pad it on the left with ones until both tensors have the same number of dimensions: ``((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))`` | ||
|
||
2. Two dimensions are compatible if they are equal, or if one of them is 1. A dimension of 1 will be expanded to match the dimension of the other tensor. ``((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))`` | ||
|
||
|
||
Differences with NumPy | ||
====================== | ||
|
||
**C rounding in integer division** Operators in Triton follow C semantics rather than Python semantics for efficiency. As such, ``int // int`` implements `rounding towards zero as in C <https://en.wikipedia.org/wiki/Modulo#In_programming_languages>`_ for integers of mixed signs, rather than rounding towards minus infinity as in Python. For the same reason, the modulus operator ``int % int`` (which is defined as ``a % b = a - b * (a // b)``) also follows C semantics rather than Python semantics. | ||
|
||
Perhaps confusingly, integer division and modulus follow Python semantics for computations where all the inputs are scalars. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.