diff --git a/train.py b/train.py index a6bcbdd85..3f07d3c7b 100644 --- a/train.py +++ b/train.py @@ -11,7 +11,6 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record -from torch.fx import GraphModule from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState @@ -147,6 +146,8 @@ def loss_fn(pred, labels): model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) + pp_split_mode = job_config.experimental.pipeline_parallel_split_mode + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing @@ -154,6 +155,10 @@ def loss_fn(pred, labels): # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) m.to_empty(device="cuda") + # skip traced modules since we do not define init_weights in the traced module + if pp_split_mode == "manual": + m.init_weights() + m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) @@ -161,14 +166,10 @@ def loss_fn(pred, labels): # move sharded model to CPU/GPU and initialize weights via DTensor init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" model.to_empty(device=init_device) - model_parts = [model] + model.init_weights() + model.train() - for mod in model_parts: - # skip traced modules since we do not define init_weights in the traced module - if isinstance(mod, GraphModule): - continue - mod.init_weights() - mod.train() + model_parts = [model] gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -204,7 +205,7 @@ def loss_fn(pred, labels): checkpoint_loaded = checkpoint.load() if parallel_dims.pp_enabled and not checkpoint_loaded: - if job_config.experimental.pipeline_parallel_split_mode == "tracer": + if pp_split_mode == "tracer": raise RuntimeError( "Pipeline parallelism with tracer mode is not supported without a seed checkpoint." )