-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather #413
Changes from 14 commits
8d00b73
68d9f61
4cd5f74
05a4a06
f48a82e
14aabfb
b88aee9
2b4e0c2
ad63aba
71d4dc6
23536e9
bdb0fd0
ef0e843
c294f6a
b58b07b
7df10ae
f674012
7dd788c
faefe27
7aad066
5040c31
cee653e
e164285
22c71ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -273,6 +273,39 @@ def build_test_list(): | |
"fsdp2_mem_tracker", | ||
ngpu=4, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--training.fp8_linear", | ||
] | ||
], | ||
"FSDP2 with bf16 all-gather", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe noob question: does all-gather always happen in bf16, or is it determined by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good question! I should change this to "all-gather in original dtype" when mixed_precision is turned off (no when mixed_precision is turned on ( |
||
"fp8_fsdp2_bf16_all_gather", | ||
ngpu=4, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--training.fp8_linear", | ||
"--training.enable_fsdp_fp8_all_gather", | ||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
], | ||
"FSDP2 with fp8 all-gather", | ||
"fp8_fsdp2_fp8_all_gather", | ||
ngpu=4, | ||
), | ||
OverrideDefinitions( | ||
[ | ||
[ | ||
"--training.fp8_linear", | ||
"--training.enable_fsdp_fp8_all_gather", | ||
"--training.precompute_float8_dynamic_scale_for_fsdp", | ||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
], | ||
"FSDP2 with fp8 all-gather and precomputed dynamic scales", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: comment for 2D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the end I have to remove 2D from this PR. current CI tokenizer has vacab size = 2556. However, fp8 gemm need the vacab size to be divisible by 16 #461 I can follow up with you on how to have a tokenizer with vacab size = 2560 to unblock 1D TP + fp8, and 2D + fp8 in CI |
||
"fp8_fsdp2_fp8_all_gather_precompute_dynamic_scales", | ||
ngpu=4, | ||
), | ||
] | ||
return integration_tests_flavors | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -347,6 +347,18 @@ def __init__(self): | |
here: https://github.com/pytorch-labs/float8_experimental | ||
""", | ||
) | ||
self.parser.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline, let's refactor fp8 configs, e.g. have a dedicated field for enabling fp8 or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think one thing to note is that right now this is a boolean which will swap to the default float8 recipe I think we should brainstorm on an elegant solutions for users to express their desired config here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good question. evetually we might have to expose args/kwargs from |
||
"--training.enable_fsdp_fp8_all_gather", | ||
action="store_true", | ||
default=False, | ||
help="Whether enable fp8 all-gather in FSDP", | ||
) | ||
self.parser.add_argument( | ||
"--training.precompute_float8_dynamic_scale_for_fsdp", | ||
action="store_true", | ||
default=False, | ||
help="Whether precompute fp8 scales dynamically for FSDP", | ||
) | ||
self.parser.add_argument( | ||
"--training.gc_freq", | ||
type=int, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,13 +12,30 @@ | |
|
||
# Note: Performance | ||
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance | ||
import contextlib | ||
|
||
import float8_experimental.config as config | ||
|
||
import torch | ||
import torch.nn as nn | ||
from float8_experimental.float8_linear import TensorScalingType | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.logging_utils import logger | ||
|
||
|
||
@contextlib.contextmanager | ||
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool): | ||
prev = config.enable_fsdp_fp8_all_gather | ||
torch.distributed.barrier() | ||
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather | ||
try: | ||
yield | ||
finally: | ||
torch.distributed.barrier() | ||
config.enable_fsdp_fp8_all_gather = prev | ||
|
||
|
||
def build_fp8_linear(model: nn.Module, job_config: JobConfig): | ||
""" | ||
This function converts the linear layers to `Float8Linear`. Note that today, | ||
|
@@ -27,8 +44,8 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): | |
This will mutate the model inplace. | ||
""" | ||
use_fp8_linear = job_config.training.fp8_linear | ||
enable_fsdp_fp8_all_gather = job_config.training.enable_fsdp_fp8_all_gather | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. discussed offline: please check if it makes sense to enable it only when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added check on |
||
try: | ||
from float8_experimental.float8_linear import Float8Linear | ||
from float8_experimental.float8_linear_utils import ( | ||
swap_linear_with_float8_linear, | ||
) | ||
|
@@ -38,5 +55,10 @@ def build_fp8_linear(model: nn.Module, job_config: JobConfig): | |
) from exc | ||
if use_fp8_linear: | ||
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear | ||
swap_linear_with_float8_linear(model, Float8Linear) | ||
logger.info("Swapped to Float8Linear layers") | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. noop Q: do we need this in a context manager to make testing + resetting easier? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm. EDIT: I also see you mentioned "make testing + resetting easier", which answered why. so I am not sure if it's a question for me |
||
swap_linear_with_float8_linear( | ||
model, scaling_type_w=TensorScalingType.DYNAMIC | ||
) | ||
logger.info( | ||
f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
import torch | ||
import torch.nn.functional as F | ||
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @weifengpy I think we should hide this import to the path where The problem here is that for every feature that requires an additional install from other dependency, we should try to hide the import to the path that uses it instead of import it globally, otherwise for users who didn't install the float8_experimental, if they rebase, and it would just fail to train for them. Please submit a follow up PR to fix this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got you. I am moving it from top-level to if-else now #464 thanks for the timely reminder |
||
from torch.distributed import destroy_process_group | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.distributed.elastic.multiprocessing.errors import record | ||
|
@@ -398,6 +399,9 @@ def loss_fn(pred, labels): | |
optimizers.step() | ||
lr_schedulers.step() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment to explain There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if job_config.training.precompute_float8_dynamic_scale_for_fsdp: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. discussed offline: can refactor to make it simpler There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
precompute_float8_dynamic_scale_for_fsdp(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a noob question: could you briefly explain what this is doing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
do you refer to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per suggestion, raise error if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. noob q: do we eventually want to just put this in fsdp2? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has to be done after optimizer step (since parameter values change). Are you suggesting to run this in the root module's pre-forward? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah anywhere between the n-1th optimizer step and the first all-gather in the nth step where fsdp2 has control (if there's any). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. I think one concern is that FSDP is agnostic to the fp8 all-gather. FSDP does not know that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see. Somehow I thought fsdp2 was fp8-aware |
||
|
||
losses_since_last_log.append(loss) | ||
|
||
# log metrics | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added followings to CI
need follow ups to enable TP fp8 all-gather in CI: current CI tokenizer has 2556, not divisible by 16) #461