Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
dagelf authored Apr 25, 2024
2 parents f368916 + bb56144 commit d195e1c
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: python prepro_tinyshakespeare.py

- name: Train model
run: python train_gpt2.py
run: python train_gpt2.py --device=cpu

- name: Compile training and testing program
run: make test_gpt2 train_gpt2
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ I attached a very small tutorial here, in [doc/layernorm/layernorm.md](doc/layer

The full training loop is also implemented in pure CUDA in one file, but optimizations of the kernels are ongoing. Currently, we roughly match the speed of PyTorch. The way we organize code is that we have a growing collection of kernels of increasing complexity in the `dev/cuda` folder, see [dev/cuda/README.md](dev/cuda/README.md). We then copy paste the best kernels into the main training loop in the single training file `train_gpt2cu.cu`.

**WIP alert, April 23**. We merged the first version of mixed precision training code. I checkpointed the fp32 version to separate files that include `_fp32` in their filename, and would like to preserve this version in the root of the repo because it 1) doesn't require the most up to date CUDA and will a lot more likely compile and is more portable, 2) it is a lot simpler and acts as reference. The "mainline" development of the CUDA version will from here on move mostly to the [train_gpt2.cu](train_gpt2.cu) file, which includes mixed precision. In the descriptions below I will default to using the fp32 version for now because it is currently more portable and stable, then at the end I will cover to the new mixed precision version.
**WIP alert, April 23**. We merged the first version of mixed precision training code. I checkpointed the fp32 version to separate files that include `_fp32` in their filename, and would like to preserve this version in the root of the repo because it 1) doesn't require the most up to date CUDA and will a lot more likely compile and is more portable, 2) it is a lot simpler and acts as reference. In fact, we'd like to diverge the fp32 version in the direction of being pure CUDA (e.g. do not even call cuBLAS by default), to be used as an educational reference, maybe even a kernel of a course on CUDA. The "mainline" development concerned with speed will from there on move to the [train_gpt2.cu](train_gpt2.cu) file, which includes mixed precision training.

In the descriptions below I will default to using the fp32 version for now because it is currently more portable and stable, then at the end I will cover to the new mixed precision version.

**Correctness**. First, we can do 10 iterations of training and verify that our code exactly matches and preproduces the numbers from PyTorch:

Expand Down Expand Up @@ -269,6 +271,13 @@ Lastly, I will be a lot more sensitive to complexity in the root folder of the p

- Metal
- [llm.metal](https://github.com/regrettable-username/llm.metal) by @[regrettable-username](https://github.com/regrettable-username): LLM training in simple, raw C/Metal Shading Language

- Zig
- [llm.zig]() by @[saimirbaci](https://github.com/Saimirbaci/llm.zig/): a Zig port of this project

- Go
- [llm.go](https://github.com/joshcarp/llm.go) by @[joshcarp](https://github.com/joshcarp): a Go port of this project

## discussions

Ways of organizing development:
Expand Down
8 changes: 7 additions & 1 deletion train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void cublasCheck(cublasStatus_t status, const char *file, int line)
#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }

// GPU helper functions for atomicAdd on smaller than 32-bit types
#ifdef ENABLE_BF16
__device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
__nv_bfloat162* ptr_bf16 = reinterpret_cast<__nv_bfloat162*>(ptr_val & ~uintptr_t(0x3));
Expand All @@ -105,6 +106,9 @@ __device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) {
: __halves2bfloat162(val, __ushort_as_bfloat16(0));
atomicAdd(ptr_bf16, add_val);
}
#endif

#ifdef ENABLE_FP16
__device__ void atomicAddX(half* addr, half val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
half2* ptr_fp16 = reinterpret_cast<half2*>(ptr_val & ~uintptr_t(0x3));
Expand All @@ -114,6 +118,8 @@ __device__ void atomicAddX(half* addr, half val) {
: __halves2half2(val, __ushort_as_half(0));
atomicAdd(ptr_fp16, add_val);
}
#endif

__device__ void atomicAddX(float* addr, float val) {
atomicAdd(addr, val);
}
Expand Down Expand Up @@ -1666,7 +1672,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {
fill_in_activation_sizes(model->act_sizes, B, T, model->config);
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += model->act_sizes[i] * sizeof(floatX);
num_activations += model->act_sizes[i];
}
model->num_activations = num_activations;
model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes);
Expand Down
162 changes: 123 additions & 39 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import math
import struct
from contextlib import nullcontext
from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -214,9 +215,19 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

# a few utilities for saving params/grads/activations to files for loading in C
def write_fp32(tensor, file):
file.write(tensor.detach().cpu().numpy().astype("float32").tobytes())

def write_tensors(model_tensors, L, file):
t = tensor.detach().cpu().to(torch.float32)
b = t.numpy().tobytes()
file.write(b)

def write_bf16(tensor, file):
t = tensor.detach().cpu().to(torch.bfloat16)
# numpy can't convert bf16 to bytes
# this way below *i think* works, but is SUPER slow or broken
# TODO fix :'(
b = bytes(t.untyped_storage())
file.write(b)

def write_tensors_fp32(model_tensors, L, file):
write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C)
write_fp32(model_tensors["transformer.wpe.weight"], file) # (T, C)
for i in range(L): # (L, C)
Expand Down Expand Up @@ -246,25 +257,65 @@ def write_tensors(model_tensors, L, file):
write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )

