Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable Context Parallel #592

Merged
merged 25 commits into from
Oct 23, 2024
Merged

enable Context Parallel #592

merged 25 commits into from
Oct 23, 2024

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Sep 30, 2024

Stack from ghstack (oldest at bottom):

Summary
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan.

Instructions
Find the context_parallel_degree option in the experimental section in .toml file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension).

Benchmark Settings
The benchmark uses train_configs/llama3_8b.toml with extra changes as below:

  • norm_type = "fused_rmsnorm"
  • mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device)

Benchmark Results

  1. max seq_len scalability
    The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix data_parallel_shard_degree to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the context_parallel_degree in the experimental section in .toml file.
image

As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)

image

While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input

image

However, increasing CP degree doesn't significantly affect MFU.

  1. loss curve convergence
    The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (tensor_parallel_degree=8, batch_size=8). As we increase CP degree, we decrease DP degree to make sure that CP * DP = 8. Besides, we also make sure seq_len increases such that seq_len = 8192 * CP.
image

The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).

Screenshot 2024-10-07 at 12 42 37 PM

The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.

  1. max seq_len vs. local WPS on a fixed number of GPUs
    The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware.
image

Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.

TODO:

  1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once [CP] Implement AllGather based context parallelism pytorch#132820 is landed.
  2. Investigate the limitation in pure CP, to enable ultra-long input sequence.

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Sep 30, 2024
ghstack-source-id: b76e0d183826dad8c4c76426fe62abaf9ad43f2f
Pull Request resolved: #592
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 30, 2024
@XilunWu XilunWu marked this pull request as draft September 30, 2024 20:09
@XilunWu XilunWu requested a review from fegin September 30, 2024 20:21
XilunWu added a commit that referenced this pull request Sep 30, 2024
ghstack-source-id: 90f1bde378561c9bd1dee3ac82990f9d91ba59ab
Pull Request resolved: #592
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
# Note: removed the 2x relaxing in CP enablement
self.model_args.max_seq_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc., @tianyu-l Want to understand is this okay?

For a general use case, we can also expand the CP to support stride-like feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate a bit on why this change was needed by CP?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l CP parallelize on the sequence dimension, anything related to the sequence dimension needs to be shard. So freqs_cis is the positional embedding and is required to be sharded according to the sequence length. So it is easier to support CP if everything has the same sequence length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me. @awgu to confirm this is OK.

Also we need to add a note in docs/composability.md to clarify why this (model change) is needed. It can be addressed in a separate PR; in that case please create issue / leave TODO.

@fegin fegin mentioned this pull request Oct 1, 2024
XilunWu added a commit that referenced this pull request Oct 3, 2024
ghstack-source-id: f9bc24ff92c0abce98dc3a0f847fc874fa77788c
Pull Request resolved: #592
XilunWu added a commit that referenced this pull request Oct 3, 2024
ghstack-source-id: 51288d0a142c839291d6035e6dddcc915e5e5a08
Pull Request resolved: #592
XilunWu added a commit that referenced this pull request Oct 4, 2024
ghstack-source-id: 6126585b13e49131e8b2d9e05a5ef1f736a0c4d9
Pull Request resolved: #592
@XilunWu XilunWu mentioned this pull request Oct 21, 2024
XilunWu added a commit that referenced this pull request Oct 21, 2024
ghstack-source-id: 9107fbfa09b4cf858ae4943ce9cb8180b28e5ea8
Pull Request resolved: #592
XilunWu added a commit that referenced this pull request Oct 21, 2024
ghstack-source-id: a0832f24bf6cfacb5e74dcdc3bca3fb58caca4aa
Pull Request resolved: #592
@XilunWu XilunWu marked this pull request as ready for review October 21, 2024 09:18
**Summary** (WIP)
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. 

TODO: 
1. Add seq_len scalability, loss convergence, and WPS performance benchmark.
2. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 



[ghstack-poisoned]
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Still have a few questions, but I think it's mostly ready.

