Skip to content

Commit

Permalink
[AMD] Support upcasting FP8E4M3NV to FP16 (triton-lang#5558)
Browse files Browse the repository at this point in the history
This commit supported upcasting FP8E4M3NV to FP16.
  • Loading branch information
knwng authored and makslevental committed Jan 13, 2025
1 parent b4b6419 commit ac4aa6e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/test/unit/language/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
])
def test_typeconvert_upcast(src_dtype, dst_dtype, device):

# On HIP, fp8e4nv upcasting is only supported to bf16, and it's only supported on MI300.
if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or not is_hip_mi300()):
# On HIP, fp8e4nv upcasting is only supported to bf16 and fp16, and it's only supported on MI300.
if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or dst_dtype != 'float16' or not is_hip_mi300()):
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")

if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9))
Expand Down
47 changes: 47 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,52 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(AMD::ISAFamily isaFamily) {
: Fp16_to_Fp8E5M2FNUZ_SW;
}

static SmallVector<Value> Fp8E4M3FN_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a = undef(fp8x4VecTy);
a = insert_element(fp8x4VecTy, a, i8_val(0), i32_val(0));
a = insert_element(fp8x4VecTy, a, v[0], i32_val(1));
a = insert_element(fp8x4VecTy, a, i8_val(0), i32_val(2));
a = insert_element(fp8x4VecTy, a, v[1], i32_val(3));
a = bitcast(a, i32_ty);

// Get sign and absolute value
Value sign = and_(a, i32_val(0x80008000));
a = and_(a, i32_val(0x7FFF7FFF));

// Right shift 1 bit to adjust the positions of exponent and mantissa
a = lshr(a, i32_val(1));

// Adjust exponent, (15 - 7) << 10 === 0x2000
a = add(a, i32_val(0x20002000));

// Check NaN
// If the fp8 input is NaN(S.1111.111), the output is set to NaN by masking
// all the bits of exponent and mantissa to 1.
auto i16x2VecTy = vec_ty(i16_ty, 2);
Value maskVec = undef(i16x2VecTy);

Value isNaN0 = icmp_uge(bitcast(v[0], i8_ty), i8_val(0x7F));
Value mask0 = select(isNaN0, i16_val(0x7FFF), i16_val(0));
maskVec = insert_element(i16x2VecTy, maskVec, mask0, i32_val(0));

Value isNaN1 = icmp_uge(bitcast(v[1], i8_ty), i8_val(0x7F));
Value mask1 = select(isNaN1, i16_val(0x7FFF), i16_val(0));
maskVec = insert_element(i16x2VecTy, maskVec, mask1, i32_val(1));

a = or_(a, bitcast(maskVec, i32_ty));

// Set sign
a = or_(a, sign);

auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec = bitcast(a, fp16x2VecTy);
return {extract_element(f16_ty, fp16x2Vec, i32_val(0)),
extract_element(f16_ty, fp16x2Vec, i32_val(1))};
}

static SmallVector<Value> Fp8E5M2_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Expand Down Expand Up @@ -914,6 +960,7 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3FNUZTyID, F16TyID, undefRounding},
Fp8E4M3FNUZ_to_Fp16(isaFamily)},
{{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID, undefRounding},
Fp8E5M2FNUZ_to_Fp16(isaFamily)},
{{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16},
Expand Down

0 comments on commit ac4aa6e

Please sign in to comment.