diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index df21e711e..c354de6f7 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -205,7 +205,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ) # Apply tensor + sequence parallelism to every transformer block - for layer_id, transformer_block in enumerate(model.layers): + for layer_name, transformer_block in model.layers.named_children(): + # for layer_id, transformer_block in enumerate(model.layers): layer_plan = { "attention": PrepareModuleInput( input_layouts=(Shard(1), None),