Skip to content

Commit

Permalink
Enable TP+PP support
Browse files Browse the repository at this point in the history
ghstack-source-id: fb342bde03ad6fe6cdb7615eb5ab8f4b2ce5ace4
Pull Request resolved: #285
  • Loading branch information
wconstab committed May 2, 2024
1 parent 805ff75 commit 2a567d3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
)

# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
for layer_name, transformer_block in model.layers.named_children():
# for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(1), None),
Expand Down

0 comments on commit 2a567d3

Please sign in to comment.