From c1487ef248bdd17be1875f4c39844a818aace093 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 16 Jul 2024 09:18:37 -0700 Subject: [PATCH 1/2] [TBD if for land] bring back torch.autograd.Function Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_linear.py | 102 ++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 7850738..787a54c 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -68,6 +68,101 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( ) scale.copy_(new_scale) +# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files +# and modified to only support dynamic scaling +@torch._dynamo.allow_in_graph +class float8_linear(torch.autograd.Function): + """ + Like F.linear, but with X and W in float8 + """ + + @staticmethod + def forward( + ctx, + x_fp8, + w_fp8, + emulate: bool, + # TODO(this PR): split config into fwd/bwd + mm_config: ScaledMMConfig, + ): + ctx.save_for_backward(x_fp8, w_fp8) + ctx.emulate = emulate + ctx.mm_config = mm_config + # orig_shape = x_fp8._data.shape + orig_shape = x_fp8.shape + # x_fp8_reshaped = Float8Tensor( + # x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config + # ) + x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1]) + + # w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config) + w_fp8_t = w_fp8.t() + + res_bits = torch.mm( + x_fp8_reshaped, w_fp8_t + ) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, go_fp8): + x_fp8, w_fp8 = ctx.saved_tensors + emulate = ctx.emulate + mm_config = ctx.mm_config + + go_fp8_orig_shape = go_fp8.shape + # go_fp8_reshaped = Float8Tensor( + # go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), + # go_fp8._scale, + # go_fp8._orig_dtype, + # mm_config, + # ) + go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1]) + + # w_fp8_t_c_t = Float8Tensor( + # w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config + # ) + w_fp8_t_c_t = w_fp8.t().contiguous().t() + + # + # calculate dL/dX + # + dL_dX = torch.mm( + go_fp8_reshaped, + w_fp8_t_c_t, + ) + dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) + + # x_fp8_orig_shape = x_fp8._data.shape + x_fp8_orig_shape = x_fp8.shape + # x_fp8_reshaped_t_c = Float8Tensor( + # x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), + # x_fp8._scale, + # x_fp8._orig_dtype, + # mm_config, + # ) + x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous() + + # go_fp8_reshaped_t_c_t = Float8Tensor( + # go_fp8_reshaped._data.t().contiguous().t(), + # go_fp8_reshaped._scale, + # go_fp8_reshaped._orig_dtype, + # mm_config, + # ) + go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() + + # + # calculate dL/dW + # + dL_dW = torch.mm( + x_fp8_reshaped_t_c, + go_fp8_reshaped_t_c_t, + ) + dL_dW = dL_dW.t() + + empty_grads = None, None, None, None, None, None, None, None, None + return dL_dX, dL_dW, *empty_grads + @torch._dynamo.allow_in_graph class NoopFwToFloat8E5M2Bw(torch.autograd.Function): @@ -394,7 +489,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) - y = torch.matmul(x_fp8, w_fp8.t()) + if not self.has_any_delayed_scaling: + emulate = False + mm_config = self.forward_config + y = float8_linear.apply(x_fp8, w_fp8, emulate, mm_config) + else: + y = torch.matmul(x_fp8, w_fp8.t()) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y) From 8505776f89131152887f1d0f27df731f81c346ab Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 16 Jul 2024 10:01:30 -0700 Subject: [PATCH 2/2] Update on "[TBD if for land] bring back torch.autograd.Function" Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. ``` # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files # and modified to only support dynamic scaling # # Why do we want a torch.autograd.Function here? Vasiliy's opinion is that # as we add more scaling granularities, keeping the scaling code close to Float8Linear # will be really useful for readability and debuggability of numerics. # # For example, a future PR to add rowwise scaling could do # # # forward # x_bf16 = ... # if scaling_granularity == ScalingGranularity.PER_TENSOR: # # we can scale the same way for fwd/bwd # x_maybe_fp8 = to_fp8(...) # else: # assert scaling_granularity == ScalingGranularity.PER_ROW: # # defer scaling to float8_mm # x_maybe_fp8 = x_bf16 # # # repeat for w # # y_bf16 = float8_mm(x_maybe_fp8, w_maybe_fp8) # # Requirements for float8_mm # - composes with DTensor, compile, autograd # - readable/debuggable # # Option 1 (this PR): float8_mm is a torch.autograd.Function # - pros # - cons # Option 2 (current code without this PR): float8_mm is an override of torch.mm # - pros # - cons # ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_linear.py | 53 ++++------------------------ 1 file changed, 7 insertions(+), 46 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 787a54c..16c1257 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -68,12 +68,13 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( ) scale.copy_(new_scale) + # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files # and modified to only support dynamic scaling @torch._dynamo.allow_in_graph -class float8_linear(torch.autograd.Function): +class float8_mm(torch.autograd.Function): """ - Like F.linear, but with X and W in float8 + Like torch.mm, but with X and W in float8 """ @staticmethod @@ -81,47 +82,24 @@ def forward( ctx, x_fp8, w_fp8, - emulate: bool, - # TODO(this PR): split config into fwd/bwd - mm_config: ScaledMMConfig, ): ctx.save_for_backward(x_fp8, w_fp8) - ctx.emulate = emulate - ctx.mm_config = mm_config - # orig_shape = x_fp8._data.shape orig_shape = x_fp8.shape - # x_fp8_reshaped = Float8Tensor( - # x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config - # ) x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1]) - # w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config) w_fp8_t = w_fp8.t() - res_bits = torch.mm( - x_fp8_reshaped, w_fp8_t - ) + res_bits = torch.mm(x_fp8_reshaped, w_fp8_t) res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) return res_bits @staticmethod def backward(ctx, go_fp8): x_fp8, w_fp8 = ctx.saved_tensors - emulate = ctx.emulate - mm_config = ctx.mm_config go_fp8_orig_shape = go_fp8.shape - # go_fp8_reshaped = Float8Tensor( - # go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), - # go_fp8._scale, - # go_fp8._orig_dtype, - # mm_config, - # ) go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1]) - # w_fp8_t_c_t = Float8Tensor( - # w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config - # ) w_fp8_t_c_t = w_fp8.t().contiguous().t() # @@ -133,22 +111,9 @@ def backward(ctx, go_fp8): ) dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) - # x_fp8_orig_shape = x_fp8._data.shape x_fp8_orig_shape = x_fp8.shape - # x_fp8_reshaped_t_c = Float8Tensor( - # x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), - # x_fp8._scale, - # x_fp8._orig_dtype, - # mm_config, - # ) x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous() - # go_fp8_reshaped_t_c_t = Float8Tensor( - # go_fp8_reshaped._data.t().contiguous().t(), - # go_fp8_reshaped._scale, - # go_fp8_reshaped._orig_dtype, - # mm_config, - # ) go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() # @@ -160,7 +125,7 @@ def backward(ctx, go_fp8): ) dL_dW = dL_dW.t() - empty_grads = None, None, None, None, None, None, None, None, None + empty_grads = (None,) return dL_dX, dL_dW, *empty_grads @@ -489,12 +454,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) - if not self.has_any_delayed_scaling: - emulate = False - mm_config = self.forward_config - y = float8_linear.apply(x_fp8, w_fp8, emulate, mm_config) - else: - y = torch.matmul(x_fp8, w_fp8.t()) + # y = float8_mm.apply(x_fp8, w_fp8) + y = float8_mm.apply(x_fp8, w_fp8) # Cast gradY to float8_e5m2 during backward y = self.cast_y_to_float8_in_bw(y)