Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PP] Bypass seed checkpoint my init-ing model parts separately #516

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -394,19 +394,23 @@ def init_weights(self):
"""
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
nn.init.normal_(self.tok_embeddings.weight)
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
for layer in self.layers.values():
layer.init_weights()
self.norm.reset_parameters()
if layer is not None:
layer.init_weights()
if self.norm is not None:
self.norm.reset_parameters()
final_out_std = self.model_args.dim**-0.5
cutoff_factor = 3
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
if self.output is not None:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)

def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
24 changes: 12 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@

import torch
from torch.distributed.elastic.multiprocessing.errors import record
from torch.fx import GraphModule

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
@@ -152,25 +153,23 @@ def loss_fn(pred, labels):
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)

# In PP, we cannot call init_weights directly because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have
# to consider RNG seed propagation. For now, we rely on a seed checkpoint to
# initialize the model.
m.to_empty(device="cuda")
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model.init_weights()
model.train()

model_parts = [model]

for mod in model_parts:
# skip traced modules since we do not define init_weights in the traced module
if isinstance(mod, GraphModule):
continue
mod.init_weights()
mod.train()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
f"GPU memory usage for model: "
@@ -205,9 +204,10 @@ def loss_fn(pred, labels):
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
raise RuntimeError(
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
# TODO: fix this by allowing each rank to set their own seed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, init_weights would be called all the time, and then when seed ckpt is specified, it would overwrite the init. that sgtm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that is correct

logger.warning(
"Pipeline Parallelism is being used without a seed checkpoint. "
"All the substages will be initialized with random weights with same RNG state which can affect convergence."
)

metric_logger = build_metric_logger(job_config, parallel_dims)