From c37989c36b2063a479f645c7b156eefd218bc75c Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Fri, 10 Jan 2025 06:17:18 +0800 Subject: [PATCH] [AMD] Support upcasting FP8E4M3NV to FP16 (#5558) This commit supported upcasting FP8E4M3NV to FP16. --- python/test/unit/language/test_conversions.py | 4 +- .../ElementwiseOpToLLVM.cpp | 47 +++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_conversions.py b/python/test/unit/language/test_conversions.py index a63af6bea3c6..37394eabe49b 100644 --- a/python/test/unit/language/test_conversions.py +++ b/python/test/unit/language/test_conversions.py @@ -272,8 +272,8 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia ]) def test_typeconvert_upcast(src_dtype, dst_dtype, device): - # On HIP, fp8e4nv upcasting is only supported to bf16, and it's only supported on MI300. - if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or not is_hip_mi300()): + # On HIP, fp8e4nv upcasting is only supported to bf16 and fp16, and it's only supported on MI300. + if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or dst_dtype != 'float16' or not is_hip_mi300()): pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture") if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9)) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 02ad0fe5e2ce..ce913c7127c7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -278,6 +278,52 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(AMD::ISAFamily isaFamily) { : Fp16_to_Fp8E5M2FNUZ_SW; } +static SmallVector Fp8E4M3FN_to_Fp16(Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v) { + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value a = undef(fp8x4VecTy); + a = insert_element(fp8x4VecTy, a, i8_val(0), i32_val(0)); + a = insert_element(fp8x4VecTy, a, v[0], i32_val(1)); + a = insert_element(fp8x4VecTy, a, i8_val(0), i32_val(2)); + a = insert_element(fp8x4VecTy, a, v[1], i32_val(3)); + a = bitcast(a, i32_ty); + + // Get sign and absolute value + Value sign = and_(a, i32_val(0x80008000)); + a = and_(a, i32_val(0x7FFF7FFF)); + + // Right shift 1 bit to adjust the positions of exponent and mantissa + a = lshr(a, i32_val(1)); + + // Adjust exponent, (15 - 7) << 10 === 0x2000 + a = add(a, i32_val(0x20002000)); + + // Check NaN + // If the fp8 input is NaN(S.1111.111), the output is set to NaN by masking + // all the bits of exponent and mantissa to 1. + auto i16x2VecTy = vec_ty(i16_ty, 2); + Value maskVec = undef(i16x2VecTy); + + Value isNaN0 = icmp_uge(bitcast(v[0], i8_ty), i8_val(0x7F)); + Value mask0 = select(isNaN0, i16_val(0x7FFF), i16_val(0)); + maskVec = insert_element(i16x2VecTy, maskVec, mask0, i32_val(0)); + + Value isNaN1 = icmp_uge(bitcast(v[1], i8_ty), i8_val(0x7F)); + Value mask1 = select(isNaN1, i16_val(0x7FFF), i16_val(0)); + maskVec = insert_element(i16x2VecTy, maskVec, mask1, i32_val(1)); + + a = or_(a, bitcast(maskVec, i32_ty)); + + // Set sign + a = or_(a, sign); + + auto fp16x2VecTy = vec_ty(f16_ty, 2); + Value fp16x2Vec = bitcast(a, fp16x2VecTy); + return {extract_element(f16_ty, fp16x2Vec, i32_val(0)), + extract_element(f16_ty, fp16x2Vec, i32_val(1))}; +} + static SmallVector Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { @@ -914,6 +960,7 @@ struct FpToFpOpConversion // F8 -> F16 {{F8E4M3FNUZTyID, F16TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp16(isaFamily)}, + {{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16}, {{F8E5M2FNUZTyID, F16TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp16(isaFamily)}, {{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16},