Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Use mm in subclass #128

Closed
wants to merge 13 commits into from
149 changes: 37 additions & 112 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from float8_experimental.float8_python_api import mm_float8
from float8_experimental.float8_tensor import Float8Tensor, to_float8
from float8_experimental.float8_tensor import Float8Tensor

from float8_experimental.float8_utils import (
amax_history_to_scale,
Expand All @@ -44,10 +44,12 @@ def forward(
fp8_scale_dL_dY,
scale_fn_name,
is_amax_initialized,
emulate: bool,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.emulate = emulate
return tensor

@staticmethod
Expand All @@ -69,99 +71,11 @@ def backward(ctx, go):
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
go_scaled = go * fp8_scale_dL_dY
bits_fp8 = to_fp8_saturated(go_scaled, torch.float8_e5m2)
empty_grads = None, None, None, None, None
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype)
empty_grads = None, None, None, None, None, None
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype, emulate=ctx.emulate)
return res, *empty_grads


class float8_linear(torch.autograd.Function):
"""
Like F.linear, but with X and W in float8
"""

@staticmethod
def forward(
ctx,
x_fp8,
w_fp8,
is_amax_initialized,
scale_fn_name,
emulate: bool,
):
ctx.save_for_backward(x_fp8, w_fp8)
ctx.scale_fn_name = scale_fn_name
ctx.emulate = emulate
orig_shape = x_fp8._data.shape
x_fp8_reshaped = Float8Tensor(
x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype
)
ctx.is_amax_initialized = is_amax_initialized

w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype)

res_bits, _output_amax = mm_float8(
x_fp8_reshaped, w_fp8_t, output_dtype=x_fp8._orig_dtype, emulate=emulate
)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, go_fp8):
x_fp8, w_fp8 = ctx.saved_tensors
scale_fn_name = ctx.scale_fn_name
emulate = ctx.emulate
is_amax_initialized = ctx.is_amax_initialized

go_fp8_orig_shape = go_fp8._data.shape
go_fp8_reshaped = Float8Tensor(
go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
go_fp8._scale,
go_fp8._orig_dtype,
)

w_fp8_t_c_t = Float8Tensor(
w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype
)

#
# calculate dL/dX
#
dL_dX, _dL_dX_amax = mm_float8(
go_fp8_reshaped,
w_fp8_t_c_t,
output_dtype=x_fp8._orig_dtype,
emulate=emulate,
)
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])

x_fp8_orig_shape = x_fp8._data.shape
x_fp8_reshaped_t_c = Float8Tensor(
x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
x_fp8._scale,
x_fp8._orig_dtype,
)

go_fp8_reshaped_t_c_t = Float8Tensor(
go_fp8_reshaped._data.t().contiguous().t(),
go_fp8_reshaped._scale,
go_fp8_reshaped._orig_dtype,
)

#
# calculate dL/dW
#
dL_dW, _dL_dW_amax = mm_float8(
x_fp8_reshaped_t_c,
go_fp8_reshaped_t_c_t,
output_dtype=x_fp8._orig_dtype,
emulate=emulate,
)
dL_dW = dL_dW.t()

empty_grads = None, None, None, None, None, None, None, None, None
return dL_dX, dL_dW, *empty_grads


@dataclasses.dataclass
class DelayedScalingRecipe:
# Controls the history length of amax buffers
Expand Down Expand Up @@ -221,13 +135,17 @@ def __init__(self, *args, **kwargs):
# will access the scale when it has ensured that it is on GPU.
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)

def cast_x_to_float8(self, x, is_amax_initialized):
def cast_x_to_float8(
self, x: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
x = x.to(torch.get_autocast_gpu_dtype())
autocast_dtype = torch.get_autocast_gpu_dtype()
x = x.to(autocast_dtype)
self.bias_dtype = autocast_dtype

scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
Expand All @@ -239,10 +157,14 @@ def cast_x_to_float8(self, x, is_amax_initialized):
torch.float8_e4m3fn,
is_amax_initialized,
)
x_fp8 = to_float8(x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x)
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate
)
return x_fp8

def cast_w_to_float8(self, w, is_amax_initialized):
def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
Expand All @@ -253,10 +175,14 @@ def cast_w_to_float8(self, w, is_amax_initialized):
torch.float8_e4m3fn,
is_amax_initialized,
)
w_fp8 = to_float8(w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w)
w_fp8 = Float8Tensor.to_float8(
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate
)
return w_fp8

def cast_y_to_float8_in_bw(self, y):
def cast_y_to_float8_in_bw(
self, y: torch.Tensor, emulate: bool = False
) -> torch.Tensor:
scale_fn_name = self.recipe.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
Expand All @@ -265,13 +191,7 @@ def cast_y_to_float8_in_bw(self, y):
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
)
return y

def float8_mm(self, x_fp8, w_fp8, is_amax_initialized):
scale_fn_name = self.recipe.scale_fn_name
y = float8_linear.apply(
x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate
emulate,
)
return y

Expand All @@ -292,6 +212,11 @@ def float8_post_forward(self):
self.is_amax_initialized = True
self.amax_and_scale_synced = False

def add_weight_tag(self):
# We add a tag to the weight nn.Parameter in order to signal
# To FSDP that this param is a weight
self.weight._is_fp8_weight = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reviewing the subclass changes but probably not the right person to review this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added in a previous PR and just moved it to the Mixin so that it can be added to the TP stuff



class Float8Linear(Float8LinearMixin, torch.nn.Linear):
"""
Expand All @@ -311,11 +236,14 @@ def forward(self, x):
w_fp8 = self._w_fp8
else:
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized)
y = self.cast_y_to_float8_in_bw(y)

y = torch.matmul(x_fp8, w_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y, self.emulate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned offline but food for thought that I'll mention here: it would be interesting to think about what it would take to have aten.matmul(Float8Tensor, Float8Tensor) actually return another Float8Tensor, and then leave it to the subclass to know to upcast on future ops that don't want to handle Float8 directly.

My understanding was:

(1) This is a pain mostly because the extra buffers for float8 live directly on the Float8Linear nn module today and not the subclass (probably for good reason)

(2) Doing this would provide benefit if we want to start increasing the number of ops that directly handle float8, but all we care about is linear then this generality is probably not very useful.


if self.bias is not None:
y = y + self.bias.to(x_fp8._orig_dtype)
y = y + self.bias.to(self.bias_dtype)

self.float8_post_forward()
return y
Expand All @@ -336,16 +264,13 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
if mod.bias is not None:
new_mod.bias_dtype = mod.bias.dtype
# I think its okay to send all params and buffers to device
new_mod.to(mod.weight.device)
new_mod.add_weight_tag()
return new_mod

def add_weight_tag(self):
# We add a tag to the weight nn.Parameter in order to signal
# To FSDP that this param is a weight
self.weight._is_fp8_weight = True


def swap_linear_with_float8_linear(
model,
Expand Down
84 changes: 84 additions & 0 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Any, Dict

import torch
from float8_experimental.float8_python_api import mm_float8_unwrapped
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import is_row_major

aten = torch.ops.aten
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}


def implements(aten_ops):
"""Register aten ops to the float8 op table"""

def decorator(func):
for op in aten_ops:
FLOAT8_OPS_TABLE[op] = func
return func

return decorator


@implements(
[
aten.view.default,
aten._unsafe_view.default,
aten.t.default,
aten.as_strided.default,
aten.clone.default,
aten.detach.default,
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate)


@implements([aten.mm.default])
def float8_mm(aten_op, args, kwargs=None):
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
a = args[0]
b = args[1]
a_data = a._data
a_scale = a._scale
b_data = b._data

if not is_row_major(a_data.stride()):
a_data = a_data.contiguous()
if is_row_major(b_data.stride()):
b_data = b_data.t().contiguous().t()
b_scale = b._scale
output_dtype = a._orig_dtype
if a._emulate:
assert a._emulate == b._emulate
return torch.ops.aten.mm_float8_emulated(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also just thinking - should emulate just be a global config somewhere, instead of a flag that you have to plumb around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked about this with Brian offline. This is probably right, but I am going to do this in a followup. I also want to see if when get plain torch.nn.fucntional.linear in the LinearFloat8 and will do some matmul changes

a._data, a._scale, b._data, b._scale, output_dtype
)[0]
tensor_out, amax = mm_float8_unwrapped(
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None
)
return tensor_out


@implements([aten.is_same_size.default])
def float8_is_same_size(aten_op, args, kwargs=None):
return args[0].shape == args[1].shape


@implements([aten._to_copy.default])
def autocast_to_copy(aten_op, args, kwargs=None):
"""This gets called when running matmul under autocast
when the input is a Float8Tensor, presenting as a fp32
tensor.
"""
assert isinstance(args[0], Float8Tensor)
assert (
len(kwargs) == 1 and "dtype" in kwargs
), "Only support dtype kwarg for autocast"
assert (
kwargs["dtype"] == torch.float16
), "Only support floating point conversion for autocast w/ Float8Tensor"
return Float8Tensor(
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate
)
2 changes: 1 addition & 1 deletion float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
to simplify the product code.
"""

import warnings

from typing import Optional, Tuple

import float8_experimental.float8_aten_api
Expand Down
Loading