def write_model(model, filename):
def write_tensors_bf16(model_tensors, L, file):
# same as fp32, but note we will re-order the tensors
# because we keep the layernorm in fp32, we place them all at the end
write_bf16(model_tensors["transformer.wte.weight"], file) # (V, C)
write_bf16(model_tensors["transformer.wpe.weight"], file) # (T, C)
for i in range(L): # (L, 3C, C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
for i in range(L): # (L, 3C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file)
for i in range(L): # (L, C, C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
for i in range(L): # (L, C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file)
for i in range(L): # (L, 4C, C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
for i in range(L): # (L, 4C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file)
for i in range(L): # (L, C, 4C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
for i in range(L): # (L, C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file)
# LayerNorms are at the end and kept in fp32
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)
write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )

def write_model(model, filename, dtype):
# everything we need to instantiate the model
# 1) header is: version int, GPTConfig ints, padding to 1024 bytes
assert dtype in {"float32", "bfloat16"} # float16 todo maybe later
version = {
"float32": 1,
"bfloat16": 2,
}[dtype]
header = torch.zeros(256, dtype=torch.int32)
header[0] = 20240326 # magic
header[1] = 1 # checkpoint version = 1
header[1] = version # checkpoint version
header[2] = model.config.block_size
header[3] = model.config.vocab_size
header[4] = model.config.n_layer
header[5] = model.config.n_head
header[6] = model.config.n_embd
# 2) the parameters on CPU are next
# 2) the parameters follow the header
params = {name: param.cpu() for name, param in model.named_parameters()}
# now write
with open(filename, "wb") as file:
# header
# write header
file.write(header.numpy().tobytes())
# model parameters
write_tensors(params, model.config.n_layer, file)
# write params
if dtype == "float32":
write_tensors_fp32(params, model.config.n_layer, file)
elif dtype == "bfloat16":
write_tensors_bf16(params, model.config.n_layer, file)
print(f"wrote {filename}")

def write_state(model, x, y, logits, loss, filename):
Expand All @@ -289,7 +340,7 @@ def write_state(model, x, y, logits, loss, filename):
# loss (single float, result of the cross entropy loss)
write_fp32(loss.cpu(), file)
# gradients
write_tensors(grads, model.config.n_layer, file)
write_tensors_fp32(grads, model.config.n_layer, file)
print(f"wrote {filename}")

