-
Notifications
You must be signed in to change notification settings - Fork 294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cp] apply fsdp to model when CP is enabled without DP for correct loss and lower mem usage #685
Conversation
…m usage [ghstack-poisoned]
…ss and lower mem usage [ghstack-poisoned]
…with DP" [ghstack-poisoned]
… correct loss and lower mem usage" [ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #684 * #685 * __->__ #683
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should take a more systematic approach for creating & naming of sub-meshes in parallel_dims.py
:
There are two layers sub-meshes: atomic ones and derived ones
- The atomic ones include those at minimal granularity:
dp_shard
,dp_replicate
,tp
,pp
,cp
. - The derived ones include:
dp
(for data loading),dpshard_cp
(for FSDP param sharding).
We can create / skip the atomic ones first, and then always create the derived ones. E.g. if dp_shard
enabled, dp_replicate
not enabled, cp
not enabled, then we still create dp
and dpshard_cp
by flattening dp_shard
alone.
This way the code would be simpler and more readable. The creation in parallel_dims
would also be less confusing.
After looking into the table, I feel we should use I also think @tianyu-l comment is a good approach. But I'm not sure if this will become more complicated when MoE is involved. |
@fegin |
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
…ent (#720) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #684 * #685 * __->__ #720 **Summary** This PR improves the design of DeviceMesh hierarchy in torchtitan. Now, we define all device meshes except `world_mesh` into 2 categories: 1. Basic mesh: those meshes defined in job `.toml` file by users. This include `pp` (`pipeline_parallel_degree`), `dp_replicate` (`data_parallel_replicate_degree`), `dp_shard` (`data_parallel_shard_degree`), `tp` (`tensor_parallel_degree`), and `cp`(`context_parallel_degree`). 2. Synthesized mesh (or called "derived mesh"): meshes that are synthesized from basic mesh by `_flatten()`. If the mesh in synthesized from a single mesh, then it is equivalent to aliasing. So far we utilize 2 synthesized meshes: `dp` and `dp_shard_cp`. The `dp` mesh is used for data loading and the `dp_shard_cp` mesh is used for model params sharding. **Test** CI
…d without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Summary
Previously CP forgot to shard the model via
apply_fsdp
when DP is not combined with CP. This leads to high peak memory usage and diverging loss.Test
train_configs/llama3_8b.toml
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh
Before: CUDA OutOfMemory
After: successful 20-steps training