Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Support Upcasting FP8E4M3NV to FP16 #5558

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/test/unit/language/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
])
def test_typeconvert_upcast(src_dtype, dst_dtype, device):

# On HIP, fp8e4nv upcasting is only supported to bf16, and it's only supported on MI300.
if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or not is_hip_mi300()):
# On HIP, fp8e4nv upcasting is only supported to bf16 and fp16, and it's only supported on MI300.
if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or dst_dtype != 'float16' or not is_hip_mi300()):
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")

if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9))
Expand Down
47 changes: 47 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,52 @@ ConverterT Fp16_to_Fp8E5M2FNUZ(AMD::ISAFamily isaFamily) {
: Fp16_to_Fp8E5M2FNUZ_SW;
}

static SmallVector<Value> Fp8E4M3FN_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a = undef(fp8x4VecTy);
a = insert_element(fp8x4VecTy, a, i8_val(0), i32_val(0));
a = insert_element(fp8x4VecTy, a, v[0], i32_val(1));
a = insert_element(fp8x4VecTy, a, i8_val(0), i32_val(2));
a = insert_element(fp8x4VecTy, a, v[1], i32_val(3));
a = bitcast(a, i32_ty);

// Get sign and absolute value
Value sign = and_(a, i32_val(0x80008000));
a = and_(a, i32_val(0x7FFF7FFF));

// Right shift 1 bit to adjust the positions of exponent and mantissa
a = lshr(a, i32_val(1));

// Adjust exponent, (15 - 7) << 10 === 0x2000
a = add(a, i32_val(0x20002000));

// Check NaN
// If the fp8 input is NaN(S.1111.111), the output is set to NaN by masking
// all the bits of exponent and mantissa to 1.
auto i16x2VecTy = vec_ty(i16_ty, 2);
Value maskVec = undef(i16x2VecTy);

Value isNaN0 = icmp_uge(bitcast(v[0], i8_ty), i8_val(0x7F));
Value mask0 = select(isNaN0, i16_val(0x7FFF), i16_val(0));
maskVec = insert_element(i16x2VecTy, maskVec, mask0, i32_val(0));

Value isNaN1 = icmp_uge(bitcast(v[1], i8_ty), i8_val(0x7F));
Value mask1 = select(isNaN1, i16_val(0x7FFF), i16_val(0));
maskVec = insert_element(i16x2VecTy, maskVec, mask1, i32_val(1));

a = or_(a, bitcast(maskVec, i32_ty));

// Set sign
a = or_(a, sign);

auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec = bitcast(a, fp16x2VecTy);
return {extract_element(f16_ty, fp16x2Vec, i32_val(0)),
extract_element(f16_ty, fp16x2Vec, i32_val(1))};
}

static SmallVector<Value> Fp8E5M2_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Expand Down Expand Up @@ -914,6 +960,7 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3FNUZTyID, F16TyID, undefRounding},
Fp8E4M3FNUZ_to_Fp16(isaFamily)},
{{F8E4M3FNTyID, F16TyID, undefRounding}, Fp8E4M3FN_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID, undefRounding},
Fp8E5M2FNUZ_to_Fp16(isaFamily)},
{{F8E5M2TyID, F16TyID, undefRounding}, Fp8E5M2_to_Fp16},
Expand Down
Loading