Skip to content

Commit

Permalink
[AMD] Support fp16 upcast in scaled dot (#5543)
Browse files Browse the repository at this point in the history
AMD CDNA3 architectures do not have native bf16 VALU instructions so
doing bf16 scaling can be expensive.

This commit prototypes upcasting to fp16 for computation. It would mean
relaxing to support fp16 in dot_scaled frontend and upcast_mxfp op
definitions.

Right now the fp16 path is turned on if one input is fp16 for
prototyping. A more explicit way might be introduced in the future.
  • Loading branch information
antiagainst authored Jan 10, 2025
1 parent 5e337e0 commit f9d9fad
Show file tree
Hide file tree
Showing 15 changed files with 250 additions and 143 deletions.
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def TT_ScaleDotElemTypeAttr : I32EnumAttr<
I32EnumAttrCase<"E2M3", 2, "e2m3">,
I32EnumAttrCase<"E3M2", 3, "e3m2">,
I32EnumAttrCase<"E2M1", 4, "e2m1">,
I32EnumAttrCase<"BF16", 5, "bf16">

I32EnumAttrCase<"BF16", 5, "bf16">,
I32EnumAttrCase<"FP16", 6, "fp16">
]>{
let cppNamespace = "::mlir::triton";
}
Expand Down
9 changes: 7 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
}];
}

def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Convert an mxfp tensor to bf16";
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> {
let summary = "Convert an mxfp tensor to bf16/fp16";

let hasVerifier = 1;

Expand All @@ -301,6 +301,11 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
let assemblyFormat = [{
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
}];

let extraClassDeclaration = [{
static RankedTensorType deduceOutputType(
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
}];
}

// Allocate global memory
Expand Down
74 changes: 33 additions & 41 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,15 @@ LogicalResult UpcastMXFPOp::verify() {

auto xTy = getSrc().getType();
auto scaleTy = getScale().getType();

if (xTy.getElementType() != FloatType::getBF16(getContext()) &&
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
return emitOpError("element type of the first operand must be bf16 or i8");
Builder b(getContext());
if (xTy.getElementType() != b.getBF16Type() &&
xTy.getElementType() != b.getF16Type() &&
xTy.getElementType() != b.getI8Type()) {
return emitOpError(
"element type of the first operand must be bf16/fp16 or i8");
}

if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
if (scaleTy.getElementType() != b.getI8Type()) {
return emitOpError("element type of the second operand must be uint8");
}

Expand Down Expand Up @@ -381,44 +383,34 @@ LogicalResult UpcastMXFPOp::verify() {
return success();
}

LogicalResult UpcastMXFPOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
auto xTy = cast<RankedTensorType>(operands[0].getType());
auto properties = opaqueProperties.as<const Properties *>();
auto typeEncoded = properties->fp_type.getValue();
auto xShape = xTy.getShape();
RankedTensorType
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
ScaleDotElemType inputElemType,
Type outputElemType) {
MLIRContext *ctx = inputTensor.getContext();
auto xTy = inputTensor.getType();
if (inputElemType != ScaleDotElemType::E2M1)
return xTy;

auto xShape = xTy.getShape();
auto newShape = llvm::to_vector(xShape);
auto encoding = xTy.getEncoding();

if (typeEncoded == ScaleDotElemType::E2M1) {
RankedTensorType retTy;

auto newShape = SmallVector<int64_t>(xShape);
if (!encoding) {
newShape.back() *= 2;
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
} else {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
newVEncoding);
}
inferredReturnTypes.push_back(retTy);
} else {
inferredReturnTypes.push_back(xTy);
}

return success();
if (!encoding) {
newShape.back() *= 2;
return RankedTensorType::get(xShape, outputElemType);
}

auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
}

OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,10 @@ class DecomposeScaledBlocked
maybeWithEncoding(scale.getType(), scaleEncoding);
scale = rewriter.create<ConvertLayoutOp>(scale.getLoc(),
newScaleDotElemType, scale);
ret = rewriter.create<triton::gpu::UpcastMXFPOp>(v.getLoc(), ret, scale,
type);
auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType(
ret, type, Builder(v.getContext()).getBF16Type());
ret = rewriter.create<triton::gpu::UpcastMXFPOp>(v.getLoc(), retTy, ret,
scale, type);
}
return ret;
}
Expand Down
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ void init_triton_ir(py::module &&m) {
.value("E3M2", ScaleDotElemType::E3M2)
.value("E2M1", ScaleDotElemType::E2M1)
.value("BF16", ScaleDotElemType::BF16)
.value("FP16", ScaleDotElemType::FP16)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local())
Expand Down
99 changes: 64 additions & 35 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3521,19 +3521,24 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
for col_a, col_b in itertools.product([True, False], repeat=2)
for rhs_scale in [False, True]
for mxfp_type in ["e2m1", "e4m3", "e5m2"]
for normal_type in ["e4m3", "e5m2", "bf16"]
for normal_type in ["e4m3", "e5m2", "bf16", "fp16"]
for mma in (mma_nonk_sizes if is_hip() else [16])
for kpack in ([1, 2] if is_hip() else [1])])
def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, num_warps, mma, kpack, device):
if is_cuda():
if normal_type == "fp16":
pytest.skip("scaled_dot with fp16 input not supported on CUDA yet")
cc = torch.cuda.get_device_capability()
if cc < (8, 9):
pytest.skip("float8e4nv not supported on CUDA < 8.9")
if is_hip():
if not is_hip_cdna():
pytest.skip("scaled_dot only implemented for HIP CDNA")
if "e4m3" in (mxfp_type, normal_type) and not is_hip_mi300():
pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300")
if "e4m3" in (mxfp_type, normal_type):
if not is_hip_mi300():
pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) only implemented for MI300")
if normal_type == "fp16":
pytest.skip(f"scaled_dot({mxfp_type}, {normal_type}) not yet implemented for MI300")
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")