"FSDP+TP+CP",
"fsdp+tp+cp",
ngpu=8,
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Looks like FSDP/HSDP + TP + CP is working. How about PP? We can also mention progress in the .md doc later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the next step is to test 4D/5D (w/ PP and HSDP)

else ("dp",)
)
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_flatten seems called twice, another one in parallel_dims.py. I wonder how this API works?

else ("dp",)
)
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another question is: given the current implementation, does it mean DP and CP have to be adjacent to each other in the device mesh? E.g. it seems we can't do [TP, CP, PP, DP] (from inner to outer) as in Llama 3.1 paper. Is that correct?

@XilunWu
Copy link
Contributor Author

XilunWu commented Oct 22, 2024

@awgu thanks for catching it! I forgot to paste the 5000-step loss curve screenshot but a zoom-in one for warm-up steps. Updated.

@tianyu-l
Copy link
Contributor

While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input. However, increasing CP degree doesn't significantly affect MFU.

By looking into how token/sec and MFU change when increasing seq_len, here's my explanation to this phenomenon.

There are two parts of flops computation
https://github.com/pytorch/torchtitan/blob/main/torchtitan/utils.py#L150

  • (I) matmul scales linearly with number of tokens
  • (II) attention scales quadratically with number of tokens

When doubling seq_len together with CP degree (also doubling number of GPUs since other degrees are fixed)

  • (I) matmul computation doubles; (II) attention computation quadruples.
  • Even CP doesn’t bring any overhead (e.g. extra comm exposure), we shoudn’t expect the same WPS per GPU because computation resource is doubled, but computation load is between 2x – 4x. The throughput per GPU should be between 0.5x and 1x, assuming perfect scaling.
  • The MFU change depends on the actual config, more specifically between the ratio of (I) and (II). If (I) dominates, the MFU/WPS ratio wouldn’t change as we increase seq_len; if (II) dominates, the MFU/WPS ratio would increase linearly with seq_len. For 8B and seq_len around 32k/64k, (II) is larger than (I) so that’s why we are seeing WPS drop, but MFU stays (because the ratio increases).

My conclusion is that the perf scaling we observed makes sense and look good to me.

**Summary**
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. 

**Instructions**
Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). 

**Benchmark Settings**
The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below:

- norm_type = "fused_rmsnorm"
- mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device)

**Benchmark Results**
1. *max seq_len scalability*
The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file.
<img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3">

*As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)*

<img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00">

*While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input*

<img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064">

*However, increasing CP degree doesn't significantly affect MFU.*

2. *loss curve convergence*
The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`.

<img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227">

*The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).*

<img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375">

*The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.*

3. *max seq_len vs. local WPS on a fixed number of GPUs*
The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. 

<img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0">

*Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.*

TODO: 
1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 
2. Investigate the limitation in pure CP, to enable ultra-long input sequence.



[ghstack-poisoned]
**Summary**
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. 

**Instructions**
Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). 

**Benchmark Settings**
The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below:

- norm_type = "fused_rmsnorm"
- mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device)

**Benchmark Results**
1. *max seq_len scalability*
The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file.
<img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3">

*As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)*

<img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00">

*While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input*

<img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064">

*However, increasing CP degree doesn't significantly affect MFU.*

2. *loss curve convergence*
The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`.

<img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227">

*The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).*

<img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375">

*The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.*

3. *max seq_len vs. local WPS on a fixed number of GPUs*
The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. 

<img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0">

*Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.*

TODO: 
1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 
2. Investigate the limitation in pure CP, to enable ultra-long input sequence.



[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: 5a8900c61f5b735aee005aefe95adda0cd678144
Pull Request resolved: #592
@XilunWu XilunWu changed the base branch from gh/XilunWu/6/base to main October 22, 2024 22:02
**Summary**
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. 

**Instructions**
Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). 

**Benchmark Settings**
The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below:

- norm_type = "fused_rmsnorm"
- mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device)

**Benchmark Results**
1. *max seq_len scalability*
The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file.
<img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3">

*As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)*

<img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00">

*While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input*

<img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064">

