Skip to content

Commit

Permalink
first commit of just the reference cpu fp32 gpt2 training
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Apr 8, 2024
0 parents commit e8e1628
Show file tree
Hide file tree
Showing 8 changed files with 2,042 additions and 0 deletions.
43 changes: 43 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
CC = clang
CFLAGS = -O3 -Ofast
LDFLAGS =
LDLIBS = -lm
INCLUDES =

# Check if OpenMP is available
# This is done by attempting to compile an empty file with OpenMP flags
# OpenMP makes the code a lot faster so I advise installing it
# e.g. on MacOS: brew install libomp
# e.g. on Ubuntu: sudo apt-get install libomp-dev
# later, run the program by prepending the number of threads, e.g.: OMP_NUM_THREADS=8 ./gpt2
ifeq ($(shell echo | $(CC) -Xpreprocessor -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0)
ifeq ($(shell uname), Darwin)
# macOS with Homebrew
CFLAGS += -Xclang -fopenmp
LDFLAGS += -L/opt/homebrew/opt/libomp/lib
LDLIBS += -lomp
INCLUDES += -I/opt/homebrew/opt/libomp/include
else
# Ubuntu or other Linux distributions
CFLAGS += -fopenmp
LDLIBS += -lgomp
endif
$(info NICE Compiling with OpenMP support)
else
$(warning OOPS Compiling without OpenMP support)
endif

# PHONY means these targets will always be executed
.PHONY: all train_gpt2 test_gpt2

# default target is all
all: train_gpt2 test_gpt2

train_gpt2: train_gpt2.c
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@

test_gpt2: test_gpt2.c
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@

clean:
rm -f train_gpt2 test_gpt2
113 changes: 113 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# llm.c

LLM training in simple, pure C/CUDA. There is no need for 245MB of PyTorch or 107MB of cPython. For example, training GPT-2 (CPU, fp32) is ~1,000 lines of clean code in a single file. It compiles and runs instantly, and exactly matches the PyTorch reference implementation. I chose GPT-2 as the first working example because it is the grand-daddy of LLMs, the first time the modern stack was put together.

Currently, I am working on:

- direct CUDA implementation, which will be significantly faster and probably come close to PyTorch.
- speed up the CPU version with SIMD instructions, AVX2 on x86 / NEON on ARM (e.g. Apple Silicon).
- more modern architectures, e.g. Llama2, Gemma, etc.

For the repo, I'd like to maintain both clean, simple reference implementations alongside a also lot more optimized versions that can come close to PyTorch, but in a tiny fraction of the code and dependencies.

## quick start

Download and tokenize a dataset. The [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset is the fastest to download and tokenize:

```bash
python prepro_tinyshakespeare.py
```

This prints:

```
Saved 32768 tokens to data/tiny_shakespeare_val.bin
Saved 305260 tokens to data/tiny_shakespeare_train.bin
```

The .bin files are raw byte streams of int32 numbers indicating the token ids with the GPT-2 tokenizer. Alternatively you could also tokenize the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset with `prepro_tinystories.py`.

In principle we'd be ready to the train the model right here. However the baseline CPU/fp32 reference code is so inefficient that it's not practical to train these models from scratch yet. Instead, we initialize with the GPT-2 weights released by OpenAI and just do finetuning. For that, we have to download the GPT-2 weights and save them as a checkpoint we can load in C:

```bash
python train_gpt2.py
```

You'll recognize this code from nanoGPT as a simple GPT-2 reference implementation in PyTorch. This script will download the GPT-2 (124M) model, overfit a single batch of data for 10 iterations, run a few steps of generation, and most importantly it will save two files: 1) the `gpt2_124M.bin` file that contains the raw model weights for loading in C, and `gpt2_124M_debug_state.bin`, which also contains more debug state: the inputs, targets, logits and loss. This is very useful for debugging C code, for unit testing, and making sure we're exactly matching the PyTorch reference implementation. For now all we care about are the model weights in `gpt2_124M.bin`. We can now initialize with them and train in raw C. First compile the code:

```bash
make train_gpt2
```

You can have a look inside the `Makefile` and its comments. It will try to autodetect if OpenMP is available on your system, which is very helpful for speeding up the code at very low cost of code complexity. Once `train_gpt2` is compiled, you can run it:

```bash
OMP_NUM_THREADS=8 ./train_gpt2
```

You should tune the number of threads depending on how many cores your CPU has. The program will load the model weights, the tokens, it will run a finetuning loop for a few iterations with Adam lr 1e-4, and then generate a sample from the model. The file is (I think) very readable and you should have a look. Simply, there are implementations for the forward and backward pass of all the layers, and they get strung together into a large, manual, forward/backward/update loop. The output looks like this on my MacBook Pro (Apple Silicon M3 Max):

```
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
train dataset num_batches: 1192
val dataset num_batches: 128
num_activations: 73323776
val loss 5.252026
step 0: train loss 5.356189 (took 1452.121000 ms)
step 1: train loss 4.301069 (took 1288.673000 ms)
step 2: train loss 4.623322 (took 1369.394000 ms)
step 3: train loss 4.600470 (took 1290.761000 ms)
... (trunctated) ...
step 39: train loss 3.970751 (took 1323.779000 ms)
val loss 4.107781
generated: 50256 16773 18162 21986 11 198 13681 263 23875 198 3152 262 11773 2910 198 1169 6002 6386 2583 286 262 11858 198 20424 428 3135 7596 995 3675 13 198 40 481 407 736 17903 11 329 703 6029 706 4082 198 42826 1028 1128 633 263 11 198 10594 407 198 2704 454 680 1028 262 1027 28860 286 198 3237 323
step 40: train loss 4.377757 (took 1366.368000 ms)
```

The generation just gives you the token ids for now, which we have to decode back to text. We can implement this in C quite easily also, because decoding is very straight-forward, it's just string chunk lookups and prints. For now we can use tiktoken:

```python
import tiktoken
enc = tiktoken.get_encoding("gpt2")
print(enc.decode(list(map(int, "50256 16773 18162 21986 11 198 13681 263 23875 198 3152 262 11773 2910 198 1169 6002 6386 2583 286 262 11858 198 20424 428 3135 7596 995 3675 13 198 40 481 407 736 17903 11 329 703 6029 706 4082 198 42826 1028 1128 633 263 11 198 10594 407 198 2704 454 680 1028 262 1027 28860 286 198 3237 323".split()))))
```

which prints:

```
<|endoftext|>Come Running Away,
Greater conquer
With the Imperial blood
the heaviest host of the gods
into this wondrous world beyond.
I will not back thee, for how sweet after birth
Netflix against repounder,
will not
flourish against the earlocks of
Allay
```

I like how Netflix comes up, it's clear that the shadow of the training past is still lurking in the model. I did not attempt to tune the finetuning hyperparameters so it's quite likely this can be improved quite a bit, most likely especially if one was to train a bit longer.

## test

I am also attaching a simple unit test for making sure our C code agrees with the PyTorch code. Compile and run with:

```
make test_gpt2
./test_gpt2
```

This now loads the `gpt2_124M_debug_state.bin` file, runs a forward pass, compares the logits and loss with the PyTorch reference implementation, then it does 10 iterations of training with Adam and makes sure the losses match PyTorch.

```
## license
MIT
82 changes: 82 additions & 0 deletions prepro_tinyshakespeare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Downloads and tokenizes the TinyShakespeare dataset.
- The download is from Github.
- The tokenization is GPT-2 tokenizer with tiktoken
The output is written to a newly created data/ folder.
The script prints:
Saved 32768 tokens to data/tiny_shakespeare_val.bin
Saved 305260 tokens to data/tiny_shakespeare_train.bin
And runs in a few seconds depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
"""

import os
import requests
from tqdm import tqdm

import tiktoken
import numpy as np

DATA_CACHE_DIR = "data"
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={'<|endoftext|>'})

def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)

def download():
"""Downloads the TinyShakespeare dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

# download the TinyStories dataset, unless it's already downloaded
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")

def tokenize():
eot = enc._special_tokens['<|endoftext|>'] # end of text token
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
text = open(data_filename, 'r').read()
# let's treat every person's statement in the dialog as a separate document
text = "<|endoftext|>" + text
text = text.replace('\n\n', '\n\n<|endoftext|>')
# encode the text
tokens = encode(text)
tokens_np = np.array(tokens, dtype=np.int32)
# let's take the first 32,768 tokens as the validation split (~10%)
val_tokens_np = tokens_np[:32768]
train_tokens_np = tokens_np[32768:]
# save to file
val_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_val.bin")
train_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_train.bin")
with open(val_filename, "wb") as f:
f.write(val_tokens_np.tobytes())
with open(train_filename, "wb") as f:
f.write(train_tokens_np.tobytes())
# prints
print(f"Saved {len(val_tokens_np)} tokens to {val_filename}")
print(f"Saved {len(train_tokens_np)} tokens to {train_filename}")

if __name__ == "__main__":
download()
tokenize()
124 changes: 124 additions & 0 deletions prepro_tinystories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Downloads and tokenizes the TinyStories dataset.
- The download is from HuggingFace datasets.
- The tokenization is GPT-2 tokenizer with tiktoken
The output is written to a newly created data/ folder.
The script prints:
Tokenizing val split...
Saved 19043638 tokens to data/TinyStories_val.bin
Tokenizing train split...
Saved 925653391 tokens to data/TinyStories_train.bin
And runs in 1-2 minutes two depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
"""

import os
import glob
import json
import random
import requests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

import tiktoken
import numpy as np

DATA_CACHE_DIR = "data"
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)

def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)

def download():
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

# download the TinyStories dataset, unless it's already downloaded
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")

# unpack the tar.gz file into all the data shards (json files)
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True)
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}")
else:
print(f"{data_dir} already exists, skipping unpacking...")

# print a single example just for debugging and such
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
with open(shard_filenames[0], "r") as f:
data = json.load(f)
print("Download done.")
print(f"Number of shards: {len(shard_filenames)}")
#print(f"Example story:\n{data[0]}")

def process_shard(shard_index, shard_filename):
with open(shard_filename, "r") as f:
data = json.load(f)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
rng = random.Random(1337 + shard_index)
rng.shuffle(data)
all_tokens = []
for example in data:
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = encode(text)
all_tokens.append(eot)
all_tokens.extend(tokens)
return all_tokens

def tokenize():
# shard 0 will be the val split, rest is train
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
val_shards = [shard_filenames[0]]
train_shards = shard_filenames[1:]
for split_name, split_shards in [("val", val_shards), ("train", train_shards)]:

print(f"Tokenizing {split_name} split...")
all_tokens = []
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_shard, shard_index, shard_filename)
for shard_index, shard_filename in enumerate(split_shards)]
for future in as_completed(futures):
all_tokens.extend(future.result())

all_tokens_np = np.array(all_tokens, dtype=np.int32)
split_filename = os.path.join(DATA_CACHE_DIR, f"TinyStories_{split_name}.bin")
with open(split_filename, "wb") as f:
f.write(all_tokens_np.tobytes())
print(f"Saved {len(all_tokens_np)} tokens to {split_filename}")

if __name__ == "__main__":
download()
tokenize()

# Prints:
# Tokenizing val split...
# Saved 19043638 tokens to data/TinyStories_val.bin
# Tokenizing train split...
# Saved 925653391 tokens to data/TinyStories_train.bin
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
numpy
torch
tiktoken
Loading

0 comments on commit e8e1628

Please sign in to comment.