-
Notifications
You must be signed in to change notification settings - Fork 204
/
Copy pathstateful_float8_linear.py
439 lines (392 loc) · 17.1 KB
/
stateful_float8_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Stateful version of Float8Linear, created to keep Float8Linear simple and
only require code readers to read the stateful code if they care about delayed
or static scaling.
"""
from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
from torchao.float8.config import Float8LinearConfig, ScalingType
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_linear import (
Float8Linear,
)
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8BwDelayed,
NoopFwToFloat8BwDynamic,
NoopFwToFloat8BwStatic,
_maybe_initialize_amaxes_scales_for_float8_cast,
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
hp_tensor_to_float8_static,
)
from torchao.float8.float8_tensor import (
GemmInputRole,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
tensor_to_amax,
tensor_to_scale,
)
from torchao.float8.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
WeightWithStaticFloat8CastTensor,
)
@torch._dynamo.allow_in_graph
class manual_float8_matmul_with_args_in_float8(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in float8
Note: this function requires all arguments to already be Float8Tensor objects,
which only supports tensorwise scaling granularity. The reason we didn't just make this
function support axiswise scaling granularity is because that would need very
careful testing of delayed scaling, as delayed scaling modifies buffers inplace.
In the future we'll probably have to unify, just postponing that until a future PR.
"""
@staticmethod
def forward(
ctx,
input_fp8,
weight_fp8_t,
):
ctx.save_for_backward(input_fp8, weight_fp8_t)
# the reshapes are needed in order to make the shapes compatible with
# torch.mm
orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits
@staticmethod
def backward(ctx, grad_output_fp8):
input_fp8, weight_fp8_t = ctx.saved_tensors
# the reshapes are needed in order to make the shapes compatible with
# torch.mm
grad_output_fp8_orig_shape = grad_output_fp8.shape
grad_output_fp8_reshaped = grad_output_fp8.reshape(
-1, grad_output_fp8_orig_shape[-1]
)
# calculate grad_input
grad_input = torch.mm(
grad_output_fp8_reshaped,
weight_fp8_t.t(),
)
grad_input = grad_input.reshape(
*grad_output_fp8_orig_shape[:-1], grad_input.shape[-1]
)
input_fp8_orig_shape = input_fp8.shape
input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1])
# calculate grad_weight
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
# compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped`
grad_weight = torch.mm(
grad_output_fp8_reshaped.t(),
input_fp8_reshaped,
)
return grad_input, grad_weight.t()
class StatefulFloat8Linear(Float8Linear):
def __init__(self, *args, **kwargs):
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
super().__init__(*args, **kwargs)
# Convenience flag to skip code related to delayed scaling
self.has_any_delayed_scaling = (
self.scaling_type_input is ScalingType.DELAYED
or self.scaling_type_weight is ScalingType.DELAYED
or self.scaling_type_grad_output is ScalingType.DELAYED
)
self.create_buffers()
# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
self.is_amax_initialized = not self.config.enable_amax_init
# pre_forward and post_forward are currently broken with FSDP
# and torch.compile, this option can disable them
# Note that when using `self.config.enable_pre_and_post_forward = False`,
# it's recommended to also set `self.config.enable_amax_init = False`.
# Otherwise, the amax buffer would never be marked as initialized and
# would be initialized in every iteration.
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward
def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
device = self.weight.device
default_input = torch.finfo(self.config.cast_config_input.target_dtype).max
default_weight = torch.finfo(self.config.cast_config_weight.target_dtype).max
default_grad_output = torch.finfo(
self.config.cast_config_grad_output.target_dtype
).max
# Note: for now, create all the buffers if any are needed, to postpone
# the work to make the scale and amax syncing and history calculation
# handle a heterogeneous setup. We can do that work later if benchmarks
# show it is worth doing.
if self.has_any_delayed_scaling:
self.register_always_float32_buffer(
"fp8_amax_input", torch.tensor([default_input], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_input", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_input", torch.tensor([1.0], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_weight", torch.tensor([default_weight], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_weight", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_weight", torch.tensor([1.0], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_grad_output",
torch.tensor([default_grad_output], device=device),
)
self.register_always_float32_buffer(
"fp8_amax_history_grad_output", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_grad_output", torch.tensor([1.0], device=device)
)
if self.config.cast_config_input.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_input",
self.config.cast_config_input.static_scale.to(device),
)
if self.config.cast_config_weight.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_weight",
self.config.cast_config_weight.static_scale.to(device),
)
if self.config.cast_config_grad_output.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_grad_output",
self.config.cast_config_grad_output.static_scale.to(device),
)
def register_always_float32_buffer(
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
) -> None:
self.register_buffer(name=name, tensor=tensor, persistent=persistent)
self.always_float32_buffers.add(name)
def _apply(self, fn, recurse=True):
ret = super()._apply(fn, recurse)
self.convert_amax_buffer_to_float32()
return ret
def convert_amax_buffer_to_float32(self):
for key in self.always_float32_buffers:
if self._buffers[key] is not None:
self._buffers[key] = self._buffers[key].to(torch.float32)
def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
is_amax_initialized = self.is_amax_initialized
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)
if tensor_already_casted_to_fp8(input):
input_fp8 = input
elif self.scaling_type_input is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
input,
self.fp8_amax_input,
self.fp8_amax_history_input,
self.fp8_scale_input,
scale_fn_name,
self.config.cast_config_input.target_dtype,
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = hp_tensor_to_float8_delayed(
input,
self.fp8_scale_input,
self.config.cast_config_input.target_dtype,
self.fp8_amax_input,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
elif self.scaling_type_input is ScalingType.DYNAMIC:
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is ScalingType.STATIC
input_fp8 = hp_tensor_to_float8_static(
input,
self.fp8_static_scale_input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
)
return input_fp8
def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if tensor_already_casted_to_fp8(weight):
return None
if self.scaling_type_weight is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
self.config.cast_config_weight.target_dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
return self.fp8_scale_weight
elif self.scaling_type_weight is ScalingType.DYNAMIC:
return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype)
else:
assert self.scaling_type_weight is ScalingType.STATIC
return self.fp8_static_scale_weight
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8BwDelayed.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
self.fp8_scale_grad_output,
scale_fn_name,
self.is_amax_initialized,
self.linear_mm_config,
self.config.cast_config_grad_output.target_dtype,
)
elif self.scaling_type_grad_output is ScalingType.DYNAMIC:
output = NoopFwToFloat8BwDynamic.apply(
output,
self.linear_mm_config,
self.config.cast_config_grad_output.target_dtype,
)
else:
assert self.scaling_type_grad_output is ScalingType.STATIC
output = NoopFwToFloat8BwStatic.apply(
output,
self.fp8_static_scale_grad_output,
self.linear_mm_config,
self.config.cast_config_grad_output.target_dtype,
)
return output
def cast_weight_to_float8_t(
self,
weight: torch.Tensor,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
self.config.cast_config_weight.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.has_any_delayed_scaling:
self.float8_pre_forward(input)
input_fp8 = self.cast_input_to_float8(input)
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = self.get_weight_scale(self.weight)
if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
self.cast_weight_to_float8_t,
self.weight,
weight_scale,
)
else:
weight_fp8_t = self.cast_weight_to_float8_t(self.weight, weight_scale)
output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t)
# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
if self.bias is not None:
output = output + self.bias.to(output.dtype)
if self.has_any_delayed_scaling:
self.float8_post_forward()
return output
def float8_pre_forward(self, input):
# TODO(future PR): deprecate these functions and the corresponding
# config setting
if not self.enable_pre_and_post_forward:
return
def float8_post_forward(self):
# TODO(future PR): deprecate these functions and the corresponding
# config setting
if not self.enable_pre_and_post_forward:
return
@classmethod
def from_float(
cls,
mod,
config: Optional[Float8LinearConfig] = None,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
config (Optional[Float8LinearConfig]): configuration for conversion to float8
"""
if config is None:
config = Float8LinearConfig()
with torch.device("meta"):
new_mod = cls(
mod.in_features,
mod.out_features,
bias=False,
config=config,
)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
# need to create buffers again when moving from meta device to
# real device
new_mod.create_buffers()
# If FSDP float8 all-gather is on, wrap the weight in a float8-aware
# tensor subclass. This must happen last because:
# 1. weight needs to be on the correct device to create the buffers
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_float8_all_gather:
if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.target_dtype,
)
)
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED:
new_mod.weight = torch.nn.Parameter(
WeightWithDelayedFloat8CastTensor(
new_mod.weight,
new_mod.fp8_amax_weight,
new_mod.fp8_amax_history_weight,
new_mod.fp8_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.target_dtype,
new_mod.is_amax_initialized,
)
)
else:
assert config.cast_config_weight.scaling_type is ScalingType.STATIC
new_mod.weight = torch.nn.Parameter(
WeightWithStaticFloat8CastTensor(
new_mod.weight,
new_mod.fp8_static_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.target_dtype,
)
)
return new_mod