*However, increasing CP degree doesn't significantly affect MFU.*

2. *loss curve convergence*
The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`.

<img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227">

*The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).*

<img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375">

*The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.*

3. *max seq_len vs. local WPS on a fixed number of GPUs*
The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. 

<img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0">

*Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.*

TODO: 
1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 
2. Investigate the limitation in pure CP, to enable ultra-long input sequence.



[ghstack-poisoned]
**Summary**
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. 

**Instructions**
Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). 

**Benchmark Settings**
The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below:

- norm_type = "fused_rmsnorm"
- mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device)

**Benchmark Results**
1. *max seq_len scalability*
The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file.
<img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3">

*As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)*

<img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00">

*While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input*

<img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064">

*However, increasing CP degree doesn't significantly affect MFU.*

2. *loss curve convergence*
The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`.

<img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227">

*The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).*

<img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375">

*The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.*

3. *max seq_len vs. local WPS on a fixed number of GPUs*
The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. 

<img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0">

*Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.*

TODO: 
1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 
2. Investigate the limitation in pure CP, to enable ultra-long input sequence.



[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: a0ad832fe452f1cc35c37139f498a82c4bbeeae0
Pull Request resolved: #592
@XilunWu XilunWu requested a review from tianyu-l October 22, 2024 22:44
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@XilunWu
Copy link
Contributor Author

XilunWu commented Oct 22, 2024

@tianyu-l actually, the 8 GPU test has a failure. This is the HSDP + CP case. I think the flatten logic has some bug when handling mesh["dp_cp"].

Flattening mesh["dp_replicate", "dp_shard", "cp"] into mesh["dp_cp"] is a workaround w/o actual cost. I think we can land this first with a TODO and put the __get_item__ call back once it's fixed in DeviceMesh.

**Summary**
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. 

**Instructions**
Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). 

**Benchmark Settings**
The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below:

- norm_type = "fused_rmsnorm"
- mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device)

**Benchmark Results**
1. *max seq_len scalability*
The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file.
<img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3">

*As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)*

<img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00">

*While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input*

<img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064">

*However, increasing CP degree doesn't significantly affect MFU.*

