Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[405B] Add performance data for 405B model #554

Merged
merged 6 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added assets/images/llama3_1_405B_loss_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 18 additions & 2 deletions docs/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,25 @@ To demonstrate the effectiveness of PyTorch distributed training techniques used
We report infra metrics achieved by [FSDP2](fsdp.md) (1D parallelism) under various configurations, and loss curves for both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel) training.


## Llama 3.1 performance numbers

Below are the WPS (word per second, or more accurately, token per second) and MFU (model FLOPS utilization) results which torchtitan achieves on the 405B model released in [LLaMa 3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1). The way we compute WPS and MFU can be found in `train.py`. Because the model now is larger, we run on 128 H100 GPUs to test both performance and loss curves. Below is the performance result of 405B model with optimizations we have developed. We do see OOM for 1D parallelism (FSDP2), so we only tested 2D parallelism (FSDP2 + Tensor Parallel).

| Model size | Batch size | Activation checkpointing | WPS | MFU | Optimizations |
| ----- | ----- | ----- | ----- | ----- | ----- |
| 405B | 2 | full | 118 | 37.1% | None
| 405B | 2 | full | 177 | 27.77% | Float8
| 405B | 2 | full | 185 | 29.03% | Float8 + Async TP

Here, we use local batch size 2 (global batch size = local batch size 2 * number of FSDP ranks 16 = 32).

Next, we show the loss curves, all models are trained 3000 steps on the [C4 dataset](https://huggingface.co/datasets/allenai/c4), with global batch size 32. We have to use full AC to save memory usage. The results are shown in the picture (a TensorBoard screenshot) below.

![image](../assets/images/llama3_1_405B_loss_curves.png)

## Llama 3 performance numbers

Below are the WPS (word per second, or more accurately, token per second) and MFU (model FLOPS utilization) results which torchtitan achieves on Llama 3 models with FSDP2 on 64 A100 (80GB) GPUs. The way we compute WPS and MFU can be found in `train.py`.
Below are the WPS and MFU results which torchtitan achieves on Llama 3 models with FSDP2 on 64 A100 (80GB) GPUs.

| Model size | Batch size | Activation checkpointing | WPS | MFU |
| ----- | ----- | ----- | ----- | ----- |
Expand All @@ -14,7 +30,7 @@ Below are the WPS (word per second, or more accurately, token per second) and MF

We use local batch size 1 (global batch size = local batch size 1 * number of FSDP ranks 64 = 64), because it mimics the small local batch size in large scaled training, and moreoever allows us to compare 1D (FSDP) and 2D (FSDP + TP) training under the same global batch size on both 8B and 70B Llama 3 models, without the out-of-memory (OOM) issue.

Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained 3000 steps on the [C4 dataset](https://huggingface.co/datasets/allenai/c4), with global batch size 64. In terms of activation checkpointing (AC) configs, the Llama 3 8B training jobs use selective op AC, whereas the Llama 3 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below.
Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1D parallelism (FSDP2) and 2D parallelism (FSDP2 + Tensor Parallel). All four models are trained the same way as mentioned above with global batch size 64. In terms of activation checkpointing (AC) configs, the Llama 3 8B training jobs use selective op AC, whereas the Llama 3 70B training jobs use full AC. The results are shown in the picture (a TensorBoard screenshot) below.

![image](../assets/images/llama3_loss_curves.png)

Expand Down
13 changes: 11 additions & 2 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.distributed.distributed_c10d as c10d
from torch.distributed.device_mesh import DeviceMesh
from torchtitan.logging import logger
from torchtitan.metrics import GPUMemoryMonitor


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
Expand Down Expand Up @@ -134,7 +135,9 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:


# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
def get_peak_flops(device_name: str) -> int:
def get_peak_flops(gpu_memory_monitor: GPUMemoryMonitor) -> int:
device_name = gpu_memory_monitor.device_name
device_mem = int(gpu_memory_monitor.device_capacity_gib)
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
Expand All @@ -145,7 +148,13 @@ def get_peak_flops(device_name: str) -> int:
return 1979e12
elif "PCIe" in device_name:
return 756e12
else: # for SXM and other variants
# Sometimes, the device name is just literally "NVIDIA H100", so
# we won't hit the two conditions before; we then need
# to check the memory size of the device to determine the peak flops.
# H100 NVL has 94 GiB memory, source: https://www.nvidia.com/en-us/data-center/h100/
elif device_mem >= 94:
return 835e12
else: # for H100 SXM
return 989e12
else: # for other GPU types, assume A100
return 312e12
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def main(job_config: JobConfig):
utils.init_distributed(job_config)
# initialize GPU memory monitor and get peak flops for MFU calculation
gpu_memory_monitor = build_gpu_memory_monitor()
gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name)
gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor)
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")
Expand Down
8 changes: 7 additions & 1 deletion train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ steps = 3000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
enable_float8_linear = false
compile = false
compile = true
dataset = "c4"

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = true

[checkpoint]
enable_checkpoint = false
Expand All @@ -51,3 +52,8 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'full' # ['none', 'selective', 'full']

[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
Loading