Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE][5/n] simplify pp vs. non-pp set up #510

Merged
merged 3 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def parallelize_llama(
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
model = apply_tp(
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
Expand All @@ -60,7 +60,7 @@ def parallelize_llama(
)

if job_config.activation_checkpoint.mode != "none":
model = apply_ac(model, job_config.activation_checkpoint)
apply_ac(model, job_config.activation_checkpoint)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
Expand All @@ -69,14 +69,14 @@ def parallelize_llama(
"fused_rmsnorm is not compatible with torch.compile yet. "
"Please use rmsnorm or layernorm."
)
model = apply_compile(model)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_type == "fsdp":
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

model = apply_fsdp(
apply_fsdp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this change break anything? IIRC one feature required us to keep returning the model. Maybe it was AC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one feature required us to keep returning the model.

Great point. Yes torch.compile and AC require reassigning the model. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by

transformer_block = compile/AC(transformer_block)
model.layers.register_module(layer_id, transformer_block)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any motivation of removing the assignment? I thought an explicit assignment does not look bad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol
I actually don't know why these functions have redundant semantics, i.e. modifying a module in-place, but at the same time return it explicitly. I'm modifying it because:

  1. Before this PR, the PP branch explicitly reassign the returned module, but the SPMD branch doesn't. I think we should use the minimum viable code to reduce confusion.
  2. IIRC @awgu had a PR which removes the reassigning for FSDP2 fully_shard, so in some sense I'm mimicking that PR.

model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
Expand All @@ -88,15 +88,13 @@ def parallelize_llama(
else:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
model = apply_ddp(
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)

return model


def apply_tp(
model: nn.Module,
Expand All @@ -110,7 +108,7 @@ def apply_tp(
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
model = parallelize_module(
parallelize_module(
model,
tp_mesh,
{
Expand Down Expand Up @@ -192,7 +190,6 @@ def apply_tp(
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)
return model


# for selective op activation checkpointing
Expand Down Expand Up @@ -273,7 +270,6 @@ def apply_ac(model: nn.Module, ac_config):
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
return model


def apply_compile(model: nn.Module):
Expand All @@ -286,7 +282,6 @@ def apply_compile(model: nn.Module):
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
return model


def apply_fsdp(
Expand Down Expand Up @@ -329,8 +324,8 @@ def apply_fsdp(
module._load_state_dict_pre_hooks.clear()
assert len(module._state_dict_pre_hooks) <= 1
module._state_dict_pre_hooks.clear()

logger.info("Applied FSDP to the model")
return model


def apply_ddp(
Expand All @@ -347,7 +342,6 @@ def apply_ddp(
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)

logger.info("Applied DDP to the model")
return model
77 changes: 39 additions & 38 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,17 @@ def main(job_config: JobConfig):

logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(whole_model)
float8_handler.convert_to_float8_training(model)

# log model size
model_param_count = utils.get_num_params(whole_model)
model_param_count = utils.get_num_params(model)
num_flop_per_token = utils.get_num_flop_per_token(
utils.get_num_params(whole_model, exclude_embedding=True),
utils.get_num_params(model, exclude_embedding=True),
model_config,
job_config.training.seq_len,
)
Expand All @@ -134,41 +134,46 @@ def main(job_config: JobConfig):
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)

if parallel_dims.pp_enabled:
stages, model_parts = models_pipelining_fns[model_name](
whole_model, pp_mesh, parallel_dims, job_config, device, model_config
)
else:
# In 1D/2D cases or PP with simple schedules, model_parts is just one item
# for PP with looped schedules, each item is one stage-model-chunk
# we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing
model_parts = [whole_model]

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
for model in model_parts:
model.to_empty(device=init_device)

# loss fn can be shared by pipeline-parallel or non-pp execution
# loss function to be shared by Pipeline Parallel and spmd training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: capitalize SPMD

def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)

# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
stages, model_parts = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config
)
pp_schedule = build_pipeline_schedule(
job_config, parallel_dims, stages, loss_fn
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess one thing that could clean up train.py a bit is to move all this code inside of models_pipelining_fns and have it return (pp_schedule, stages, model_parts)

(not sure if we need stages, model_parts in train.py anymore actually).

if we did that, i'd update the comment "apply PT-D Pipeline Parallel" to mention also applying PP/DP/TP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need stages any more after build_pipeline_schedule, but we still need model_parts (for optimizer, checkpointing, fp8 updates, etc.)

I think it's clearer we separate PP and other parallelize_llama

# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply spmd-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)

# In PP, we cannot call init_weights directly because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have
# to consider RNG seed propagation. For now, we rely on a seed checkpoint to
# initialize the model.
m.to_empty(device="cuda")
m.train()
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
# allocate sharded model on GPU and initialize weights via DTensor
whole_model.init_weights()
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model.init_weights()
model.train()

model_parts = [model]

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand All @@ -183,10 +188,6 @@ def loss_fn(pred, labels):

train_state = TrainState()

# train loop
for model in model_parts:
model.train()

# load initial checkpoint
checkpoint = CheckpointManager(
dataloader=data_loader,
Expand Down Expand Up @@ -301,9 +302,9 @@ def loss_fn(pred, labels):
loss.backward()

# clip gradients
for model in model_parts:
for m in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
m.parameters(), job_config.training.max_norm, foreach=True
)

# sync float8 amaxes and scales
Expand Down Expand Up @@ -393,14 +394,14 @@ def loss_fn(pred, labels):
train_state.step, force=(train_state.step == job_config.training.steps)
)

# signals the profiler that the next profiling step has started
# signal the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()

if memory_profiler:
memory_profiler.step()

# Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished)
# reduce timeout after first train step for faster signal
# (assuming lazy init and compilation are finished)
if train_state.step == 1:
utils.set_pg_timeouts(
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
Expand Down
Loading