def write_tokenizer(enc, filename):
Expand Down Expand Up @@ -321,6 +372,8 @@ def write_tokenizer(enc, filename):
parser = argparse.ArgumentParser()
parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk")
parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
parser.add_argument("--dtype", type=str, default="float32", help="float32|float16|bfloat16")
parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores")
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
Expand All @@ -329,39 +382,53 @@ def write_tokenizer(enc, filename):
args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
assert args.dtype in {"float32", "float16", "bfloat16"}

# select a reasonable device to run on
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
# select the device
if args.device:
device = args.device
else:
# attempt to autodetect the device
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device: {device}")

# set up a context manager following the desired dtype and device
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) if device == "cuda" else nullcontext()

# seed the random number generators
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)

# init the tokenizer
# set the torch precision mode to use TensorFloat32 (TF32) for matmuls
# docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
if args.tensorcores:
torch.set_float32_matmul_precision('high')

# init (and write) the tokenizer
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

write_tokenizer(enc, "gpt2_tokenizer.bin")

if args.tensorcores:
torch.set_float32_matmul_precision('high')

# load the GPT-2 model weights
model = GPT.from_pretrained("gpt2")
model.train()
model.to(device)
if args.compile:
config.coordinate_descent_tuning = True # suggested by @Chillee
if hasattr(config, "coordinate_descent_tuning"):
config.coordinate_descent_tuning = True # suggested by @Chillee
print("compiling the model...")
model = torch.compile(model)

# -------------------------------------------------------------------------
# data loading related: long but it's just to get a single batch of data

# load the tokens
# prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
# we're using val instead of train split just because it is smaller/faster
Expand Down Expand Up @@ -392,45 +459,62 @@ def get_batch():
if i + B*T + 1 >= len(tokens):
i = 0 # in prod we'd want to randomize the start point a bit

# forward backward for a few iterations
# fetch one batch of data, which we will overfit to
data_iter = iter(get_batch())
x, y = next(data_iter) # we'll overfit this batch below

# -------------------------------------------------------------------------
# STAGE 1: weights / state logging for C to load later

# do one forward pass to generate ground truth for our C tests
if not args.inference_only and args.write_tensors:
logits, loss = model(x, y)
loss.backward()
write_model(model, "gpt2_124M.bin")
# save model params, in both float32 and bfloat16
write_model(model, "gpt2_124M.bin", dtype="float32")
# write_model(model, "gpt2_124M_bf16.bin", dtype="bfloat16")
# save x, y, logits, loss, and parameter gradients, for debugging C
# always store these in fp32 to have an accurate reference (?)
write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")

use_fused = device == "cuda" # only works on CUDA (?)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=use_fused)
timings = []
# -------------------------------------------------------------------------
# STAGE 2: training loop to get timings

# init the optimizer
adam_use_fused = device == "cuda" # only works on CUDA (?)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=adam_use_fused)

if device == "cuda":
torch.cuda.reset_peak_memory_stats()
timings = []
for i in range(args.num_iterations):
t0 = time.time()
logits, loss = model(x, y)
if not args.inference_only:
optimizer.zero_grad()
with ctx:
logits, loss = model(x, y)
del logits
loss.backward()
optimizer.step()
if not args.inference_only:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# wait on the CPU for all device work to end so we get accurate per-iteration timings below
if device == "mps":
torch.mps.synchronize()
elif device == "cuda":
torch.cuda.synchronize()
# time and print
t1 = time.time()
if i > args.num_iterations - 20:
# the 0th iteration is often an outlier (much slower) => skip logging it
if i > 0 and i > args.num_iterations - 20:
timings.append(t1-t0)
print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")

if len(timings) > 20:
print(f"final 20 iters avg: {np.mean(timings[-20:])*1000:.3f}ms")
else:
print(f"final {len(timings)-1} iters avg: {np.mean(timings[1:])*1000:.3f}ms")
# print the average of the last 20 timings, to get something smooth-ish
timings = timings[-20:]
print(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms")
print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

print(f"Peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
# -------------------------------------------------------------------------
# STAGE 3: Few steps of inference

# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
Expand Down

0 comments on commit d195e1c

Please sign in to comment.