From e13014494e299398c368d76b9ed9783a456daf96 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 09:16:48 -0700 Subject: [PATCH 1/2] [3/x] clean up casting functions: delete to_fp8_no_autograd Summary: `ToFloat8ConstrFunc` was just calling `to_fp8_no_autograd`, unify them to reduce confusion. We can rename the function in a future PR, keeping PRs small for now. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_scaling_utils.py | 13 +-- float8_experimental/float8_tensor.py | 117 ++++++++------------ 2 files changed, 51 insertions(+), 79 deletions(-) diff --git a/float8_experimental/float8_scaling_utils.py b/float8_experimental/float8_scaling_utils.py index c1a58fb..f319e75 100644 --- a/float8_experimental/float8_scaling_utils.py +++ b/float8_experimental/float8_scaling_utils.py @@ -18,7 +18,6 @@ LinearMMConfig, ScaledMMConfig, tensor_already_casted_to_fp8, - to_fp8_no_autograd, ToFloat8ConstrFunc, ) @@ -146,12 +145,12 @@ def backward(ctx, go): fp8_amax_grad_output.fill_(tensor_to_amax(go)) - res = to_fp8_no_autograd( + res = ToFloat8ConstrFunc.apply( go, fp8_scale_grad_output, e5m2_dtype, - linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, + ctx.linear_mm_config, + GemmInputRole.GRAD_OUTPUT, ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -178,11 +177,11 @@ def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): return gradY, None gradY_scale = tensor_to_scale(gradY, e5m2_dtype) - fp8_tensor = to_fp8_no_autograd( + fp8_tensor = ToFloat8ConstrFunc.apply( gradY, gradY_scale, e5m2_dtype, - linear_mm_config=ctx.linear_mm_config, - gemm_input_role=GemmInputRole.GRAD_OUTPUT, + ctx.linear_mm_config, + GemmInputRole.GRAD_OUTPUT, ) return fp8_tensor, None diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 428d5c9..a3c6ab3 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -128,71 +128,6 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: return False -# TODO: rename to hp_tensor_and_scale_to_float8_tensor -def to_fp8_no_autograd( - x: torch.Tensor, - x_scale: torch.Tensor, - float8_dtype: torch.dtype, - linear_mm_config: Optional[LinearMMConfig], - gemm_input_role: Optional[GemmInputRole], -) -> "Float8Tensor": - """Convert a tensor to float8 without autograd - This is used in multiple places in the codebase to convert a tensor to float8 - - This function will apply the scaling, and then convert to a Float8Tensor - - Note: - We will call this function with a DTensor subclass. Ideally this would be an aten OP - that DTensor could overload to ensure proper semantics. There are some techincal issues - with that composing with FakeTensor, so we special case here. - - DTensor Invariant: DTensor must always be the outer most tensor subclass - - Args: - x: the tensor to convert - scale: the scale to use to convert the tensor - float8_dtype: the float8 dtype to use - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (x, w or dL_dY) in - the 3 fwd/bwd gemms of linear - """ - x_scaled = x * x_scale - bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) - - if isinstance(bits_fp8, DTensor): - assert isinstance( - x, DTensor - ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor" - bits_mesh = bits_fp8.device_mesh - bits_placements = bits_fp8.placements - local_bits = bits_fp8.to_local() - local_scale = x_scale.to_local() - inner_float8_tensor = Float8Tensor( - local_bits, - local_scale, - x.dtype, - linear_mm_config=linear_mm_config, - gemm_input_role=gemm_input_role, - ) - return DTensor.from_local( - inner_float8_tensor, - bits_mesh, - bits_placements, - run_check=False, - shape=bits_fp8.size(), - stride=bits_fp8.stride(), - ) - - return Float8Tensor( - bits_fp8, - x_scale, - x.dtype, - linear_mm_config=linear_mm_config, - gemm_input_role=gemm_input_role, - ) - - @torch._dynamo.allow_in_graph class ToFloat8ConstrFunc(torch.autograd.Function): """ @@ -210,18 +145,56 @@ def forward( linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): - """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. - Args + """ + This function will apply the scaling, and then convert to a Float8Tensor + + Note: + We will call this function with a DTensor subclass. Ideally this would be an aten OP + that DTensor could overload to ensure proper semantics. There are some techincal issues + with that composing with FakeTensor, so we special case here. + + DTensor Invariant: DTensor must always be the outer most tensor subclass + + Args: tensor: the tensor to convert scale: the scale to use to convert the tensor - float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn - emulate: whether to emulate the matmuls in fp32 + float8_dtype: the float8 dtype to use + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in + the 3 fwd/bwd gemms of linear """ + tensor_scaled = tensor * scale + bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) + + if isinstance(bits_fp8, DTensor): + assert isinstance( + x, DTensor + ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor" + bits_mesh = bits_fp8.device_mesh + bits_placements = bits_fp8.placements + local_bits = bits_fp8.to_local() + local_scale = scale.to_local() + inner_float8_tensor = Float8Tensor( + local_bits, + local_scale, + tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) + return DTensor.from_local( + inner_float8_tensor, + bits_mesh, + bits_placements, + run_check=False, + shape=bits_fp8.size(), + stride=bits_fp8.stride(), + ) - return to_fp8_no_autograd( - tensor, + return Float8Tensor( + bits_fp8, scale, - float8_dtype, + tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, ) From 03c91196ecd313b4c74da290ce6c48f7943f1401 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 11:00:25 -0700 Subject: [PATCH 2/2] Update on "[3/x] clean up casting functions: delete to_fp8_no_autograd" Summary: `ToFloat8ConstrFunc` was just calling `to_fp8_no_autograd`, unify them to reduce confusion. We can rename the function in a future PR, keeping PRs small for now. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60292694](https://our.internmc.facebook.com/intern/diff/D60292694) [ghstack-poisoned] --- float8_experimental/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index a3c6ab3..62ce38d 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -169,7 +169,7 @@ def forward( if isinstance(bits_fp8, DTensor): assert isinstance( - x, DTensor + scale, DTensor ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor" bits_mesh = bits_fp8.device_mesh bits_placements = bits_fp8.placements