Skip to content

Commit

Permalink
keep pp vs. non-pp split
Browse files Browse the repository at this point in the history
ghstack-source-id: 9894aa1bc6d6026f59d6a4cc28b573dbb87d20d0
Pull Request resolved: #526
  • Loading branch information
tianyu-l committed Aug 17, 2024
1 parent 81c555f commit 89b6d64
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import torch
from torch.distributed.elastic.multiprocessing.errors import record
from torch.fx import GraphModule

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
Expand Down Expand Up @@ -147,28 +146,30 @@ def loss_fn(pred, labels):
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)

pp_split_mode = job_config.experimental.pipeline_parallel_split_mode

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device="cuda")
# skip traced modules since we do not define init_weights in the traced module
if pp_split_mode == "manual":
m.init_weights()
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model_parts = [model]
model.init_weights()
model.train()

for mod in model_parts:
# skip traced modules since we do not define init_weights in the traced module
if isinstance(mod, GraphModule):
continue
mod.init_weights()
mod.train()
model_parts = [model]

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand Down Expand Up @@ -204,7 +205,7 @@ def loss_fn(pred, labels):
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
if job_config.experimental.pipeline_parallel_split_mode == "tracer":
if pp_split_mode == "tracer":
raise RuntimeError(
"Pipeline parallelism with tracer mode is not supported without a seed checkpoint."
)
Expand Down

0 comments on commit 89b6d64

Please sign in to comment.