Expand Down Expand Up @@ -3566,13 +3571,14 @@ def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, s
tl.store(out_ptr, c.to(tl.bfloat16))

@triton.jit
def mxfp_to_bf16_kernel(
def mxfp_upcast_kernel(
x_ptr,
scale_ptr,
mxfp_ptr,
N,
e_bits: tl.constexpr,
m_bits: tl.constexpr,
to_type: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# x.shape == (N, 32) for fp8 or (N, 16) for fp4
Expand All @@ -3594,52 +3600,66 @@ def mxfp_to_bf16_kernel(
tl.static_assert(scale.dtype == tl.uint8)
tl.static_assert(x.dtype == tl.uint8)

scale_bf16 = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True)
if to_type == tl.bfloat16:
upcasted_scale = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True)
else:
tl.static_assert(to_type == tl.float16)
scale_fp32 = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
upcasted_scale = scale_fp32.to(tl.float16)

to_e_bits: tl.constexpr = 8 if to_type == tl.bfloat16 else 5
to_m_bits: tl.constexpr = 7 if to_type == tl.bfloat16 else 10
if is_fp8:
if e_bits == 5 and m_bits == 2:
x_f8 = x.to(tl.float8e5, bitcast=True)
x_bf16 = x_f8.to(tl.bfloat16)
upcasted_x = x_f8.to(to_type)
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits
non_finite_mask_bf16: tl.constexpr = ((1 << 8) - 1) << 7
x_bf16 = tl.where(
non_finite_mask_16bit: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
upcasted_x = tl.where(
x & non_finite_mask == non_finite_mask,
(x_bf16.to(tl.uint16, bitcast=True) | non_finite_mask_bf16).to(tl.bfloat16, bitcast=True),
x_bf16,
(upcasted_x.to(tl.uint16, bitcast=True) | non_finite_mask_16bit).to(to_type, bitcast=True),
upcasted_x,
)
else:
tl.static_assert(e_bits == 4 and m_bits == 3)
x_f8 = x.to(tl.float8e4nv, bitcast=True)
x_bf16 = x_f8.to(tl.bfloat16)
upcasted_x = x_f8.to(to_type)
else:
to_bias: tl.constexpr = 127 if to_type == tl.bfloat16 else 15
to_point5: tl.constexpr = 16128 if to_type == tl.bfloat16 else 0x3800
# e2m1
em0 = x & 0x7
em1 = x & 0x70
x0 = (em0.to(tl.uint16) << 2 + 4) | ((x & 0x8).to(tl.uint16) << 8 + 4)
x1 = (em1.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << (8))
x0 = (em0.to(tl.uint16) << (to_m_bits - 1)) | ((x & 0x8).to(tl.uint16) << 12)
x1 = (em1.to(tl.uint16) << (to_m_bits - 1 - 4)) | ((x & 0x80).to(tl.uint16) << 8)
# Three cases:
# 1) x is normal and non-zero: Correct bias
x0 = tl.where((em0 & 0x6) != 0, x0 + ((127 - 1) << 7), x0)
x1 = tl.where((em1 & 0x60) != 0, x1 + ((127 - 1) << 7), x1)
x0 = tl.where((em0 & 0x6) != 0, x0 + ((to_bias - 1) << to_m_bits), x0)
x1 = tl.where((em1 & 0x60) != 0, x1 + ((to_bias - 1) << to_m_bits), x1)
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16
x0 = tl.where(em0 == 0x1, 16128 | (x0 & 0x8000), x0)
x1 = tl.where(em1 == 0x10, 16128 | (x1 & 0x8000), x1)
x0 = tl.where(em0 == 0x1, to_point5 | (x0 & 0x8000), x0)
x1 = tl.where(em1 == 0x10, to_point5 | (x1 & 0x8000), x1)
# 3) x is zero, do nothing
x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True)
# Multiplication preserves infs and NaNs in x_bf16
mxfp = x_bf16 * scale_bf16
# If scale is NaN, we encode it as an bf16 inf, so we need to correct for that
upcasted_x = tl.interleave(x0, x1).to(to_type, bitcast=True)
# Multiplication preserves infs and NaNs in upcasted_x
mxfp = upcasted_x * upcasted_scale
# If scale is NaN, we encode it as an inf, so we need to correct for that
mxfp = tl.where(scale == 0xFF, float("nan"), mxfp)

offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32)

def dot_scale_ref(x, scale_x, y, scale_y, type_x, type_y):

def upcast(v, scale, type, transposed):
comp_dtype = torch.bfloat16
def upcast(v, scale, type, comp_dtype, transposed):
if scale is None:
type = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type]
type = {
"e4m3": torch.float8_e4m3fn,
"e5m2": torch.float8_e5m2,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[type]
return v.view(type).to(comp_dtype)
e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type]
# Packing is always on the K dimension so we transpose before upcasting then transpose back.
Expand All @@ -3650,15 +3670,19 @@ def upcast(v, scale, type, transposed):
N = v_upcast.numel()
BLOCK_SIZE = 512
grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, )
mxfp_to_bf16_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE,
num_warps=num_warps)
comp_dtype = tl.float16 if comp_dtype == torch.float16 else tl.bfloat16
mxfp_upcast_kernel[grid](v, scale, v_upcast, scale.numel(), e_bits, m_bits, comp_dtype, BLOCK_SIZE,
num_warps=num_warps)
assert v_upcast.isfinite().all()
if transposed:
v_upcast = v_upcast.mT
return v_upcast

x_upcast = upcast(x, scale_x, type_x, False)
y_upcast = upcast(y, scale_y, type_y, True)
# Upcast to fp16 if one of the input is fp16
comp_dtype = torch.float16 if "fp16" in (type_x, type_y) else torch.bfloat16

x_upcast = upcast(x, scale_x, type_x, comp_dtype, False)
y_upcast = upcast(y, scale_y, type_y, comp_dtype, True)

class AccumulateInFp32:

Expand All @@ -3672,15 +3696,21 @@ def __exit__(self, exc_type, exc_val, exc_tb):
with AccumulateInFp32():
return torch.matmul(x_upcast, y_upcast)

comp_dtype = torch.float16 if normal_type == "fp16" else torch.bfloat16
comp_dtype_bias = 15 if normal_type == "fp16" else 127
# The max exponent we use to initialize data in the x/y and associated scale tensor to avoid
# overflow when scaling.
comp_dtype_max_exp = 6 if normal_type == "fp16" else 15

torch.manual_seed(0)

def make_arg(shape, ty, col_major=False, max_val=255):
if col_major:
shape = shape[:-2] + (shape[-1], shape[-2])
if ty == "bf16":
ret = torch.randn(shape, dtype=torch.bfloat16, device=device)
if ty == "bf16" or ty == "fp16":
ret = torch.randn(shape, dtype=comp_dtype, device=device)
# Clamp to avoid relative error issues
ret.clamp_(-2**15, 2**15 - 1)
ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1)
else:
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
if col_major:
Expand All @@ -3696,9 +3726,8 @@ def make_arg(shape, ty, col_major=False, max_val=255):
y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b)

# sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
# Max scale= 2**15
scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15)
scale_y = make_arg((N, K // 32), "e8m0", max_val=127 + 15)
scale_x = make_arg((M, K // 32), "e8m0", max_val=comp_dtype_bias + comp_dtype_max_exp)
scale_y = make_arg((N, K // 32), "e8m0", max_val=comp_dtype_bias + comp_dtype_max_exp)
if rhs_scale:
scale_x = None
else:
Expand All @@ -3721,7 +3750,7 @@ def make_finite(x, dtype):
if is_hip():
kernel_kwargs["kpack"] = kpack
kernel_kwargs["matrix_instr_nonkdim"] = mma
z = x.new_empty((M, N), dtype=torch.bfloat16)
z = x.new_empty((M, N), dtype=comp_dtype)
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b,
**kernel_kwargs)
z_ref = dot_scale_ref(x, scale_x, y, scale_y, type_a, type_b)
Expand Down
11 changes: 9 additions & 2 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,17 +1740,24 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
lhs and rhs use microscaling formats described here:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
Software emulation enables targeting hardware architectures without native microscaling
operation support. Right now for such case, microscaled lhs/rhs are upcasted to
:code:`bf16` element type beforehand for dot computation, with one exception:
for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type,
the other input is also upcasted to :code:`fp16` element type instead.
This behavior is experimental and may be subject to change in the future.
:param lhs: The first tensor to be multiplied.
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
:param lhs_scale: Scale factor for lhs tensor.
:type lhs_scale: e8m0 type represented as an uint8 tensor.
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}.
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
:type lhs_format: str
:param rhs: The second tensor to be multiplied.
:type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
:param rhs_scale: Scale factor for rhs tensor.
:type rhs_scale: e8m0 type represented as an uint8 tensor.
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}.
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
:type rhs_format: str
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
"""
Expand Down
Loading

0 comments on commit f9d9fad

Please sign in to comment.