-
Notifications
You must be signed in to change notification settings - Fork 294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable Context Parallel #592
Conversation
[ghstack-poisoned]
ghstack-source-id: b76e0d183826dad8c4c76426fe62abaf9ad43f2f Pull Request resolved: #592
[ghstack-poisoned]
ghstack-source-id: 90f1bde378561c9bd1dee3ac82990f9d91ba59ab Pull Request resolved: #592
# (use 2x max sequence length to be safe) | ||
self.model_args.max_seq_len * 2, | ||
# Note: removed the 2x relaxing in CP enablement | ||
self.model_args.max_seq_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc., @tianyu-l Want to understand is this okay?
For a general use case, we can also expand the CP to support stride
-like feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate a bit on why this change was needed by CP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l CP parallelize on the sequence dimension, anything related to the sequence dimension needs to be shard. So freqs_cis is the positional embedding and is required to be sharded according to the sequence length. So it is easier to support CP if everything has the same sequence length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable to me. @awgu to confirm this is OK.
Also we need to add a note in docs/composability.md to clarify why this (model change) is needed. It can be addressed in a separate PR; in that case please create issue / leave TODO.
[ghstack-poisoned]
ghstack-source-id: f9bc24ff92c0abce98dc3a0f847fc874fa77788c Pull Request resolved: #592
[ghstack-poisoned]
ghstack-source-id: 51288d0a142c839291d6035e6dddcc915e5e5a08 Pull Request resolved: #592
[ghstack-poisoned]
ghstack-source-id: 6126585b13e49131e8b2d9e05a5ef1f736a0c4d9 Pull Request resolved: #592
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 9107fbfa09b4cf858ae4943ce9cb8180b28e5ea8 Pull Request resolved: #592
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: a0832f24bf6cfacb5e74dcdc3bca3fb58caca4aa Pull Request resolved: #592
**Summary** (WIP) This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. TODO: 1. Add seq_len scalability, loss convergence, and WPS performance benchmark. 2. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! Still have a few questions, but I think it's mostly ready.
"FSDP+TP+CP", | ||
"fsdp+tp+cp", | ||
ngpu=8, | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Looks like FSDP/HSDP + TP + CP is working. How about PP? We can also mention progress in the .md doc later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the next step is to test 4D/5D (w/ PP and HSDP)
else ("dp",) | ||
) | ||
dp_mesh = ( | ||
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_flatten
seems called twice, another one in parallel_dims.py
. I wonder how this API works?
else ("dp",) | ||
) | ||
dp_mesh = ( | ||
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another question is: given the current implementation, does it mean DP and CP have to be adjacent to each other in the device mesh? E.g. it seems we can't do [TP, CP, PP, DP] (from inner to outer) as in Llama 3.1 paper. Is that correct?
@awgu thanks for catching it! I forgot to paste the 5000-step loss curve screenshot but a zoom-in one for warm-up steps. Updated. |
By looking into how token/sec and MFU change when increasing There are two parts of flops computation
When doubling seq_len together with CP degree (also doubling number of GPUs since other degrees are fixed)
My conclusion is that the perf scaling we observed makes sense and look good to me. |
**Summary** This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. **Instructions** Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). **Benchmark Settings** The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below: - norm_type = "fused_rmsnorm" - mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device) **Benchmark Results** 1. *max seq_len scalability* The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file. <img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3"> *As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)* <img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00"> *While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input* <img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064"> *However, increasing CP degree doesn't significantly affect MFU.* 2. *loss curve convergence* The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`. <img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227"> *The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).* <img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375"> *The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.* 3. *max seq_len vs. local WPS on a fixed number of GPUs* The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. <img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0"> *Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.* TODO: 1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 2. Investigate the limitation in pure CP, to enable ultra-long input sequence. [ghstack-poisoned]
**Summary** This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. **Instructions** Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). **Benchmark Settings** The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below: - norm_type = "fused_rmsnorm" - mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device) **Benchmark Results** 1. *max seq_len scalability* The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file. <img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3"> *As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)* <img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00"> *While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input* <img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064"> *However, increasing CP degree doesn't significantly affect MFU.* 2. *loss curve convergence* The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`. <img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227"> *The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).* <img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375"> *The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.* 3. *max seq_len vs. local WPS on a fixed number of GPUs* The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. <img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0"> *Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.* TODO: 1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 2. Investigate the limitation in pure CP, to enable ultra-long input sequence. [ghstack-poisoned]
ghstack-source-id: 5a8900c61f5b735aee005aefe95adda0cd678144 Pull Request resolved: #592
**Summary** This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. **Instructions** Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). **Benchmark Settings** The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below: - norm_type = "fused_rmsnorm" - mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device) **Benchmark Results** 1. *max seq_len scalability* The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file. <img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3"> *As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)* <img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00"> *While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input* <img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064"> *However, increasing CP degree doesn't significantly affect MFU.* 2. *loss curve convergence* The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`. <img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227"> *The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).* <img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375"> *The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.* 3. *max seq_len vs. local WPS on a fixed number of GPUs* The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. <img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0"> *Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.* TODO: 1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 2. Investigate the limitation in pure CP, to enable ultra-long input sequence. [ghstack-poisoned]
**Summary** This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. **Instructions** Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). **Benchmark Settings** The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below: - norm_type = "fused_rmsnorm" - mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device) **Benchmark Results** 1. *max seq_len scalability* The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file. <img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3"> *As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)* <img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00"> *While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input* <img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064"> *However, increasing CP degree doesn't significantly affect MFU.* 2. *loss curve convergence* The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`. <img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227"> *The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).* <img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375"> *The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.* 3. *max seq_len vs. local WPS on a fixed number of GPUs* The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. <img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0"> *Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.* TODO: 1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 2. Investigate the limitation in pure CP, to enable ultra-long input sequence. [ghstack-poisoned]
ghstack-source-id: a0ad832fe452f1cc35c37139f498a82c4bbeeae0 Pull Request resolved: #592
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
@tianyu-l actually, the 8 GPU test has a failure. This is the HSDP + CP case. I think the flatten logic has some bug when handling Flattening |
**Summary** This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. **Instructions** Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). **Benchmark Settings** The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below: - norm_type = "fused_rmsnorm" - mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device) **Benchmark Results** 1. *max seq_len scalability* The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file. <img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3"> *As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)* <img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00"> *While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input* <img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064"> *However, increasing CP degree doesn't significantly affect MFU.* 2. *loss curve convergence* The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`. <img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227"> *The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).* <img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375"> *The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.* 3. *max seq_len vs. local WPS on a fixed number of GPUs* The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. <img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0"> *Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.* TODO: 1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 2. Investigate the limitation in pure CP, to enable ultra-long input sequence. [ghstack-poisoned]
**Summary** This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. **Instructions** Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). **Benchmark Settings** The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below: - norm_type = "fused_rmsnorm" - mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device) **Benchmark Results** 1. *max seq_len scalability* The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file. <img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3"> *As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)* <img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00"> *While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input* <img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064"> *However, increasing CP degree doesn't significantly affect MFU.* 2. *loss curve convergence* The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`. <img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227"> *The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).* <img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375"> *The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.* 3. *max seq_len vs. local WPS on a fixed number of GPUs* The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. <img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0"> *Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.* TODO: 1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 2. Investigate the limitation in pure CP, to enable ultra-long input sequence. [ghstack-poisoned]
ghstack-source-id: 584d30016c4598e9a64595c3c2cb227f88de9b00 Pull Request resolved: #592
… with mesh access" **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch. [ghstack-poisoned]
**Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch. [ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #666 **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #667 Note: This PR is a reland of #666 where the PR was mistakenly merged into a wrong branch. **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.
) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ pytorch#667 Note: This PR is a reland of pytorch#666 where the PR was mistakenly merged into a wrong branch. **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In pytorch#592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.
from torch.nn.attention import sdpa_kernel, SDPBackend | ||
|
||
# currently we only support these two SDP backends. | ||
# TODO (xilunwu): support cuDNN backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious if you recall what the blocker for CUDNN_ATTENTION
is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @tmm1 it's simply cudnn attention has a different op signature. I'm adding support now and should be able to have the PR draft out by next week.
Stack from ghstack (oldest at bottom):
Summary
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan.
Instructions
Find the
context_parallel_degree
option in theexperimental
section in.toml
file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension).Benchmark Settings
The benchmark uses
train_configs/llama3_8b.toml
with extra changes as below:Benchmark Results
The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix
data_parallel_shard_degree
to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying thecontext_parallel_degree
in theexperimental
section in.toml
file.As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)
While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input
However, increasing CP degree doesn't significantly affect MFU.
The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (
tensor_parallel_degree=8
,batch_size=8
). As we increase CP degree, we decrease DP degree to make sure thatCP * DP = 8
. Besides, we also make sureseq_len
increases such thatseq_len = 8192 * CP
.The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).
The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.
The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware.
Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.
TODO: