Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding NVIDIA hardware performance detection #555

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM pytorch/pytorch:2.4.1-cuda11.8-cudnn9-runtime

WORKDIR /workspace

COPY requirements.txt .

RUN pip install --no-cache-dir -r requirements.txt

RUN apt-get update \
&& apt-get install -y \
&& apt-get -y install gcc

CMD ["/bin/bash"]
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ Dependencies:
- `wandb` for optional logging <3
- `tqdm` for progress bars <3

### Docker

Build the container:
```
docker build -t nanogpt-env . --no-cache
```

And run it interactively to before launching training runs:
```
docker run -it --gpus all --network=host -v $(pwd):/workspace nanogpt-env
```

## quick start

If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers:
Expand Down Expand Up @@ -225,3 +237,7 @@ For more questions/discussions feel free to stop by **#nanoGPT** on Discord:
## acknowledgements

All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT!

## add mfu with h100

torch.cuda.get_device_name()
6 changes: 3 additions & 3 deletions config/train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
# launch as the following (e.g. in a screen session) and wait ~5 days:
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py

wandb_log = True
wandb_log = False
wandb_project = 'owt'
wandb_run_name='gpt2-124M'

# these make the total batch size be ~0.5M
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
batch_size = 12
batch_size = 104
block_size = 1024
gradient_accumulation_steps = 5 * 8
gradient_accumulation_steps = 128

# this makes total number of tokens be 300B
max_iters = 600000
Expand Down
5 changes: 2 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, config):
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.flash = True
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -286,7 +286,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):

return optimizer

def estimate_mfu(self, fwdbwd_per_iter, dt):
def estimate_mfu(self, fwdbwd_per_iter, dt, flops_promised):
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
# first estimate the number of flops we do per iteration.
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
Expand All @@ -298,7 +298,6 @@ def estimate_mfu(self, fwdbwd_per_iter, dt):
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# express our flops throughput as ratio of A100 bfloat16 peak flops
flops_achieved = flops_per_iter * (1.0/dt) # per second
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu

Expand Down
42 changes: 42 additions & 0 deletions multinode_launcher.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

# Parameters
#SBATCH --account=compute-account
#SBATCH --dependency=singleton
#SBATCH --error=/mnt/fsp/nanoGPT/results/nanoGPT_%j.err
#SBATCH --exclusive
#SBATCH --gpus-per-node=8
#SBATCH --job-name=nanoGPT
#SBATCH --mem=0
#SBATCH --nodes=16
#SBATCH --ntasks-per-node=1
#SBATCH --output=/mnt/fsp/nanoGPT/results/nanoGPT_%j.out
#SBATCH --partition=batch
#SBATCH --time=0-01:00:00

# setup
export TRANSFORMERS_OFFLINE=0
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
export NCCL_NVLS_ENABLE=0
export NVTE_DP_AMAX_REDUCE_INTERVAL=0
export NVTE_ASYNC_AMAX_REDUCTION=1
export NVTE_FUSED_ATTN=0

# Get the IP address of the first node
MASTER_ADDR=$(srun --ntasks=1 --nodes=1 hostname -I | awk '{print $1}')
MASTER_PORT=$(shuf -i 10000-65500 -n 1)

echo "Using MASTER_ADDR: $MASTER_ADDR"
echo "Using MASTER_PORT: $MASTER_PORT"

# command 1
srun --container-image /mnt/fsp/nanoGPT/nvcr.io+nvidia+pytorch+24.08-py3.sqsh \
--container-mounts /mnt/fsp/nanoGPT:/workspace \
--no-container-mount-home \
bash -c "CUDA_DEVICE_MAX_CONNECTIONS=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun \
--nnodes=$SLURM_NNODES \
--nproc_per_node=$SLURM_GPUS_PER_NODE \
--node_rank=\$SLURM_PROCID \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py config/train_gpt2.py"
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
datasets==2.21.0
tiktoken==0.7.0
transformers==4.44.2
wandb==0.16.6
46 changes: 39 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import math
import pickle
from contextlib import nullcontext
import subprocess
import re

import numpy as np
import torch
Expand All @@ -29,6 +31,29 @@

from model import GPTConfig, GPT


def get_nvidia_gpu_performance():
performance_map = {
"H100": 989e12, # 989 TFLOPS
"A100": 312e12 # 312 TFLOPS
}
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'],
capture_output=True, text=True, check=True)
gpu_name = result.stdout.strip()
match = re.search(r'(A100|H100)', gpu_name)
if match:
gpu_model = match.group(0)
return performance_map.get(gpu_model, "Unknown performance")
else:
return "Unknown performance"
except subprocess.CalledProcessError:
return "Unknown performance"
except Exception as e:
return "Unknown performance"

GPU_PERFORMANCE = get_nvidia_gpu_performance()

# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
Expand All @@ -45,7 +70,7 @@
wandb_run_name = 'gpt2' # 'run' + str(time.time())
# data
dataset = 'openwebtext'
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
gradient_accumulation_steps = 16 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024
# model
Expand Down Expand Up @@ -91,14 +116,18 @@
seed_offset = ddp_rank # each process gets a different seed
# world_size number of processes will be training simultaneously, so we can scale
# down the desired gradient accumulation iterations per process proportionally
assert gradient_accumulation_steps % ddp_world_size == 0
gradient_accumulation_steps //= ddp_world_size
# assert gradient_accumulation_steps % ddp_world_size == 0, f"Gradient accumulation steps {gradient_accumulation_steps} is not divisible by world size {ddp_world_size}"
# gradient_accumulation_steps //= ddp_world_size
gradient_accumulation_steps = ddp_world_size * 1 # FSP: 4 steps per GPU
else:
# if not ddp, we are running on a single gpu, and one process
master_process = True
seed_offset = 0
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
if master_process:
print(f"Effective batch size: {gradient_accumulation_steps * ddp_world_size * batch_size}")
print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
print(f"tokens per iteration will be: {tokens_per_iter:,}")

if master_process:
Expand Down Expand Up @@ -321,10 +350,13 @@ def get_lr(it):
# get loss as float. note: this is a CPU-GPU sync point
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
lossf = loss.item() * gradient_accumulation_steps
if local_iter_num >= 5: # let the training loop settle a bit
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
if GPU_PERFORMANCE != "Unknown performance":
if local_iter_num >= 5: # let the training loop settle a bit
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt, GPU_PERFORMANCE)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%, tk/s {tokens_per_iter/dt:.2f}")
else:
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, tk/s {tokens_per_iter/dt:.2f}")
iter_num += 1
local_iter_num += 1

Expand Down