Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[405B] Add performance data for 405B model #554

Merged
merged 6 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added assets/images/llama3_1_405B_loss_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 28 additions & 8 deletions docs/performance.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
To demonstrate the effectiveness of PyTorch distributed training techniques used in torchtitan, we report both the infra metrics and loss curves of Llama 2 (13B and 70B) and Llama 3 (8B and 70B) training on 64 A100 (80GB memory) GPUs.
We report infra metrics achieved by [FSDP2](fsdp.md) (1D parallelism) under various configurations, and loss curves for both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel) training.
To demonstrate the effectiveness of PyTorch distributed training techniques used in torchtitan, we report both the infra metrics and loss curves of Llama 2 (13B and 70B) and Llama 3 (8B and 70B) training on 64 A100 (80GB memory) GPUs and Llama 3.1 (405B) on 128 H100 (94GB memory).
We report infra metrics achieved by [FSDP2](fsdp.md) (1D parallelism) under various configurations, and loss curves for both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel) training. (We only report 2D for 405B)


## Llama 3.1 performance numbers

Below are the WPS (word per second, or more accurately, token per second) and MFU (model FLOPS utilization) results which torchtitan achieves on the 405B model released in [Llama 3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1). The way we compute WPS and MFU can be found in `train.py`. Because the model now is larger, we run on 128 H100 GPUs to test both performance and loss curves. Below is the performance result of 405B model with optimizations we have developed. We do see OOM for 1D parallelism (FSDP2), so we only tested 2D parallelism (FSDP2 + Tensor Parallel).

| Model size | Batch size | Activation checkpointing | WPS | MFU | Optimizations |
| ----- | ----- | ----- | ----- | ----- | ----- |
| 405B | 2 | full | 118 | 37.1%[^1] | None
| 405B | 2 | full | 177 | 27.77%[^2] | Float8
| 405B | 2 | full | 185 | 29.03% | Float8 + Async TP

Here, we use local batch size 2 (global batch size = local batch size 2 * number of FSDP ranks 16 = 32).

Next, we show the loss curves, all models are trained 3000 steps on the [C4 dataset](https://huggingface.co/datasets/allenai/c4), with global batch size 32. We have to use full AC to save memory usage. The results are shown in the picture (a TensorBoard screenshot) below.

![image](../assets/images/llama3_1_405B_loss_curves.png)

## Llama 3 performance numbers

Below are the WPS (word per second, or more accurately, token per second) and MFU (model FLOPS utilization) results which torchtitan achieves on Llama 3 models with FSDP2 on 64 A100 (80GB) GPUs. The way we compute WPS and MFU can be found in `train.py`.
Below are the WPS and MFU results which torchtitan achieves on Llama 3 models with FSDP2 on 64 A100 (80GB) GPUs.

| Model size | Batch size | Activation checkpointing | WPS | MFU |
| ----- | ----- | ----- | ----- | ----- |
Expand All @@ -14,7 +30,7 @@ Below are the WPS (word per second, or more accurately, token per second) and MF

We use local batch size 1 (global batch size = local batch size 1 * number of FSDP ranks 64 = 64), because it mimics the small local batch size in large scaled training, and moreoever allows us to compare 1D (FSDP) and 2D (FSDP + TP) training under the same global batch size on both 8B and 70B Llama 3 models, without the out-of-memory (OOM) issue.

Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained 3000 steps on the [C4 dataset](https://huggingface.co/datasets/allenai/c4), with global batch size 64. In terms of activation checkpointing (AC) configs, the Llama 3 8B training jobs use selective op AC, whereas the Llama 3 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below.
Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained the same way as mentioned above with global batch size 64. In terms of activation checkpointing (AC) configs, the Llama 3 8B training jobs use selective op AC, whereas the Llama 3 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below.

![image](../assets/images/llama3_loss_curves.png)

Expand All @@ -28,16 +44,20 @@ Below are the WPS and MFU results which torchtitan achieves on Llama 2 models wi
| 13B | 2 | no | 2162 | 61.1% |
| 13B | 2 | selective layer | 1914 | 54.1% |
| 13B | 2 | selective op | 1904 | 53.8% |
| 70B | 1[^1] | selective op | 355 | 50.8% |
| 70B | 1[^3] | selective op | 355 | 50.8% |
| 70B | 2 | full | 353 | 50.5% |

We primarily use local batch size 2 (global batch size 128) in the experiments, to keep the same number of tokens per training iteration between Llama 2 and Llama 3 (since the default sequence length in Llama 2 is 4096 which is halved compared with Llama 3). In fact, for Llama 2 70B model with full activation checkpointing, the MFU can go up to 54% when local batch size is higher (but before an OOM happens).

Next we show the loss curves for Llama 2 13B and Llama 2 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained 3000 steps with global batch size 128.
In terms of activation checkpointing (AC) configs, the Llama 2 13B training jobs use selective op AC, whereas the Llama 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below[^2].
In terms of activation checkpointing (AC) configs, the Llama 2 13B training jobs use selective op AC, whereas the Llama 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below[^4].

![image](../assets/images/llama2_loss_curves.png)

[^1]: Since the 70B training with local batch size 2 will cause an OOM error when selective activation checkpointing is used, we report the local batch size 1 case instead.
[^1]: We used HBM2e based lower TDP SXM H100(95GB) for our test, the actual peak TFLOPs number is between SXM and NVL, and we don't know its exact value. So this MFU number is lower than actual MFU because we use the peak number of SXM directly.

[^2]: Since for Float8, we are not converting all the matmuls to Float8 because our fused attention implementation is not done in Float8, so this number is lower than expected.

[^3]: Since the 70B training with local batch size 2 will cause an OOM error when selective activation checkpointing is used, we report the local batch size 1 case instead.

[^2]: One may have noticed that for both 13B and 70B training, 1D parallelism has slightly better convergence than 2D parallelism in the first half of training. We believe this is caused by the stronger shuffling effect introduced by having more FSDP ranks in the 1D parallelism, and the difference in convergence speed should go away after switching to a randomized data loading solution.
[^4]: One may have noticed that for both 13B and 70B training, 1D parallelism has slightly better convergence than 2D parallelism in the first half of training. We believe this is caused by the stronger shuffling effect introduced by having more FSDP ranks in the 1D parallelism, and the difference in convergence speed should go away after switching to a randomized data loading solution.
16 changes: 14 additions & 2 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import gc
import os
import subprocess
from dataclasses import dataclass
from datetime import timedelta
from typing import Union
Expand Down Expand Up @@ -135,17 +136,28 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:

# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
def get_peak_flops(device_name: str) -> int:
# Run the lspci command and capture the output
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
# Filter the output for lines containing both "NVIDIA" and "H100"
filtered_lines = [
line
for line in result.stdout.splitlines()
if "NVIDIA" in line and "H100" in line
]
# Join all filtered lines into a single string
combined_output = " ".join(filtered_lines)
device_name = combined_output or device_name
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 1979e12
return 835e12
elif "PCIe" in device_name:
return 756e12
else: # for SXM and other variants
else: # for H100 SXM and other variants
return 989e12
else: # for other GPU types, assume A100
return 312e12
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def main(job_config: JobConfig):
# initialize GPU memory monitor and get peak flops for MFU calculation
gpu_memory_monitor = build_gpu_memory_monitor()
gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name)
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")
Expand Down
8 changes: 7 additions & 1 deletion train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ steps = 3000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
enable_float8_linear = false
compile = false
compile = true
dataset = "c4"

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = true

[checkpoint]
enable_checkpoint = false
Expand All @@ -51,3 +52,8 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'full' # ['none', 'selective', 'full']

[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
Loading