-
Notifications
You must be signed in to change notification settings - Fork 293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE][5/n] simplify pp vs. non-pp set up #510
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,17 +115,17 @@ def main(job_config: JobConfig): | |
|
||
logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") | ||
with torch.device("meta"): | ||
whole_model = model_cls.from_model_args(model_config) | ||
model = model_cls.from_model_args(model_config) | ||
|
||
# a no-op hander if float8 is not enabled | ||
float8_handler = Float8Handler(job_config, parallel_dims) | ||
# swap to Float8Linear based on float8 configs | ||
float8_handler.convert_to_float8_training(whole_model) | ||
float8_handler.convert_to_float8_training(model) | ||
|
||
# log model size | ||
model_param_count = utils.get_num_params(whole_model) | ||
model_param_count = utils.get_num_params(model) | ||
num_flop_per_token = utils.get_num_flop_per_token( | ||
utils.get_num_params(whole_model, exclude_embedding=True), | ||
utils.get_num_params(model, exclude_embedding=True), | ||
model_config, | ||
job_config.training.seq_len, | ||
) | ||
|
@@ -134,41 +134,46 @@ def main(job_config: JobConfig): | |
f"{color.red}size: {model_param_count:,} total parameters{color.reset}" | ||
) | ||
|
||
if parallel_dims.pp_enabled: | ||
stages, model_parts = models_pipelining_fns[model_name]( | ||
whole_model, pp_mesh, parallel_dims, job_config, device, model_config | ||
) | ||
else: | ||
# In 1D/2D cases or PP with simple schedules, model_parts is just one item | ||
# for PP with looped schedules, each item is one stage-model-chunk | ||
# we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing | ||
model_parts = [whole_model] | ||
|
||
# apply PT-D DP/TP parallelisms and activation checkpointing | ||
model_parts = [ | ||
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) | ||
for m in model_parts | ||
] | ||
|
||
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" | ||
for model in model_parts: | ||
model.to_empty(device=init_device) | ||
|
||
# loss fn can be shared by pipeline-parallel or non-pp execution | ||
# loss function to be shared by Pipeline Parallel and spmd training | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: capitalize SPMD |
||
def loss_fn(pred, labels): | ||
return torch.nn.functional.cross_entropy( | ||
pred.flatten(0, 1), labels.flatten(0, 1) | ||
) | ||
|
||
# apply parallelisms and initialization | ||
if parallel_dims.pp_enabled: | ||
# apply PT-D Pipeline Parallel | ||
stages, model_parts = models_pipelining_fns[model_name]( | ||
model, pp_mesh, parallel_dims, job_config, device, model_config | ||
) | ||
pp_schedule = build_pipeline_schedule( | ||
job_config, parallel_dims, stages, loss_fn | ||
) | ||
|
||
# For PP with looped schedules, each item in model_parts is one stage-model-chunk. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i guess one thing that could clean up train.py a bit is to move all this code inside of (not sure if we need stages, model_parts in train.py anymore actually). if we did that, i'd update the comment "apply PT-D Pipeline Parallel" to mention also applying PP/DP/TP There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need I think it's clearer we separate PP and other |
||
# We need to iterate through model_parts to apply SPMD parallelisms, compilation, | ||
# optimizer, and checkpointing | ||
for m in model_parts: | ||
# apply spmd-style PT-D techniques | ||
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) | ||
|
||
# In PP, we cannot call init_weights directly because some layers are missing. | ||
# In the future, we may make init_weights handle missing layers, but also have | ||
# to consider RNG seed propagation. For now, we rely on a seed checkpoint to | ||
# initialize the model. | ||
m.to_empty(device="cuda") | ||
m.train() | ||
else: | ||
# If PP is enabled, we can't rely on init_weights, because some layers are missing. | ||
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. | ||
# allocate sharded model on GPU and initialize weights via DTensor | ||
whole_model.init_weights() | ||
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel | ||
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) | ||
|
||
# move sharded model to CPU/GPU and initialize weights via DTensor | ||
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" | ||
model.to_empty(device=init_device) | ||
model.init_weights() | ||
model.train() | ||
|
||
model_parts = [model] | ||
|
||
gpu_mem_stats = gpu_memory_monitor.get_peak_stats() | ||
logger.info( | ||
|
@@ -183,10 +188,6 @@ def loss_fn(pred, labels): | |
|
||
train_state = TrainState() | ||
|
||
# train loop | ||
for model in model_parts: | ||
model.train() | ||
|
||
# load initial checkpoint | ||
checkpoint = CheckpointManager( | ||
dataloader=data_loader, | ||
|
@@ -301,9 +302,9 @@ def loss_fn(pred, labels): | |
loss.backward() | ||
|
||
# clip gradients | ||
for model in model_parts: | ||
for m in model_parts: | ||
torch.nn.utils.clip_grad_norm_( | ||
model.parameters(), job_config.training.max_norm, foreach=True | ||
m.parameters(), job_config.training.max_norm, foreach=True | ||
) | ||
|
||
# sync float8 amaxes and scales | ||
|
@@ -393,14 +394,14 @@ def loss_fn(pred, labels): | |
train_state.step, force=(train_state.step == job_config.training.steps) | ||
) | ||
|
||
# signals the profiler that the next profiling step has started | ||
# signal the profiler that the next profiling step has started | ||
if torch_profiler: | ||
torch_profiler.step() | ||
|
||
if memory_profiler: | ||
memory_profiler.step() | ||
|
||
# Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished) | ||
# reduce timeout after first train step for faster signal | ||
# (assuming lazy init and compilation are finished) | ||
if train_state.step == 1: | ||
utils.set_pg_timeouts( | ||
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this change break anything? IIRC one feature required us to keep returning the model. Maybe it was AC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point. Yes torch.compile and AC require reassigning the model. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any motivation of removing the assignment? I thought an explicit assignment does not look bad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol
I actually don't know why these functions have redundant semantics, i.e. modifying a module in-place, but at the same time return it explicitly. I'm modifying it because:
fully_shard
, so in some sense I'm mimicking that PR.