Skip to content

Commit

Permalink
run sdpa with dtensor
Browse files Browse the repository at this point in the history
ghstack-source-id: b8b2b58ffc72fcb8bfc88f4ba2a3455e3cc92c0a
Pull Request resolved: #180
  • Loading branch information
tianyu-l authored and wconstab committed May 2, 2024
1 parent 05f0802 commit c530a64
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
)
Expand Down Expand Up @@ -181,15 +182,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
# 2. Prepare the freq_cis in rotary embedding as dtensor
# 3. Parallelize the root norm layer over the sequence dim
# 4. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"embeddings": PrepareModuleOutput(
output_layouts=(None, Replicate()),
desired_output_layouts=(None, Replicate()),
use_local_output=False,
),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=(Shard(-1) if loss_parallel else Replicate()),
Expand All @@ -212,9 +219,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wq": col_parallel_strategy(use_local_output=False),
"attention.wk": col_parallel_strategy(use_local_output=False),
"attention.wv": col_parallel_strategy(use_local_output=False),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
Expand All @@ -227,11 +234,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down

0 comments on commit c530a64

Please sign in to comment.