Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishnarraj committed Oct 2, 2024
2 parents 51d9c91 + 7ecd890 commit a0fa4a4
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 60 deletions.
80 changes: 79 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ It can be used to launch the subsequent run:

```bash
MATMUL_TILE_SIZE=16 MATMUL_LOCAL_MEM_PADDING_SIZE=1 MATMUL_VLOAD_SIZE=8 MATMUL_DO_PRELOAD=1 MATMUL_USE_MAD=1 ./train_gpt2cl

## quick start (1 GPU, fp32 only)

If you won't be training on multiple nodes, aren't interested in mixed precision, and are interested in learning CUDA, the fp32 (legacy) files might be of interest to you. These are files that were "checkpointed" early in the history of llm.c and frozen in time. They are simpler, more portable, and possibly easier to understand. Run the 1 GPU, fp32 code like this:

```bash
chmod u+x ./dev/download_starter_pack.sh
./dev/download_starter_pack.sh
make train_gpt2fp32cu
./train_gpt2fp32cu
```

The download_starter_pack.sh script is a quick & easy way to get started and it downloads a bunch of .bin files that help get you off the ground. These contain: 1) the GPT-2 124M model saved in fp32, in bfloat16, 2) a "debug state" used in unit testing (a small batch of data, and target activations and gradients), 3) the GPT-2 tokenizer, and 3) the tokenized [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. Alternatively, instead of running the .sh script, you can re-create these artifacts manually as follows:

```bash
pip install -r requirements.txt
python dev/data/tinyshakespeare.py
python train_gpt2.py
```

## quick start (CPU)
Expand Down Expand Up @@ -165,6 +183,60 @@ sudo apt-get -y install libcudnn9-dev-cuda-12

On top of this you need the [cuDNN frontend](https://github.com/NVIDIA/cudnn-frontend/tree/main), but this is just header files. Simply clone the repo to your disk. The Makefile currently looks for it in either your home directory or the current directory. If you have put it elsewhere, add `CUDNN_FRONTEND_PATH=/path/to/your/cudnn-frontend/include` to the `make` command-line.

## multi-GPU training

Make sure you install MPI and NCCL, e.g. on Linux:

```bash
sudo apt install openmpi-bin openmpi-doc libopenmpi-dev
```

For NCCL follow the instructions from the [official website](https://developer.nvidia.com/nccl/nccl-download) (e.g. network installer)

and then:

```bash
make train_gpt2cu
mpirun -np <number of GPUs> ./train_gpt2cu
```

or simply run one of our scripts under `./scripts/`.

## multi-node training

Make sure you've installed `NCCL` following instructions from [multi-GPU](#multi-gpu-training) section.
There are 3 ways we currently support that allow you to run multi-node training:
1) Use OpenMPI to exchange nccl id and initialize NCCL. See e.g. `./scripts/multi_node/run_gpt2_124M_mpi.sh` script for details.
2) Use shared file system to init NCCL. See `./scripts/multi_node/run_gpt2_124M_fs.sbatch` script for details.
3) Use TCP sockets to init NCCL. See `./scripts/multi_node/run_gpt2_124M_tcp.sbatch` script for details.
Note:
* If you're running in a slurm environment and your slurm doesn't support PMIx (which we assume will be a common situation given that `slurm-wlm` dropped PMIx support) you will have to use FS (2) or TCP (3) approach. To test whether your slurm supports PMIx run: `srun --mpi=list` and see whether you get `pmix` in the output.
* If you don't have slurm set up, you can kick off a multi-node run using `mpirun` - MPI (1).

None of these 3 methods is superior, we just offer you options so that you can run in your specific environment.

## experiments / sweeps

Just as an example process to sweep learning rates on a machine with 4 GPUs on TinyStories. Run a shell script `sweep.sh` (after you of course `chmod u+x sweep.sh`):

```bash
#!/bin/bash
learning_rates=(3e-5 1e-4 3e-4 1e-3)
for i in {0..3}; do
export CUDA_VISIBLE_DEVICES=$i
screen -dmS "tr$i" bash -c "./train_gpt2cu -i data/TinyStories -v 250 -s 250 -g 144 -l ${learning_rates[$i]} -o stories$i.log"
done
# you can bring these down with
# screen -ls | grep -E "tr[0-3]" | cut -d. -f1 | xargs -I {} screen -X -S {} quit
```

This example opens up 4 screen sessions and runs the four commands with different LRs. This writes the log files `stories$i.log` with all the losses, which you can plot as you wish in Python. A quick example of how to parse and plot these logfiles is in [dev/vislog.ipynb](dev/vislog.ipynb).

## repo

A few more words on what I want this repo to be:
Expand All @@ -188,10 +260,16 @@ Lastly, I will be a lot more sensitive to complexity in the root folder of the p

- CUDA C++
- [llm.cpp](https://github.com/gevtushenko/llm.c) by @[gevtushenko](https://github.com/gevtushenko): a port of this project using the [CUDA C++ Core Libraries](https://github.com/NVIDIA/cccl)
- A presentation this fork was covered in [this lecture](https://www.youtube.com/watch?v=WiB_3Csfj_Q) in the [CUDA MODE Discord Server](https://discord.gg/cudamode)
- A presentation this fork was covered in [this lecture](https://www.youtube.com/watch?v=WiB_3Csfj_Q) in the [GPU MODE Discord Server](https://discord.gg/cudamode)

- C++/CUDA
- [llm.cpp](https://github.com/zhangpiu/llm.cpp/tree/master/llmcpp) by @[zhangpiu](https://github.com/zhangpiu): a port of this project using the [Eigen](https://gitlab.com/libeigen/eigen), supporting CPU/CUDA.

- WebGPU C++
- [gpu.cpp](https://github.com/AnswerDotAI/gpu.cpp) by @[austinvhuang](https://github.com/austinvhuang): a library for portable GPU compute in C++ using native WebGPU. Aims to be a general-purpose library, but also porting llm.c kernels to WGSL.

- C++
- [llm.cpp](https://github.com/GaoYusong/llm.cpp) by @[GaoYusong](https://github.com/GaoYusong): a port of this project featuring a C++ single-header [tinytorch.hpp](https://github.com/GaoYusong/llm.cpp/blob/main/tinytorch.hpp) library

- Go
- [llm.go](https://github.com/joshcarp/llm.go) by @[joshcarp](https://github.com/joshcarp): a Go port of this project
Expand Down
86 changes: 27 additions & 59 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
TODO: add the actual commands
"""

import argparse
import os
import math
import glob
import inspect
from contextlib import nullcontext
from dataclasses import dataclass
import json
from pathlib import Path
import time
from typing import (
AbstractSet,
Callable,
Collection,
Dict,
Iterator,
Expand Down Expand Up @@ -55,9 +55,6 @@
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the LLaMA 3.x model

# using a global to toggle flash-attention
FLASH = 0

# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
Expand Down Expand Up @@ -157,6 +154,7 @@ def __init__(self, config):
self.n_rep = self.n_head // self.n_kv_head
self.hd = config.n_embd // config.n_head
self.use_kv = config.use_kv
self.flash = config.flash

self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection
Expand Down Expand Up @@ -186,9 +184,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):

q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD)

if FLASH:
if self.flash:
# flashattention
y = F.scaled_dot_product_attention(q, k, v, mask)
# if T == 1 no need to mask, otherwise the function complains
# scaled_dot_product_attention expects a mask where value of True indicates that the element should take part in attention
# our mask is the opposite, so we need to invert it
y = F.scaled_dot_product_attention(q, k, v, mask == 0 if T > 1 else None)
else:
# manual implementation of attention
# this materializes the large (T,T) matrix for all the queries and keys
Expand Down Expand Up @@ -257,6 +258,7 @@ class LlamaConfig:
use_scaled_rope: bool = True
max_gen_batch_size: int = 4
use_kv: bool = True
flash: bool = False # use flashattention?

def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -402,7 +404,7 @@ def unpermute(w, n_heads, dim1, dim2):
def from_pretrained_llama3_hf(cls, model_id):
"""Loads pretrained LLaMA model weights from HuggingFace"""
from transformers import AutoModelForCausalLM, AutoTokenizer
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-bae model is supported for now"
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-base model is supported for now"
model_args = LlamaConfig()

model = AutoModelForCausalLM.from_pretrained(model_id)
Expand Down Expand Up @@ -477,7 +479,6 @@ def generate(
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
Expand All @@ -488,45 +489,35 @@ def generate(
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences.
Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.
"""
bsz = len(prompt_tokens)
assert bsz <= self.config.max_gen_batch_size, (bsz, self.config.max_gen_batch_size)
assert bsz <= self.config.max_gen_batch_size, f"Batch size {bsz} exceeds the maximum generation batch size {self.config.max_gen_batch_size}"
device = next(self.parameters()).device

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= self.config.block_size
assert max_prompt_len <= self.config.block_size, f"Prompt length {max_prompt_len} exceeds the maximum block size {self.config.block_size}"
total_len = min(self.config.block_size, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
for idx, t in enumerate(prompt_tokens):
tokens[idx, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id

if min_prompt_len == total_len:
logits, _ = self.forward(tokens, start_pos=prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)

stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device)

Expand All @@ -542,41 +533,25 @@ def generate(
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
eos_reached |= ~input_text_mask[:, cur_pos] & torch.isin(next_token, stop_tokens)
prev_pos = cur_pos
if all(eos_reached):
break

if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
out_tokens = []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = toks.index(stop_token)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None
except ValueError:
pass
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
return out_tokens

# -----------------------------------------------------------------------------
# sampling utils
Expand Down Expand Up @@ -959,18 +934,16 @@ def print0(*args, **kwargs):
print(*args, **kwargs)

if __name__ == "__main__":
import time
import argparse
print0(f"Running pytorch {torch.version.__version__}")

# default settings will overfit a tiny batch of data
# and save model weights and debug state to disk on the first iteration
parser = argparse.ArgumentParser()
parser.add_argument("--use_hf", type=int, default=1, help="use HuggingFace (default) or use Meta's model")
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer")
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint (needed if use_hf=0)")
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer (needed if use_hf=0)")
# file system input / output
parser.add_argument("--input_bin", type=str, default="dev/data/tinystories/TinyStories_val.bin", help="input .bin to train on")
parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on")
parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B", help="chose the llama model")
Expand All @@ -982,7 +955,7 @@ def print0(*args, **kwargs):
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
# optimization
parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate warmup iterations")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="learning rate warmup iterations")
parser.add_argument("--warmup_iters", type=int, default=0, help="learning rate warmup iterations")
parser.add_argument("--learning_rate_decay_frac", type=float, default=1.0, help="learning rate warmup iterations")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay")
Expand All @@ -998,7 +971,6 @@ def print0(*args, **kwargs):
# memory management
parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
parser.add_argument("--flash", type=int, default=0, help="use flash attention")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16")
parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)")
# python -> C bridge
Expand Down Expand Up @@ -1052,9 +1024,9 @@ def print0(*args, **kwargs):
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device: {device}")
device_type = 'cuda' if 'cuda' in device else 'cpu'
assert device_type in {'cuda'} # we need to load LLaMA as bf16 on CUDA
assert device_type in {'cuda'}, "GPU required to run LLaMA 3" # we need to load LLaMA as bf16 on CUDA
print(f"using device: {device}")

# calculate gradient accumulation from the desired total batch size and the current run configuration
tokens_per_fwdbwd = B * T * ddp_world_size
Expand All @@ -1077,16 +1049,12 @@ def print0(*args, **kwargs):
if args.tensorcores:
torch.set_float32_matmul_precision('high')

# turn on/off flash attention
assert args.flash in {0, 1}
FLASH = args.flash

# init the model
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
if args.use_hf:
model = LLaMA.from_pretrained_llama3_hf(args.model)
else: # use Meta's checkpoint
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path)

model.train()
Expand Down Expand Up @@ -1201,7 +1169,7 @@ def get_lr(it):
else: # Meta
prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False)
generation_tokens = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, echo=False)
results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens]
for prompt, result in zip(prompts, results):
print(prompt, end="")
Expand Down

0 comments on commit a0fa4a4

Please sign in to comment.