diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index 04e4c25fd6d8..1e7e663ad279 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -128,8 +128,8 @@ def TT_ScaleDotElemTypeAttr : I32EnumAttr< I32EnumAttrCase<"E2M3", 2, "e2m3">, I32EnumAttrCase<"E3M2", 3, "e3m2">, I32EnumAttrCase<"E2M1", 4, "e2m1">, - I32EnumAttrCase<"BF16", 5, "bf16"> - + I32EnumAttrCase<"BF16", 5, "bf16">, + I32EnumAttrCase<"FP16", 6, "fp16"> ]>{ let cppNamespace = "::mlir::triton"; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 2b0643c338fb..546de144f09e 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -283,8 +283,8 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods]> { - let summary = "Convert an mxfp tensor to bf16"; +def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> { + let summary = "Convert an mxfp tensor to bf16/fp16"; let hasVerifier = 1; @@ -301,6 +301,11 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods` type($result) }]; + + let extraClassDeclaration = [{ + static RankedTensorType deduceOutputType( + TypedValue inputTensor, ScaleDotElemType inputElemType, Type outputElemType); + }]; } // Allocate global memory diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 39d52ac89100..f2088f3a84ca 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -301,13 +301,15 @@ LogicalResult UpcastMXFPOp::verify() { auto xTy = getSrc().getType(); auto scaleTy = getScale().getType(); - - if (xTy.getElementType() != FloatType::getBF16(getContext()) && - xTy.getElementType() != IntegerType::get(getContext(), 8)) { - return emitOpError("element type of the first operand must be bf16 or i8"); + Builder b(getContext()); + if (xTy.getElementType() != b.getBF16Type() && + xTy.getElementType() != b.getF16Type() && + xTy.getElementType() != b.getI8Type()) { + return emitOpError( + "element type of the first operand must be bf16/fp16 or i8"); } - if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) { + if (scaleTy.getElementType() != b.getI8Type()) { return emitOpError("element type of the second operand must be uint8"); } @@ -381,44 +383,34 @@ LogicalResult UpcastMXFPOp::verify() { return success(); } -LogicalResult UpcastMXFPOp::inferReturnTypes( - MLIRContext *ctx, std::optional loc, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties opaqueProperties, - RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto xTy = cast(operands[0].getType()); - auto properties = opaqueProperties.as(); - auto typeEncoded = properties->fp_type.getValue(); - auto xShape = xTy.getShape(); +RankedTensorType +UpcastMXFPOp::deduceOutputType(TypedValue inputTensor, + ScaleDotElemType inputElemType, + Type outputElemType) { + MLIRContext *ctx = inputTensor.getContext(); + auto xTy = inputTensor.getType(); + if (inputElemType != ScaleDotElemType::E2M1) + return xTy; + auto xShape = xTy.getShape(); + auto newShape = llvm::to_vector(xShape); auto encoding = xTy.getEncoding(); - - if (typeEncoded == ScaleDotElemType::E2M1) { - RankedTensorType retTy; - - auto newShape = SmallVector(xShape); - if (!encoding) { - newShape.back() *= 2; - retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx)); - } else { - auto oldEncoding = cast(encoding); - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), - oldEncoding.getKWidth() * 2); - // Figure out the K dimension for the input A/B, given that the return - // type is upcasted A/B type so we need to update the proper dim size. - const int opIdx = oldEncoding.getOpIdx(); - const bool hasBatch = xShape.size() == 3; - const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; - newShape[kIdx] *= 2; - retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx), - newVEncoding); - } - inferredReturnTypes.push_back(retTy); - } else { - inferredReturnTypes.push_back(xTy); - } - - return success(); + if (!encoding) { + newShape.back() *= 2; + return RankedTensorType::get(xShape, outputElemType); + } + + auto oldEncoding = cast(encoding); + auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(), + oldEncoding.getParent(), + oldEncoding.getKWidth() * 2); + // Figure out the K dimension for the input A/B, given that the return + // type is upcasted A/B type so we need to update the proper dim size. + const int opIdx = oldEncoding.getOpIdx(); + const bool hasBatch = xShape.size() == 3; + const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch; + newShape[kIdx] *= 2; + return RankedTensorType::get(newShape, outputElemType, newVEncoding); } OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) { diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d2c9d14ba004..5f3dbf3cbeba 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -573,8 +573,10 @@ class DecomposeScaledBlocked maybeWithEncoding(scale.getType(), scaleEncoding); scale = rewriter.create(scale.getLoc(), newScaleDotElemType, scale); - ret = rewriter.create(v.getLoc(), ret, scale, - type); + auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType( + ret, type, Builder(v.getContext()).getBF16Type()); + ret = rewriter.create(v.getLoc(), retTy, ret, + scale, type); } return ret; } diff --git a/python/src/ir.cc b/python/src/ir.cc index 668365bbdad5..b6b0c846fe91 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -232,6 +232,7 @@ void init_triton_ir(py::module &&m) { .value("E3M2", ScaleDotElemType::E3M2) .value("E2M1", ScaleDotElemType::E2M1) .value("BF16", ScaleDotElemType::BF16) + .value("FP16", ScaleDotElemType::FP16) .export_values(); py::class_(m, "context", py::module_local()) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b2838697d3d7..d49e02230aaf 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3521,19 +3521,24 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid for col_a, col_b in itertools.product([True, False], repeat=2) for rhs_scale in [False, True] for mxfp_type in ["e2m1", "e4m3", "e5m2"] - for normal_type in ["e4m3", "e5m2", "bf16"] + for normal_type in ["e4m3", "e5m2", "bf16", "fp16"] for mma in (mma_nonk_sizes if is_hip() else [16]) for kpack in ([1, 2] if is_hip() else [1])]) def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, num_warps, mma, kpack, device): if is_cuda(): + if normal_type == "fp16": + pytest.skip("scaled_dot with fp16 input not supported on CUDA yet") cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") if is_hip(): if not is_hip_cdna(): pytest.skip("scaled_dot only implemented for HIP CDNA") - if "e4m3" in (mxfp_type, normal_type) and not is_hip_mi300(): - pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300") + if "e4m3" in (mxfp_type, normal_type): + if not is_hip_mi300(): + pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300") + if normal_type == "fp16": + pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) not yet implemented for MI300") if mma == 16 and K == 64: pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") @@ -3566,13 +3571,14 @@ def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, s tl.store(out_ptr, c.to(tl.bfloat16)) @triton.jit - def mxfp_to_bf16_kernel( + def mxfp_upcast_kernel( x_ptr, scale_ptr, mxfp_ptr, N, e_bits: tl.constexpr, m_bits: tl.constexpr, + to_type: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): # x.shape == (N, 32) for fp8 or (N, 16) for fp4 @@ -3594,41 +3600,51 @@ def mxfp_to_bf16_kernel( tl.static_assert(scale.dtype == tl.uint8) tl.static_assert(x.dtype == tl.uint8) - scale_bf16 = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + if to_type == tl.bfloat16: + upcasted_scale = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + else: + tl.static_assert(to_type == tl.float16) + scale_fp32 = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + upcasted_scale = scale_fp32.to(tl.float16) + + to_e_bits: tl.constexpr = 8 if to_type == tl.bfloat16 else 5 + to_m_bits: tl.constexpr = 7 if to_type == tl.bfloat16 else 10 if is_fp8: if e_bits == 5 and m_bits == 2: x_f8 = x.to(tl.float8e5, bitcast=True) - x_bf16 = x_f8.to(tl.bfloat16) + upcasted_x = x_f8.to(to_type) # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits - non_finite_mask_bf16: tl.constexpr = ((1 << 8) - 1) << 7 - x_bf16 = tl.where( + non_finite_mask_16bit: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits + upcasted_x = tl.where( x & non_finite_mask == non_finite_mask, - (x_bf16.to(tl.uint16, bitcast=True) | non_finite_mask_bf16).to(tl.bfloat16, bitcast=True), - x_bf16, + (upcasted_x.to(tl.uint16, bitcast=True) | non_finite_mask_16bit).to(to_type, bitcast=True), + upcasted_x, ) else: tl.static_assert(e_bits == 4 and m_bits == 3) x_f8 = x.to(tl.float8e4nv, bitcast=True) - x_bf16 = x_f8.to(tl.bfloat16) + upcasted_x = x_f8.to(to_type) else: + to_bias: tl.constexpr = 127 if to_type == tl.bfloat16 else 15 + to_point5: tl.constexpr = 16128 if to_type == tl.bfloat16 else 0x3800 # e2m1 em0 = x & 0x7 em1 = x & 0x70 - x0 = (em0.to(tl.uint16) << 2 + 4) | ((x & 0x8).to(tl.uint16) << 8 + 4) - x1 = (em1.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << (8)) + x0 = (em0.to(tl.uint16) << (to_m_bits - 1)) | ((x & 0x8).to(tl.uint16) << 12) + x1 = (em1.to(tl.uint16) << (to_m_bits - 1 - 4)) | ((x & 0x80).to(tl.uint16) << 8) # Three cases: # 1) x is normal and non-zero: Correct bias - x0 = tl.where((em0 & 0x6) != 0, x0 + ((127 - 1) << 7), x0) - x1 = tl.where((em1 & 0x60) != 0, x1 + ((127 - 1) << 7), x1) + x0 = tl.where((em0 & 0x6) != 0, x0 + ((to_bias - 1) << to_m_bits), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((to_bias - 1) << to_m_bits), x1) # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 - x0 = tl.where(em0 == 0x1, 16128 | (x0 & 0x8000), x0) - x1 = tl.where(em1 == 0x10, 16128 | (x1 & 0x8000), x1) + x0 = tl.where(em0 == 0x1, to_point5 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, to_point5 | (x1 & 0x8000), x1) # 3) x is zero, do nothing - x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) - # Multiplication preserves infs and NaNs in x_bf16 - mxfp = x_bf16 * scale_bf16 - # If scale is NaN, we encode it as an bf16 inf, so we need to correct for that + upcasted_x = tl.interleave(x0, x1).to(to_type, bitcast=True) + # Multiplication preserves infs and NaNs in upcasted_x + mxfp = upcasted_x * upcasted_scale + # If scale is NaN, we encode it as an inf, so we need to correct for that mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -3636,10 +3652,14 @@ def mxfp_to_bf16_kernel( def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y): - def upcast(v, scale, type, transposed): - comp_dtype = torch.bfloat16 + def upcast(v, scale, type, comp_dtype, transposed): if scale is None: - type = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type] + type = { + "e4m3": torch.float8_e4m3fn, + "e5m2": torch.float8_e5m2, + "bf16": torch.bfloat16, + "fp16": torch.float16, + }[type] return v.view(type).to(comp_dtype) e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type] # Packing is always on the K dimension so we transpose before upcasting then transpose back. @@ -3650,15 +3670,19 @@ def upcast(v, scale, type, transposed): N = v_upcast.numel() BLOCK_SIZE = 512 grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) - mxfp_to_bf16_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, - num_warps=num_warps) + comp_dtype = tl.float16 if comp_dtype == torch.float16 else tl.bfloat16 + mxfp_upcast_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, comp_dtype, BLOCK_SIZE, + num_warps=num_warps) assert v_upcast.isfinite().all() if transposed: v_upcast = v_upcast.mT return v_upcast - x_upcast = upcast(x, scale_x, type_x, False) - y_upcast = upcast(y, scale_y, type_y, True) + # Upcast to fp16 if one of the input is fp16 + comp_dtype = torch.float16 if "fp16" in (type_x, type_y) else torch.bfloat16 + + x_upcast = upcast(x, scale_x, type_x, comp_dtype, False) + y_upcast = upcast(y, scale_y, type_y, comp_dtype, True) class AccumulateInFp32: @@ -3672,15 +3696,21 @@ def __exit__(self, exc_type, exc_val, exc_tb): with AccumulateInFp32(): return torch.matmul(x_upcast, y_upcast) + comp_dtype = torch.float16 if normal_type == "fp16" else torch.bfloat16 + comp_dtype_bias = 15 if normal_type == "fp16" else 127 + # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid + # overflow when scaling. + comp_dtype_max_exp = 6 if normal_type == "fp16" else 15 + torch.manual_seed(0) def make_arg(shape, ty, col_major=False, max_val=255): if col_major: shape = shape[:-2] + (shape[-1], shape[-2]) - if ty == "bf16": - ret = torch.randn(shape, dtype=torch.bfloat16, device=device) + if ty == "bf16" or ty == "fp16": + ret = torch.randn(shape, dtype=comp_dtype, device=device) # Clamp to avoid relative error issues - ret.clamp_(-2**15, 2**15 - 1) + ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1) else: ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) if col_major: @@ -3696,9 +3726,8 @@ def make_arg(shape, ty, col_major=False, max_val=255): y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b) # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) - # Max scale= 2**15 - scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15) - scale_y = make_arg((N, K // 32), "e8m0", max_val=127 + 15) + scale_x = make_arg((M, K // 32), "e8m0", max_val=comp_dtype_bias + comp_dtype_max_exp) + scale_y = make_arg((N, K // 32), "e8m0", max_val=comp_dtype_bias + comp_dtype_max_exp) if rhs_scale: scale_x = None else: @@ -3721,7 +3750,7 @@ def make_finite(x, dtype): if is_hip(): kernel_kwargs["kpack"] = kpack kernel_kwargs["matrix_instr_nonkdim"] = mma - z = x.new_empty((M, N), dtype=torch.bfloat16) + z = x.new_empty((M, N), dtype=comp_dtype) pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, **kernel_kwargs) z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 847b413ab1d8..a07cf5dc6864 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1740,17 +1740,24 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, lhs and rhs use microscaling formats described here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + Software emulation enables targeting hardware architectures without native microscaling + operation support. Right now for such case, microscaled lhs/rhs are upcasted to + :code:`bf16` element type beforehand for dot computation, with one exception: + for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type, + the other input is also upcasted to :code:`fp16` element type instead. + This behavior is experimental and may be subject to change in the future. + :param lhs: The first tensor to be multiplied. :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. :param lhs_scale: Scale factor for lhs tensor. :type lhs_scale: e8m0 type represented as an uint8 tensor. - :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. :type lhs_format: str :param rhs: The second tensor to be multiplied. :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. :param rhs_scale: Scale factor for rhs tensor. :type rhs_scale: e8m0 type represented as an uint8 tensor. - :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. :type rhs_format: str :param acc: The accumulator tensor. If not None, the result is added to this tensor. """ diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 07d99fafc0a7..62e2538acca9 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1548,7 +1548,7 @@ def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): If float_format is subbyte, make sure it's packed as uint8 and return it. Otherwise, return a tensor (perhaps bitcasting) of the specified float format. """ - triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format) + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) if triton_ty is None: assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" @@ -1556,7 +1556,7 @@ def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): if val.dtype == triton_ty: return val else: - unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format] + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" return bitcast(val, triton_ty, builder) @@ -1572,7 +1572,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te rhs_format: str = rhs_format.value lhs_format_enum = _str_to_fp_type(lhs_format) rhs_format_enum = _str_to_fp_type(rhs_format) - allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16"} + allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"} assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h b/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h index 0c60759a8cb7..4e19b370c83a 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h @@ -232,9 +232,8 @@ struct GCNBuilder { std::string dump() const; - mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc, - Type resTy, bool hasSideEffect = true, - bool isAlignStack = false, + mlir::Value launch(RewriterBase &rewriter, Location loc, Type resTy, + bool hasSideEffect = true, bool isAlignStack = false, ArrayRef attrs = {}) const; private: diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index ce913c7127c7..1fcc6851fae1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -99,25 +99,6 @@ static Value cvtFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter, return builder.launch(rewriter, loc, f32_ty, false); } -static Value cvtFp32ToFp16(Location loc, ConversionPatternRewriter &rewriter, - const Value &v, const RoundingMode rounding) { - GCNBuilder builder; - - auto &cvt = *builder.create("v_cvt_f16_f32"); - auto res = builder.newOperand("=v"); - auto operand = builder.newOperand(v, "v"); - if (rounding == RoundingMode::RTZ) { - auto &setRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0xc"); - setRTZ(); - } - cvt(res, operand); - if (rounding == RoundingMode::RTZ) { - auto &resetRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0x0"); - resetRTZ(); - } - return builder.launch(rewriter, loc, f16_ty, false); -} - // convert fp8 to fp32 static SmallVector cvtFp8ToFp32(Location loc, ConversionPatternRewriter &rewriter, @@ -194,8 +175,8 @@ convert_val_Fp8_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, SmallVector ret = cvtFp8ToFp32(loc, rewriter, v0, v1, fp8_format); // Convert fp32 to fp16 - ret[0] = cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE); - ret[1] = cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE); + ret[0] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE); + ret[1] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE); return ret; } @@ -1014,7 +995,7 @@ struct FpToFpOpConversion outVals.reserve(operands[0].size()); for (Value v : operands[0]) { outVals.push_back( - cvtFp32ToFp16(loc, rewriter, v, roundingMode.value())); + LLVM::AMD::cvtFp32ToFp16(loc, rewriter, v, roundingMode.value())); } return outVals; } @@ -1065,8 +1046,8 @@ struct FpToFpOpConversion } if (useFP16IntermediateSrc) for (Value &v : inVals) - v = cvtFp32ToFp16(loc, rewriter, v, - roundingMode.value_or(RoundingMode::RTNE)); + v = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, v, + roundingMode.value_or(RoundingMode::RTNE)); inVals.resize(numElements, undef(typeConverter->convertType(srcType))); SmallVector outVals; if (srcType != dstType) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp index b83707ee145f..2de1c0f3d23d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp @@ -72,9 +72,8 @@ SmallVector GCNBuilder::getAllArgs() const { return res; } -mlir::Value GCNBuilder::launch(ConversionPatternRewriter &rewriter, - Location loc, Type resTy, bool hasSideEffect, - bool isAlignStack, +mlir::Value GCNBuilder::launch(RewriterBase &rewriter, Location loc, Type resTy, + bool hasSideEffect, bool isAlignStack, ArrayRef attrs) const { auto *ctx = rewriter.getContext(); auto inlineAsm = rewriter.create( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 86ddbbd1952c..cd02807db72c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -1,5 +1,6 @@ #include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/TypeUtilities.h" @@ -19,6 +20,51 @@ using namespace mlir::triton::gpu; namespace { +SmallVector convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc, + ArrayRef values) { + SmallVector results; + for (auto v : values) { + auto em0 = and_(v, i8_val(0x7)); + auto em1 = and_(v, i8_val(0x70)); + // FP16 bits: sign = 1, exponent = 5, mantissa = 10 + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(10 - 1)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(10 - 1 - 4)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x6)), i8_val(0)), + add(v0, i16_val((15 - 1) << 10)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x60)), i8_val(0)), + add(v1, i16_val((15 - 1) << 10)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5 + v0 = bitcast(select(icmp_eq(em0, i8_val(0x1)), + or_(i16_val(0x3800), and_(v0, i16_val(0x8000))), v0), + f16_ty); + v1 = bitcast(select(icmp_eq(em1, i8_val(0x10)), + or_(i16_val(0x3800), and_(v1, i16_val(0x8000))), v1), + f16_ty); + // 3) x is zero, nothing to do + results.push_back(v0); + results.push_back(v1); + } + return results; +} + +Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty); + Value scaleF16 = + LLVM::AMD::cvtFp32ToFp16(loc, rewriter, scaleF32, RoundingMode::RTNE); + Value mulF16 = fmul(v, scaleF16); + // Account for NaN in the scale as per the mxfp specification. + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value nanF16 = bitcast(i16_val(0x7c01), f16_ty); + return select(scaleIsNan, nanF16, bitcast(mulF16, f16_ty)); +}; + // Scales the given bf16 v using the given scale factor without relying on bf16 // multiplication. // @@ -55,7 +101,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { bool isPacked = fpType == ScaleDotElemType::E2M1; if (!(isPacked || fpType == ScaleDotElemType::E4M3 || fpType == ScaleDotElemType::E5M2)) - return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); + return rewriter.notifyMatchFailure(op, "NYI: non-mxfp4/mxfp8 cases"); Location loc = op.getLoc(); auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); @@ -88,8 +134,11 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (isPacked) - xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); + bool useFp16 = op.getType().getElementType().isF16(); + if (isPacked) { + xVals = useFp16 ? convertMxfp4x2ToFp16x2(rewriter, loc, xVals) + : LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); + } // Given that MFMA layout for the A tensor arranges thread in a column-major // manner, for the current tid, it's at row (tid % mDim). When we set up @@ -117,7 +166,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; xVals[index] = - mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], si[j / 16]); + useFp16 ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16]) + : mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], + si[j / 16]); } } } else { @@ -140,7 +191,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; xVals[index] = - mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], si[j / 8]); + useFp16 + ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16]) + : mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], si[j / 8]); } } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 0bd401f1993a..71b5d6579eab 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -1,10 +1,8 @@ #include "Utility.h" -#include "PatternTritonGPUOpToLLVM.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/PatternMatch.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -358,4 +356,23 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange({ptr, val, pred})); } +Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, + triton::RoundingMode rounding) { + GCNBuilder builder; + + auto &cvt = *builder.create("v_cvt_f16_f32"); + auto res = builder.newOperand("=v"); + auto operand = builder.newOperand(v, "v"); + if (rounding == triton::RoundingMode::RTZ) { + auto &setRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0xc"); + setRTZ(); + } + cvt(res, operand); + if (rounding == triton::RoundingMode::RTZ) { + auto &resetRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0x0"); + resetRTZ(); + } + return builder.launch(rewriter, loc, f16_ty, false); +} + } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index cba2db5a896b..f1c5d1497129 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -47,6 +47,9 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, Value pred, int64_t alignmentBytes = 0, triton::CacheModifier cm = triton::CacheModifier::NONE); + +Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v, + triton::RoundingMode rounding); } // namespace mlir::LLVM::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index f3f64d799575..4e507755119a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -152,11 +152,12 @@ FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, } FailureOr chooseMfmaInstruction(tt::DotScaledOp dot, int mfmaVersion, - int nonKDim) { - // For scaled dot, we handle it with bf16 emulation for now. - Type bf16Type = Builder(dot.getContext()).getBF16Type(); + int nonKDim, bool useFp16) { + // For scaled dot, we handle it with fp16 or bf16 emulation for now. + Builder b(dot.getContext()); + Type elemType = useFp16 ? b.getF16Type() : b.getBF16Type(); return chooseMfmaInstruction( - dot.getC().getType(), /*aElemType=*/bf16Type, /*bElemType=*/bf16Type, + dot.getC().getType(), /*aElemType=*/elemType, /*bElemType=*/elemType, dot.getLhs().getType().getShape().back(), mfmaVersion, nonKDim); } @@ -505,7 +506,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { return elemType == ScaleDotElemType::E2M1 || elemType == ScaleDotElemType::E4M3 || elemType == ScaleDotElemType::E5M2 || - elemType == ScaleDotElemType::BF16; + elemType == ScaleDotElemType::BF16 || + elemType == ScaleDotElemType::FP16; }; if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) return rewriter.notifyMatchFailure(dotOp, "NYI: mxfp6 operand"); @@ -518,11 +520,20 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); // Choose a suitable MFMA instruction for this scaled dot op. + bool useFp16 = dotOp.getLhsType() == ScaleDotElemType::FP16 || + dotOp.getRhsType() == ScaleDotElemType::FP16; FailureOr mfmaInstr = - chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, useFp16); if (failed(mfmaInstr)) return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic"); + if (useFp16) { + dotOp.emitRemark( + "Warning: detected one dot_scaled operand is fp16 tensor so " + "upcasting to fp16 for computation, which impacts precision; " + "experimental behavior and may change in future"); + } + unsigned mDim = mfmaInstr.value().getMDim(); unsigned nDim = mfmaInstr.value().getNDim(); unsigned kDim = mfmaInstr.value().getKDim(); @@ -560,8 +571,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { auto newAcc = rewriter.create( dotOp.getC().getLoc(), newRetType, dotOp.getC()); - auto toMMABf16 = [&](TensorValue v, int idx, - ScaleDotElemType type) -> TensorValue { + auto upcastForMMA = [&](TensorValue v, int idx, + ScaleDotElemType type) -> TensorValue { auto vType = v.getType(); auto newVEncoding = DotOperandEncodingAttr::get( ctx, idx, newRetType.getEncoding(), kWdiths[idx]); @@ -570,16 +581,19 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { v = rewriter.create(v.getLoc(), newVType, v); // Don't need to covert int8 holding mxfp4--the upcast_mxfp op can // take int8 tensor as input. - if (type == ScaleDotElemType::BF16 || type == ScaleDotElemType::E2M1) + if (type == ScaleDotElemType::BF16 || type == ScaleDotElemType::FP16 || + type == ScaleDotElemType::E2M1) return v; - auto vTypeBf16 = RankedTensorType::get( - vType.getShape(), rewriter.getBF16Type(), newVEncoding); + auto upcastedType = RankedTensorType::get( + vType.getShape(), + useFp16 ? rewriter.getF16Type() : rewriter.getBF16Type(), + newVEncoding); return cast( - rewriter.create(v.getLoc(), vTypeBf16, v).getResult()); + rewriter.create(v.getLoc(), upcastedType, v).getResult()); }; - a = toMMABf16(a, 0, aElemType); - b = toMMABf16(b, 1, bElemType); + a = upcastForMMA(a, 0, aElemType); + b = upcastForMMA(b, 1, bElemType); // We need to have "matching" encoding between the main tensor and scale // tensor to make sure the scale values needed is in the same warp. So we @@ -598,10 +612,10 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, blockWarpsPerCTA, {1, 0}, ctaLayout); - auto upcastMXFP = [&](TensorValue main, TensorValue scale, + auto upcastMXFP = [&](TensorValue v, TensorValue scale, ScaleDotElemType elemType) -> Value { if (!scale) - return main; + return v; auto newScaleType = RankedTensorType::get( scale.getType().getShape(), scale.getType().getElementType(), @@ -609,8 +623,13 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { auto convOp = rewriter.create(scale.getLoc(), newScaleType, scale); - return rewriter.create(dotOp.getLoc(), main, - convOp, elemType); + Builder b(v.getContext()); + // TODO: Emit device assert to check scale tensor range fitting into fp16? + Type outputElemType = useFp16 ? b.getF16Type() : b.getBF16Type(); + auto outputType = + ttg::UpcastMXFPOp::deduceOutputType(v, elemType, outputElemType); + return rewriter.create(dotOp.getLoc(), outputType, v, + convOp, elemType); }; Value scaledA = upcastMXFP(a, aScale, dotOp.getLhsType());