Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cp] apply fsdp to model when CP is enabled without DP for correct loss and lower mem usage #685

Merged
merged 22 commits into from
Dec 11, 2024

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Nov 20, 2024

Stack from ghstack (oldest at bottom):

Summary
Previously CP forgot to shard the model via apply_fsdp when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

Test

  1. modify train_configs/llama3_8b.toml
steps = 20
context_parallel_degree = 8
  1. run training on 8xH100 GPUs
    CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh
    Before: CUDA OutOfMemory
    After: successful 20-steps training

… correct loss and lower mem usage"

[ghstack-poisoned]
@XilunWu XilunWu changed the base branch from gh/XilunWu/12/base to main November 20, 2024 01:31
XilunWu added a commit that referenced this pull request Nov 29, 2024
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
@XilunWu XilunWu requested review from tianyu-l and fegin December 4, 2024 01:36
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should take a more systematic approach for creating & naming of sub-meshes in parallel_dims.py:
There are two layers sub-meshes: atomic ones and derived ones

  1. The atomic ones include those at minimal granularity: dp_shard, dp_replicate, tp, pp, cp.
  2. The derived ones include: dp (for data loading), dpshard_cp (for FSDP param sharding).

We can create / skip the atomic ones first, and then always create the derived ones. E.g. if dp_shard enabled, dp_replicate not enabled, cp not enabled, then we still create dp and dpshard_cp by flattening dp_shard alone.

This way the code would be simpler and more readable. The creation in parallel_dims would also be less confusing.

@fegin
Copy link
Contributor

fegin commented Dec 4, 2024

After looking into the table, I feel we should use dp_shard_cp for fully_shard purpose and leave dp for the data loader. This will be more consistent.

I also think @tianyu-l comment is a good approach. But I'm not sure if this will become more complicated when MoE is involved.

@tianyu-l
Copy link
Contributor

tianyu-l commented Dec 4, 2024

But I'm not sure if this will become more complicated when MoE is involved.

@fegin
I feel the idea could carry over to any new parallelisms -- always initialize the lowest level dimensions, and then gradually build up, with carefully chosen names though.

…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
@XilunWu XilunWu requested a review from tianyu-l December 5, 2024 01:03
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Dec 11, 2024
…ent (#720)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #684
* #685
* __->__ #720

**Summary**
This PR improves the design of DeviceMesh hierarchy in torchtitan. Now,
we define all device meshes except `world_mesh` into 2 categories:
1. Basic mesh: those meshes defined in job `.toml` file by users. This
include `pp` (`pipeline_parallel_degree`), `dp_replicate`
(`data_parallel_replicate_degree`), `dp_shard`
(`data_parallel_shard_degree`), `tp` (`tensor_parallel_degree`), and
`cp`(`context_parallel_degree`).
2. Synthesized mesh (or called "derived mesh"): meshes that are
synthesized from basic mesh by `_flatten()`. If the mesh in synthesized
from a single mesh, then it is equivalent to aliasing. So far we utilize
2 synthesized meshes: `dp` and `dp_shard_cp`. The `dp` mesh is used for
data loading and the `dp_shard_cp` mesh is used for model params
sharding.

**Test**
CI
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
@XilunWu XilunWu merged commit 40a0873 into main Dec 11, 2024
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants