Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Fix graph breaks in tensor subclass (#131)
Browse files Browse the repository at this point in the history
Summary:
For more detailed understanding of status see:
#106

But this removes all graph breaks on the main work branch

Pull Request resolved: #131

Reviewed By: albanD

Differential Revision: D50758815

Pulled By: drisspg

fbshipit-source-id: 1502601099988b1eba666306e327eb724eb14989
  • Loading branch information
drisspg authored and facebook-github-bot committed Oct 27, 2023
1 parent 436102b commit 429a313
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 29 deletions.
7 changes: 4 additions & 3 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Lets wait to define the top level interface
# from float8_experimental.float8_tensor import Float8Tensor
# Lets define a few top level things here
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_linear import Float8Linear

# __all__ = ["Float8Tensor"]
__all__ = ["Float8Tensor", "Float8Linear"]
23 changes: 6 additions & 17 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from float8_experimental.float8_python_api import mm_float8
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, to_float8

from float8_experimental.float8_utils import (
amax_history_to_scale,
Expand Down Expand Up @@ -174,9 +174,7 @@ class DelayedScalingRecipe:

class Float8LinearMixin(object):
def __init__(self, *args, **kwargs):
delayed_scaling_recipe = kwargs.pop(
"delayed_scaling_recipe", DelayedScalingRecipe()
)
delayed_scaling_recipe = kwargs.pop("delayed_scaling_recipe", DelayedScalingRecipe())
super().__init__(*args, **kwargs)

# TODO(future): have a unique recipe per buffer instead of one per
Expand Down Expand Up @@ -239,10 +237,7 @@ def cast_x_to_float8(self, x, is_amax_initialized):
torch.float8_e4m3fn,
is_amax_initialized,
)
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x
)

x_fp8 = to_float8(x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x)
return x_fp8

def cast_w_to_float8(self, w, is_amax_initialized):
Expand All @@ -256,9 +251,7 @@ def cast_w_to_float8(self, w, is_amax_initialized):
torch.float8_e4m3fn,
is_amax_initialized,
)
w_fp8 = Float8Tensor.to_float8(
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w
)
w_fp8 = to_float8(w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w)
return w_fp8

def cast_y_to_float8_in_bw(self, y):
Expand All @@ -275,9 +268,7 @@ def cast_y_to_float8_in_bw(self, y):

def float8_mm(self, x_fp8, w_fp8, is_amax_initialized):
scale_fn_name = self.recipe.scale_fn_name
y = float8_linear.apply(
x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate
)
y = float8_linear.apply(x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate)
return y

def float8_pre_forward(self, x):
Expand Down Expand Up @@ -416,9 +407,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
#
_update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x)
_update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w)
_update_history_with_new_amax(
child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY
)
_update_history_with_new_amax(child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY)

#
# 3. calculate the scales
Expand Down
20 changes: 15 additions & 5 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def backward(ctx, g):
return g, None, None, None


def to_float8(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer:torch.Tensor =None) -> "Float8Tensor":
""" Converts a higher precision tensor to float8 in a differentiable way.
Args:
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: a buffer to store the amax value in prior to conversion
Returns:
Float8Tensor: a float8 tensor
"""
return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer)

class FromFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion from fp8
Expand All @@ -86,7 +100,7 @@ def forward(ctx, tensor):

@staticmethod
def backward(ctx, g):
return Float8Tensor.to_float8(g), None, None
return to_float8(g), None, None


class Float8Tensor(torch.Tensor):
Expand Down Expand Up @@ -154,10 +168,6 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata):
def to_original_precision(self):
return FromFloat8ConstrFunc.apply(self)

@classmethod
def to_float8(cls, tensor, scale, float8_dtype, amax_buffer=None):
return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# 1. tracing through __torch_function__ logic is not supported yet in
Expand Down
8 changes: 4 additions & 4 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from float8_experimental.float8_linear_nots import Float8LinearNoTensorSubclass
from float8_experimental.float8_python_api import mm_float8
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, to_float8

from float8_experimental.float8_utils import (
amax_to_scale,
Expand All @@ -39,7 +39,7 @@ def test_preserves_dtype(self):
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = Float8Tensor.to_float8(x1_hp, x1_s, lp_dtype)
x2_lp = to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)

Expand Down Expand Up @@ -248,8 +248,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)
a_fp8 = to_float8(a, a_scale, input_dtype)
b_fp8 = to_float8(b, b_scale, input_dtype)

out_scaled_mm, output_amax_scaled = mm_float8(
a_fp8, b_fp8, output_dtype=output_dtype, emulate=False
Expand Down

0 comments on commit 429a313

Please sign in to comment.