Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep pp vs. non-pp split #526

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import torch
from torch.distributed.elastic.multiprocessing.errors import record
from torch.fx import GraphModule

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
Expand Down Expand Up @@ -147,28 +146,30 @@ def loss_fn(pred, labels):
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)

pp_split_mode = job_config.experimental.pipeline_parallel_split_mode

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device="cuda")
# skip traced modules since we do not define init_weights in the traced module
if pp_split_mode == "manual":
m.init_weights()
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model_parts = [model]
model.init_weights()
model.train()

for mod in model_parts:
# skip traced modules since we do not define init_weights in the traced module
if isinstance(mod, GraphModule):
continue
mod.init_weights()
mod.train()
model_parts = [model]

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand Down Expand Up @@ -204,7 +205,7 @@ def loss_fn(pred, labels):
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
if job_config.experimental.pipeline_parallel_split_mode == "tracer":
if pp_split_mode == "tracer":
raise RuntimeError(
"Pipeline parallelism with tracer mode is not supported without a seed checkpoint."
)
Expand Down
Loading