2. *loss curve convergence*
The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`.

<img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227">

*The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).*

<img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375">

*The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.*

3. *max seq_len vs. local WPS on a fixed number of GPUs*
The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. 

<img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0">

*Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.*

TODO: 
1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 
2. Investigate the limitation in pure CP, to enable ultra-long input sequence.



[ghstack-poisoned]
**Summary**
This PR adds DTensor Context Parallel (pytorch/pytorch#131351) option to torchtitan. 

**Instructions**
Find the `context_parallel_degree` option in the `experimental` section in `.toml` file, assign the desired Context Parallel degree (i.e. the # of GPUs along the Context Parallel dimension). 

**Benchmark Settings**
The benchmark uses `train_configs/llama3_8b.toml` with extra changes as below:

- norm_type = "fused_rmsnorm"
- mode = 'full' (using full activation checkpoint allows us to find the longest possible input sequence length on a device)

**Benchmark Results**
1. *max seq_len scalability*
The first set of benchmark shows that with CP, adding more GPUs on the CP dimension allows longer input seq_len and this growth is linearly scalable. We use FSDP to shard the model and fix `data_parallel_shard_degree` to be 8, and we use 8, 16, 32, 64 H100 GPUs (i.e. CP degree=1,2,4,8) to measure the longest possible input seq_len w/o GPU OOM. We adjust the CP degree by modifying the `context_parallel_degree` in the `experimental` section in `.toml` file.
<img width="600" alt="image" src="https://github.com/user-attachments/assets/03f68062-3368-4ea5-a324-9d2346806af3">

*As Context Parallel degrees goes up (1, 2, 4, 8), the max seq_len also increases (32k, 80k, 144k, 300k)*

<img width="597" alt="image" src="https://github.com/user-attachments/assets/b48078d3-ab4d-40dc-97cf-2d477b713e00">

*While increasing Context Parallel degree increases the max seq_len, it also decreases the WPS on each device. This is the price paid for longer input*

<img width="593" alt="image" src="https://github.com/user-attachments/assets/edef69ef-f0f0-41ec-9114-f7c420929064">

*However, increasing CP degree doesn't significantly affect MFU.*

2. *loss curve convergence*
The second set of benchmark proves that our implementation is correct by showing the loss curve converges just as FSDP + TP. This experiment uses 64 H100 GPUs and fix TP degree and local batch size to 8 (`tensor_parallel_degree=8`, `batch_size=8`). As we increase CP degree, we decrease DP degree to make sure that `CP * DP = 8`. Besides, we also make sure `seq_len` increases such that `seq_len = 8192 * CP`.

<img width="1061" alt="image" src="https://github.com/user-attachments/assets/8a25e0b1-cbd4-46d5-a444-8dd6f5277227">

*The training loss curve matches among the 4 combination of parallelisms: DP=8; DP=4, CP=2; DP=2, CP=4; CP=8 (TP is fixed to 8).*

<img width="2292" alt="Screenshot 2024-10-07 at 12 42 37 PM" src="https://github.com/user-attachments/assets/74a62cc8-562b-4441-babf-d601ec9bf375">

*The only observed difference is among the warm-up steps. This is due to our pick of the number of steps and picking an appropriate number of warm-up steps should smooth out this.*

3. *max seq_len vs. local WPS on a fixed number of GPUs*
The third set of benchmark gives users a guide on what CP degree to choose and what WPS to expect on a specific set of devices (e.g. number of devices, device types, etc). Here we use H100 as an example but it is expected to observer similar trade-off curves on other hardware. 

<img width="830" alt="image" src="https://github.com/user-attachments/assets/b1ea2320-e0ae-439f-9434-4194bd4932a0">

*Increasing CP degree allows longer input sequence length at the cost of reducing WPS on each device. The exception is pure CP -- the seq_len increase is very limited compared to DP=2, CP=4 on 8 GPUs. This is due to our implementation, and we're working on enabling performant CP to support ultra-long input sequence.*

TODO: 
1. An option for CP tensor rotate implementation between all-to-all and all-gather will be added once pytorch/pytorch#132820 is landed. 
2. Investigate the limitation in pure CP, to enable ultra-long input sequence.



[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: 584d30016c4598e9a64595c3c2cb227f88de9b00
Pull Request resolved: #592
@XilunWu XilunWu merged commit b19456a into main Oct 23, 2024
5 checks passed
XilunWu added a commit that referenced this pull request Oct 31, 2024
… with mesh access"


**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.


[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Oct 31, 2024
**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.


[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Oct 31, 2024
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #666

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access
on flattened mesh which are constructed from more than 2 meshes. Refer
to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct
accessing the flattened mesh. We want to turn back to mesh access which
is more straightforward since the fix has been merged in PyTorch.
XilunWu added a commit that referenced this pull request Oct 31, 2024
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #667

Note: This PR is a reland of #666 where the PR was mistakenly merged
into a wrong branch.

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access
on flattened mesh which are constructed from more than 2 meshes. Refer
to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct
accessing the flattened mesh. We want to turn back to mesh access which
is more straightforward since the fix has been merged in PyTorch.
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Nov 26, 2024
)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ pytorch#667

Note: This PR is a reland of pytorch#666 where the PR was mistakenly merged
into a wrong branch.

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access
on flattened mesh which are constructed from more than 2 meshes. Refer
to the fix PR for details if interested.

In pytorch#592 we avoided this issue by calling `_flatten` instead of direct
accessing the flattened mesh. We want to turn back to mesh access which
is more straightforward since the fix has been merged in PyTorch.
from torch.nn.attention import sdpa_kernel, SDPBackend

# currently we only support these two SDP backends.
# TODO (xilunwu): support cuDNN backend
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious if you recall what the blocker for CUDNN_ATTENTION is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tmm1 it's simply cudnn attention has a different op signature. I'm adding support now and should be able to have the PR draft out by next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants