From ef17881edb71d46cf039ff317bea18ddea2ea888 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 08:17:38 -0700 Subject: [PATCH] [1/x] clean up casting functions Summary: This is a start of a cleanup of private casting functions in preparation for rowwise scaling. In this PR: 1. create `float8_scaling_utils.py` to unify functions which take a high precision tensor and return a float8 tensor, taking care of scaling 2. delete `Float8Tensor.to_float8` and move callsites to `ToFloat8ConstrFunc`, since the two functions do the same thing The end result is a slightly cleaner state, future PRs will do more cleanups. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- benchmarks/bench_padding.py | 29 ++- float8_experimental/float8_dynamic_utils.py | 71 ------- float8_experimental/float8_linear.py | 112 +---------- float8_experimental/float8_scaling_utils.py | 189 ++++++++++++++++++ float8_experimental/float8_tensor.py | 32 +-- float8_experimental/float8_tensor_parallel.py | 8 +- float8_experimental/fsdp_utils.py | 14 +- float8_experimental/inference.py | 14 +- test/test_base.py | 77 +++---- test/test_compile.py | 4 +- test/test_dtensor.py | 34 ++-- 11 files changed, 310 insertions(+), 274 deletions(-) delete mode 100644 float8_experimental/float8_dynamic_utils.py create mode 100644 float8_experimental/float8_scaling_utils.py diff --git a/benchmarks/bench_padding.py b/benchmarks/bench_padding.py index af036d6..e02d2d5 100644 --- a/benchmarks/bench_padding.py +++ b/benchmarks/bench_padding.py @@ -4,7 +4,12 @@ import fire import torch -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, + ToFloat8ConstrFunc, +) from float8_experimental.float8_utils import pad_tensor_for_matmul from tabulate import tabulate from torch._inductor.utils import do_bench_using_profiling @@ -50,9 +55,25 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype): b_config = ScaledMMConfig( emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True ) - - a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config) - b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config) + a_config = LinearMMConfig(a_config, a_config, a_config) + b_config = LinearMMConfig(b_config, b_config, b_config) + + a_fp8 = ToFloat8ConstrFunc.apply( + A, + scale_a, + fp8_dtype, + None, # amax_buffer + a_config, + GemmInputRole.INPUT, + ) + b_fp8 = ToFloat8ConstrFunc.apply( + B, + scale_b, + fp8_dtype, + None, # amax_buffer + b_config, + GemmInputRole.WEIGHT, + ) return a_fp8 @ b_fp8 diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py deleted file mode 100644 index bfacd65..0000000 --- a/float8_experimental/float8_dynamic_utils.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from float8_experimental.float8_tensor import ( - Float8Tensor, - GemmInputRole, - LinearMMConfig, - tensor_already_casted_to_fp8, - to_fp8_no_autograd, -) -from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale - - -@torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2Bw(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2, initialize if needed - """ - - @staticmethod - def forward( - ctx, - tensor, - linear_mm_config: LinearMMConfig, - ): - ctx.linear_mm_config = linear_mm_config - return tensor - - @staticmethod - def backward(ctx, gradY): - if tensor_already_casted_to_fp8(gradY): - return gradY, None - gradY_scale = tensor_to_scale(gradY, e5m2_dtype) - fp8_tensor = to_fp8_no_autograd( - gradY, - gradY_scale, - e5m2_dtype, - linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - ) - return fp8_tensor, None - - -def cast_to_float8_e4m3_dynamic( - inpt_tensor: torch.Tensor, - linear_mm_config: LinearMMConfig, - reduce_amax: bool = False, - gemm_input_role: GemmInputRole = GemmInputRole.INPUT, -) -> Float8Tensor: - if tensor_already_casted_to_fp8(inpt_tensor): - return inpt_tensor - scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) - return Float8Tensor.to_float8( - inpt_tensor, - scale, - e4m3_dtype, - linear_mm_config=linear_mm_config, - gemm_input_role=gemm_input_role, - ) - - -def cast_to_float8_e5m2_dynamic_bw( - gradY: torch.Tensor, linear_mm_config: LinearMMConfig -) -> torch.Tensor: - return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 81a5c52..6e184c2 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -16,9 +16,12 @@ from float8_experimental.config import Float8LinearConfig, ScalingType -from float8_experimental.float8_dynamic_utils import ( +from float8_experimental.float8_scaling_utils import ( + _maybe_initialize_amaxes_scales_for_float8_cast, + cast_to_float8_delayed, cast_to_float8_e4m3_dynamic, - cast_to_float8_e5m2_dynamic_bw, + NoopFwToFloat8E5M2BwDelayed, + NoopFwToFloat8E5M2BwDynamic, ) from float8_experimental.float8_tensor import ( @@ -26,15 +29,9 @@ GemmInputRole, LinearMMConfig, ScaledMMConfig, - to_fp8_no_autograd, ) -from float8_experimental.float8_utils import ( - amax_history_to_scale, - e4m3_dtype, - e5m2_dtype, - tensor_to_amax, -) +from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax from float8_experimental.fsdp_utils import ( WeightWithDelayedFloat8CastTensor, @@ -42,35 +39,6 @@ ) -def _maybe_initialize_amaxes_scales_for_float8_cast( - x, - cur_amax, - amax_history, - scale, - scale_fn_name, - float8_dtype, - is_initialized, - reduce_amax, -): - """ - If x is about to be cast to `float8` and the amax buffers are not initialized, - initializes them inplace. - """ - if is_initialized: - return - with torch.no_grad(): - # Note: we need to enable distributed reduction here in order - # to match numerics between single GPU and multi GPU code for - # activations and gradients - new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) - cur_amax.fill_(new_amax) - amax_history[0] = new_amax - new_scale = amax_history_to_scale( - amax_history, float8_dtype, x.dtype, scale_fn_name - ) - scale.copy_(new_scale) - - # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files @torch._dynamo.allow_in_graph class manual_float8_matmul(torch.autograd.Function): @@ -127,66 +95,6 @@ def backward(ctx, grad_output_fp8): return grad_input, grad_weight.t() -@torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2Bw(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2, initialize if needed - """ - - @staticmethod - def forward( - ctx, - tensor, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - is_amax_initialized, - linear_mm_config: LinearMMConfig, - ): - ctx.save_for_backward( - fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output - ) - ctx.scale_fn_name = scale_fn_name - ctx.is_amax_initialized = is_amax_initialized - ctx.linear_mm_config = linear_mm_config - return tensor - - @staticmethod - def backward(ctx, go): - ( - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - ) = ctx.saved_tensors - scale_fn_name = ctx.scale_fn_name - is_amax_initialized = ctx.is_amax_initialized - - _maybe_initialize_amaxes_scales_for_float8_cast( - go, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - e5m2_dtype, - is_amax_initialized, - reduce_amax=True, - ) - - fp8_amax_grad_output.fill_(tensor_to_amax(go)) - - res = to_fp8_no_autograd( - go, - fp8_scale_grad_output, - e5m2_dtype, - linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, - ) - empty_grads = None, None, None, None, None, None - return res, *empty_grads - - class Float8Linear(torch.nn.Linear): """ Note: this is **not** a public API and is only intended to be used @@ -352,7 +260,7 @@ def cast_input_to_float8( is_amax_initialized, reduce_amax=True, ) - input_fp8 = Float8Tensor.to_float8( + input_fp8 = cast_to_float8_delayed( input, self.fp8_scale_input, e4m3_dtype, @@ -384,7 +292,7 @@ def cast_weight_to_float8( reduce_amax=False, ) - weight_fp8 = Float8Tensor.to_float8( + weight_fp8 = cast_to_float8_delayed( weight, self.fp8_scale_weight, e4m3_dtype, @@ -407,7 +315,7 @@ def cast_weight_to_float8( def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8E5M2Bw.apply( + output = NoopFwToFloat8E5M2BwDelayed.apply( output, self.fp8_amax_grad_output, self.fp8_amax_history_grad_output, @@ -418,7 +326,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: ) else: assert self.scaling_type_grad_output is ScalingType.DYNAMIC - output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config) + output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config) return output def float8_pre_forward(self, input): diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py new file mode 100644 index 0000000..81910e2 --- /dev/null +++ b/float8_experimental/float8_scaling_utils.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for scaling high precision tensors to float8. +""" + +from typing import Optional + +import torch + +from float8_experimental.float8_tensor import ( + Float8Tensor, + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, + tensor_already_casted_to_fp8, + to_fp8_no_autograd, + ToFloat8ConstrFunc, +) + +from float8_experimental.float8_utils import ( + amax_history_to_scale, + e4m3_dtype, + e5m2_dtype, + tensor_to_amax, + tensor_to_scale, +) + + +def cast_to_float8_e4m3_dynamic( + inpt_tensor: torch.Tensor, + linear_mm_config: LinearMMConfig, + reduce_amax: bool = False, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, +) -> Float8Tensor: + if tensor_already_casted_to_fp8(inpt_tensor): + return inpt_tensor + scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) + return ToFloat8ConstrFunc.apply( + inpt_tensor, + scale, + e4m3_dtype, + None, # amax_buffer + linear_mm_config, + gemm_input_role, + ) + + +# TODO(future PR): align name with cast_to_float8_e4m3_dynamic +def cast_to_float8_delayed( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype, + amax_buffer: torch.Tensor, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, +): + return ToFloat8ConstrFunc.apply( + tensor, + scale, + float8_dtype, + amax_buffer, + linear_mm_config, + gemm_input_role, + ) + + +def _maybe_initialize_amaxes_scales_for_float8_cast( + x, + cur_amax, + amax_history, + scale, + scale_fn_name, + float8_dtype, + is_initialized, + reduce_amax, +): + """ + If x is about to be cast to `float8` and the amax buffers are not initialized, + initializes them inplace. + """ + if is_initialized: + return + with torch.no_grad(): + # Note: we need to enable distributed reduction here in order + # to match numerics between single GPU and multi GPU code for + # activations and gradients + new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) + cur_amax.fill_(new_amax) + amax_history[0] = new_amax + new_scale = amax_history_to_scale( + amax_history, float8_dtype, x.dtype, scale_fn_name + ) + scale.copy_(new_scale) + + +@torch._dynamo.allow_in_graph +class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function): + """ + Forward: no-op + Backward: convert to float8_e5m2 with delayed scaling, initialize if needed + """ + + @staticmethod + def forward( + ctx, + tensor, + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, + scale_fn_name, + is_amax_initialized, + linear_mm_config: LinearMMConfig, + ): + ctx.save_for_backward( + fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output + ) + ctx.scale_fn_name = scale_fn_name + ctx.is_amax_initialized = is_amax_initialized + ctx.linear_mm_config = linear_mm_config + return tensor + + @staticmethod + def backward(ctx, go): + ( + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, + ) = ctx.saved_tensors + scale_fn_name = ctx.scale_fn_name + is_amax_initialized = ctx.is_amax_initialized + + _maybe_initialize_amaxes_scales_for_float8_cast( + go, + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, + scale_fn_name, + e5m2_dtype, + is_amax_initialized, + reduce_amax=True, + ) + + fp8_amax_grad_output.fill_(tensor_to_amax(go)) + + res = to_fp8_no_autograd( + go, + fp8_scale_grad_output, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + ) + empty_grads = None, None, None, None, None, None + return res, *empty_grads + + +@torch._dynamo.allow_in_graph +class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function): + """ + Forward: no-op + Backward: convert to float8_e5m2 with dynamic scaling + """ + + @staticmethod + def forward( + ctx, + tensor, + linear_mm_config: LinearMMConfig, + ): + ctx.linear_mm_config = linear_mm_config + return tensor + + @staticmethod + def backward(ctx, gradY): + if tensor_already_casted_to_fp8(gradY): + return gradY, None + gradY_scale = tensor_to_scale(gradY, e5m2_dtype) + fp8_tensor = to_fp8_no_autograd( + gradY, + gradY_scale, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.GRAD_OUTPUT, + ) + return fp8_tensor, None diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index a46e7ce..fd37482 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -128,6 +128,7 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: return False +# TODO: rename to hp_tensor_and_scale_to_float8_tensor def to_fp8_no_autograd( x: torch.Tensor, x_scale: torch.Tensor, @@ -341,37 +342,6 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride def to_original_precision(self): return FromFloat8ConstrFunc.apply(self) - @staticmethod - @torch._dynamo.allow_in_graph - def to_float8( - tensor: torch.Tensor, - scale: torch.Tensor, - float8_dtype: torch.dtype, - amax_buffer: Optional[torch.Tensor] = None, - linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, - ): - """Converts a higher precision tensor to float8 in a differentiable way. - - Args: - tensor: the tensor to convert - scale: the scale to use to convert the tensor - float8_dtype: the float8 dtype to use - amax_buffer: a buffer to store the amax value in prior to conversion - linearmm_config: Defines the configuration for 3 gemms in fwd/bwd of linear - - Returns: - Float8Tensor: a float8 tensor - """ - return ToFloat8ConstrFunc.apply( - tensor, - scale, - float8_dtype, - amax_buffer, - linear_mm_config, - gemm_input_role, - ) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # 1. tracing through __torch_function__ logic is not supported yet in diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index eea7376..54127af 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn from float8_experimental.config import ScalingType -from float8_experimental.float8_dynamic_utils import ( +from float8_experimental.float8_scaling_utils import ( cast_to_float8_e4m3_dynamic, - cast_to_float8_e5m2_dynamic_bw, + NoopFwToFloat8E5M2BwDynamic, ) from float8_experimental.float8_tensor import GemmInputRole from torch.distributed._tensor import DTensor @@ -67,7 +67,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -119,7 +119,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 5fbefc9..ca0812c 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -10,12 +10,13 @@ import torch import torch.nn as nn import torch.utils._pytree as pytree -from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic +from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, + ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import e4m3_dtype, EPS @@ -163,12 +164,13 @@ def __repr__(self): def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: - float8_tensor = Float8Tensor.to_float8( + float8_tensor = ToFloat8ConstrFunc.apply( self._tensor, self._precomputed_scale, torch.float8_e4m3fn, - linear_mm_config=self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, + None, # amax_buffer + self._linear_mm_config, + GemmInputRole.WEIGHT, ) else: float8_tensor = cast_to_float8_e4m3_dynamic( @@ -355,13 +357,13 @@ def fsdp_pre_all_gather(self, mesh): # 2. populate `_amax_buffer` inplace # TODO(future PR): clean up all the casting functions and clearly # separate dynamic vs delayed, tech debt has accumulated - float8_tensor = Float8Tensor.to_float8( + float8_tensor = ToFloat8ConstrFunc.apply( self._tensor, self._scale_buffer, e4m3_dtype, self._amax_buffer, self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, + GemmInputRole.WEIGHT, ) return (float8_tensor._data,), (float8_tensor._scale,) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 0c10589..4807ac6 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -22,7 +22,7 @@ LinearMMConfig, ScaledMMConfig, tensor_already_casted_to_fp8, - to_fp8_no_autograd, + ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale @@ -127,12 +127,13 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: self.weight, Float8Tensor ), "Weight has already been quantized, cannot quantize again." scale = tensor_to_scale(self.weight, dtype) - quantized_weight = to_fp8_no_autograd( + quantized_weight = ToFloat8ConstrFunc.apply( self.weight, scale, dtype, + None, # amax_buffer self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, + GemmInputRole.WEIGHT, ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False @@ -200,12 +201,13 @@ def cast_to_float8_e4m3_inference( if static_quantization_scale is not None else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) ) - return Float8Tensor.to_float8( + return ToFloat8ConstrFunc.apply( inpt_tensor, scale, e4m3_dtype, - linear_mm_config=linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, + None, # amax_buffer + linear_mm_config, + GemmInputRole.INPUT, ) diff --git a/test/test_base.py b/test/test_base.py index 2f7c717..b1f4781 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -17,7 +17,6 @@ import torch.nn as nn from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType -from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( convert_to_float8_training, @@ -25,11 +24,13 @@ sync_float8_amax_and_scale_history, ) from float8_experimental.float8_python_api import addmm_float8_unwrapped +from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, ScaledMMConfig, + ToFloat8ConstrFunc, ) from float8_experimental.float8_utils import ( compute_error, @@ -65,7 +66,7 @@ def test_preserves_dtype(self) -> None: for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.randn(4, 4, dtype=hp_dtype) x1_s = tensor_to_scale(x1_hp, lp_dtype) - x2_lp = Float8Tensor.to_float8(x1_hp, x1_s, lp_dtype) + x2_lp = ToFloat8ConstrFunc.apply(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() self.assertTrue(x3_hp.dtype == hp_dtype) @@ -75,7 +76,7 @@ def test_differentiable_casts(self) -> None: x = torch.randn(1).requires_grad_() grad = torch.randn(1) x_s = tensor_to_scale(x, f8_dtype) - x_f8 = Float8Tensor.to_float8(x, x_s, f8_dtype) + x_f8 = ToFloat8ConstrFunc.apply(x, x_s, f8_dtype) x_f8_hp = x_f8.to_original_precision() x_f8_hp.backward(grad) # the gradient should be unchanged through both casts @@ -84,7 +85,7 @@ def test_differentiable_casts(self) -> None: def test_split_cat(self): a = torch.rand(16, 16, dtype=torch.bfloat16) scale = tensor_to_scale(a, e4m3_dtype) - fp8_a = Float8Tensor.to_float8(a, scale, e4m3_dtype) + fp8_a = ToFloat8ConstrFunc.apply(a, scale, e4m3_dtype) splits = torch.split(fp8_a, 16) catted = torch.cat(splits, dim=0) @@ -93,14 +94,14 @@ def test_split_cat(self): def test_index_put(self): a = torch.rand(16, dtype=torch.bfloat16) scale_a = tensor_to_scale(a, torch.float8_e4m3fn) - fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) + fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn) index = torch.randint(0, 15, (16,), dtype=torch.long) b = torch.rand(16, 16, dtype=torch.bfloat16) scale_b = tensor_to_scale(b, torch.float8_e4m3fn) - fp8_b = Float8Tensor.to_float8(b, scale_a, torch.float8_e4m3fn) - fp8_b_bad = Float8Tensor.to_float8(b, scale_b, torch.float8_e4m3fn) + fp8_b = ToFloat8ConstrFunc.apply(b, scale_a, torch.float8_e4m3fn) + fp8_b_bad = ToFloat8ConstrFunc.apply(b, scale_b, torch.float8_e4m3fn) with self.assertRaises(AssertionError): b[index] = fp8_a @@ -111,7 +112,7 @@ def test_index_put(self): def test_copy_(self): a = torch.rand(16, dtype=torch.bfloat16) scale_a = tensor_to_scale(a, torch.float8_e4m3fn) - fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) + fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn) b = torch.empty(16, dtype=torch.bfloat16) b.copy_(fp8_a) # Should work @@ -406,8 +407,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) - b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) + a_fp8 = ToFloat8ConstrFunc.apply(a, a_scale, input_dtype) + b_fp8 = ToFloat8ConstrFunc.apply(b, b_scale, input_dtype) out_scaled_mm = addmm_float8_unwrapped( a_fp8._data, @@ -446,19 +447,21 @@ def test_different_configs_error(self): ScaledMMConfig(True, False, False, False), ScaledMMConfig(True, False, False, False), ) - a = Float8Tensor.to_float8( + a = ToFloat8ConstrFunc.apply( x_fp32, x_scale, fp8_dtype, - linear_mm_config=linear_config_a, - gemm_input_role=GemmInputRole.INPUT, + None, # amax_buffer + linear_config_a, + GemmInputRole.INPUT, ) - b = Float8Tensor.to_float8( + b = ToFloat8ConstrFunc.apply( x_fp32, x_scale, fp8_dtype, - linear_mm_config=linear_config_b, - gemm_input_role=GemmInputRole.WEIGHT, + None, # amax_buffer + linear_config_b, + GemmInputRole.WEIGHT, ) with pytest.raises( AssertionError, @@ -485,11 +488,11 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = Float8Tensor.to_float8( - a, a_scale, input_dtype, gemm_input_role=GemmInputRole.INPUT + a_fp8 = ToFloat8ConstrFunc.apply( + a, a_scale, input_dtype, None, None, GemmInputRole.INPUT ) - b_fp8 = Float8Tensor.to_float8( - b, b_scale, input_dtype, gemm_input_role=GemmInputRole.WEIGHT + b_fp8 = ToFloat8ConstrFunc.apply( + b, b_scale, input_dtype, None, None, GemmInputRole.WEIGHT ) with pytest.raises( @@ -505,19 +508,21 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): scaled_mm_config, scaled_mm_config, scaled_mm_config ) - a_fp8 = Float8Tensor.to_float8( + a_fp8 = ToFloat8ConstrFunc.apply( a, a_scale, input_dtype, - linear_mm_config=pad_config, - gemm_input_role=GemmInputRole.INPUT, + None, # amax_buffer + pad_config, + GemmInputRole.INPUT, ) - b_fp8 = Float8Tensor.to_float8( + b_fp8 = ToFloat8ConstrFunc.apply( b, b_scale, input_dtype, - linear_mm_config=pad_config, - gemm_input_role=GemmInputRole.WEIGHT, + None, # amax_buffer + pad_config, + GemmInputRole.WEIGHT, ) out_padded = a_fp8 @ b_fp8 out_padded.to(compare_type) @@ -528,19 +533,21 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): emulated_scaled_mm_config, emulated_scaled_mm_config, ) - a_fp8 = Float8Tensor.to_float8( + a_fp8 = ToFloat8ConstrFunc.apply( a, a_scale, input_dtype, - linear_mm_config=emulated_config, - gemm_input_role=GemmInputRole.INPUT, + None, # amax_buffer + emulated_config, + GemmInputRole.INPUT, ) - b_fp8 = Float8Tensor.to_float8( + b_fp8 = ToFloat8ConstrFunc.apply( b, b_scale, input_dtype, - linear_mm_config=emulated_config, - gemm_input_role=GemmInputRole.WEIGHT, + None, # amax_buffer + emulated_config, + GemmInputRole.WEIGHT, ) out_emualted = a_fp8 @ b_fp8 out_emualted.to(compare_type) @@ -694,19 +701,19 @@ def test_fp8_tensor_statistics(self): # Overflow caused by a too large scaling factor s_overflow = torch.tensor(1e9) - fp8_overflow = Float8Tensor.to_float8(x1_hp, s_overflow, lp_dtype) + fp8_overflow = ToFloat8ConstrFunc.apply(x1_hp, s_overflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (0, tensor_len)) # Underflow caused by a too small scaling factor s_underflow = torch.tensor(1e-9) - fp8_underflow = Float8Tensor.to_float8(x1_hp, s_underflow, lp_dtype) + fp8_underflow = ToFloat8ConstrFunc.apply(x1_hp, s_underflow, lp_dtype) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype) self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0)) # Both overflow and underflow x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0) - fp8_over_underflow = Float8Tensor.to_float8( + fp8_over_underflow = ToFloat8ConstrFunc.apply( x2_hp, torch.tensor(1.0), lp_dtype ) (zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype) diff --git a/test/test_compile.py b/test/test_compile.py index a71b879..7f6ab68 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -20,7 +20,7 @@ get_float8_layers, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_tensor import Float8Tensor, LinearMMConfig +from float8_experimental.float8_tensor import LinearMMConfig, ToFloat8ConstrFunc from float8_experimental.float8_utils import e4m3_dtype from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -178,7 +178,7 @@ def __init__(self, graph_break: bool): self.graph_break = graph_break def forward(self, x): - x_fp8 = Float8Tensor.to_float8( + x_fp8 = ToFloat8ConstrFunc.apply( x, self.fp8_scale_x, e4m3_dtype, diff --git a/test/test_dtensor.py b/test/test_dtensor.py index eeca6df..cfa0445 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -14,13 +14,14 @@ import torch.nn as nn import torch.nn.functional as F from float8_experimental import Float8LinearConfig - -from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw from float8_experimental.float8_linear_utils import convert_to_float8_training + +from float8_experimental.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, + ToFloat8ConstrFunc, ) from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, @@ -86,11 +87,11 @@ def test_scaled_mm(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() - x_fp8 = Float8Tensor.to_float8( - x_fp32, x_scale, fp8_dtype, gemm_input_role=GemmInputRole.INPUT + x_fp8 = ToFloat8ConstrFunc.apply( + x_fp32, x_scale, fp8_dtype, None, None, GemmInputRole.INPUT ) - y_fp8 = Float8Tensor.to_float8( - y_fp32, y_scale, fp8_dtype, gemm_input_role=GemmInputRole.WEIGHT + y_fp8 = ToFloat8ConstrFunc.apply( + y_fp32, y_scale, fp8_dtype, None, None, GemmInputRole.WEIGHT ) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) @@ -116,7 +117,7 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16): x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() - x_fp8 = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) + x_fp8 = ToFloat8ConstrFunc.apply(x_fp32, x_scale, fp8_dtype) dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [Shard(0)], run_check=False) out_dist = dist_x_fp8.redistribute(placements=[Replicate()]) @@ -144,7 +145,7 @@ def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() assert isinstance(dist_x_scale, DTensor) - dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) + dist_x_fp8 = ToFloat8ConstrFunc.apply(dist_x_fp32, dist_x_scale, fp8_dtype) assert isinstance(dist_x_fp8, DTensor) @@ -163,18 +164,25 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() dist_target = distribute_tensor(target, mesh, [Shard(0)]) - dist_x_fp8 = Float8Tensor.to_float8( - dist_x_fp32, dist_x_scale, fp8_dtype, gemm_input_role=GemmInputRole.INPUT + dist_x_fp8 = ToFloat8ConstrFunc.apply( + dist_x_fp32, + dist_x_scale, + fp8_dtype, + None, + None, + GemmInputRole.INPUT, ) - dist_weight_fp8 = Float8Tensor.to_float8( + dist_weight_fp8 = ToFloat8ConstrFunc.apply( dist_wight_fp32, dist_weight_scale, fp8_dtype, - gemm_input_role=GemmInputRole.WEIGHT, + None, + None, + GemmInputRole.WEIGHT, ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2Bw.apply(out, LinearMMConfig()) + out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig()) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward()