diff --git a/Makefile b/Makefile index 06f7d960b..38156f2b2 100644 --- a/Makefile +++ b/Makefile @@ -66,7 +66,7 @@ else endif # PHONY means these targets will always be executed -.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu +.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu # Add targets TARGETS = train_gpt2 test_gpt2 @@ -87,10 +87,12 @@ train_gpt2: train_gpt2.c test_gpt2: test_gpt2.c $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@ -# possibly may want to disable warnings? e.g. append -Xcompiler -Wno-unused-result train_gpt2cu: train_gpt2.cu $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) -o $@ +train_gpt2fp32cu: train_gpt2_fp32.cu + $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) -o $@ + test_gpt2cu: test_gpt2.cu $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) -o $@ @@ -98,5 +100,4 @@ profile_gpt2cu: profile_gpt2.cu $(NVCC) $(NVCC_FLAGS) -lineinfo $< $(NVCC_LDFLAGS) -o $@ clean: - rm -f train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu - + rm -f train_gpt2 test_gpt2 train_gpt2cu train_gpt2fp32cu test_gpt2cu diff --git a/train_gpt2_fp32.cu b/train_gpt2_fp32.cu new file mode 100644 index 000000000..c2027e82b --- /dev/null +++ b/train_gpt2_fp32.cu @@ -0,0 +1,2097 @@ +/* +GPT-2 Transformer Neural Net trained in raw CUDA +Non-trivial notes to be aware of: + +We are being clever in the backward pass to conserve memory. +In particular, all parameters use a += in the backward pass, so we +can later do gradient accumulation. But all activations have = instead of += +because these are faster (just read, no write). This is okay for all activations +except for those in the residual stream, where the gradients have to add. We make +sure that those parts work out ok and that we do a += as necessary. E.g., +the layernorms are connected to the residuals so we += in layernorm backward. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// ---------------------------------------------------------------------------- +// CUDA utils + +// convenience macro for calculating grid/block dimensions for kernels +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +// CUDA error checking +void cudaCheck(cudaError_t error, const char *file, int line) { + if (error != cudaSuccess) { + printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, + cudaGetErrorString(error)); + exit(EXIT_FAILURE); + } +}; +#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__)) + +// cuBLAS error checking +void cublasCheck(cublasStatus_t status, const char *file, int line) +{ + if (status != CUBLAS_STATUS_SUCCESS) { + printf("[cuBLAS ERROR]: %d %s %d\n", status, file, line); + exit(EXIT_FAILURE); + } +} +#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); } + +// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK +static size_t cublaslt_workspace_size = 32 * 1024 * 1024; +static void* cublaslt_workspace = NULL; +static cublasComputeType_t cublas_compute_type; +cublasHandle_t cublas_handle; +cublasLtHandle_t cublaslt_handle; + +namespace cg = cooperative_groups; + +// ---------------------------------------------------------------------------- +// fread convenience utils, with nice handling of error checking using macros +// simple replace fopen, fread, fclose with fopenCheck, freadCheck, fcloseCheck + +FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { + FILE *fp = fopen(path, mode); + if (fp == NULL) { + fprintf(stderr, "Error: Failed to open file '%s' at %s:%d\n", path, file, line); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Path: %s\n", path); + fprintf(stderr, " Mode: %s\n", mode); + exit(EXIT_FAILURE); + } + return fp; +} + +#define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__) + +void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { + size_t result = fread(ptr, size, nmemb, stream); + if (result != nmemb) { + if (feof(stream)) { + fprintf(stderr, "Error: Unexpected end of file at %s:%d\n", file, line); + } else if (ferror(stream)) { + fprintf(stderr, "Error: File read error at %s:%d\n", file, line); + } else { + fprintf(stderr, "Error: Partial read at %s:%d. Expected %zu elements, read %zu\n", + file, line, nmemb, result); + } + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Expected elements: %zu\n", nmemb); + fprintf(stderr, " Read elements: %zu\n", result); + exit(EXIT_FAILURE); + } +} + +#define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__) + +void fclose_check(FILE *fp, const char *file, int line) { + if (fclose(fp) != 0) { + fprintf(stderr, "Error: Failed to close file at %s:%d\n", file, line); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + exit(EXIT_FAILURE); + } +} + +#define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__) + +// ---------------------------------------------------------------------------- +// malloc error-handling wrapper util + +void *malloc_check(size_t size, const char *file, int line) { + void *ptr = malloc(size); + if (ptr == NULL) { + fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Size: %zu bytes\n", size); + exit(EXIT_FAILURE); + } + return ptr; +} + +#define mallocCheck(size) malloc_check(size, __FILE__, __LINE__) + +// ---------------------------------------------------------------------------- +// all the kernels + +// warp-level reduction for finding the maximum value +__device__ float warpReduceMax(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + } + return val; +} + +// warp-level reduction for summing values +__device__ float warpReduceSum(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + } + return val; +} + +__global__ void encoder_forward_kernel2(float* out, + int* inp, float* wte, float* wpe, + int B, int T, int C) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int N = B * T * C; + + if (idx < N) { + int bt = idx / C; + int b = bt / T; + int t = bt % T; + int c = idx % C; + + int ix = inp[b * T + t]; + + float* out_btc = out + b * T * C + t * C + c; + float* wte_ix = wte + ix * C + c; + float* wpe_tc = wpe + t * C + c; + *out_btc = *wte_ix + *wpe_tc; + } +} + +// really bad naive kernel with atomicAdd +__global__ void encoder_backward_kernel(float* dwte, float* dwpe, + const float* dout, const int* inp, + int B, int T, int C) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int N = B * T * C; + + if (idx < N) { + int bt = idx / C; + int b = bt / T; + int t = bt % T; + int c = idx % C; + + int ix = inp[b * T + t]; + + const float* dout_btc = dout + b * T * C + t * C + c; + float* dwte_ix = dwte + ix * C + c; + float* dwpe_tc = dwpe + t * C + c; + + atomicAdd(dwte_ix, *dout_btc); + atomicAdd(dwpe_tc, *dout_btc); + } +} + +__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, + const float* __restrict__ inp, const float* __restrict__ weight, + const float* __restrict__ bias, int N, int C) { + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + if(idx >= N) { + return; + } + + // the row of input that this group of threads is responsible for + const float* x = inp + idx * C; + + // mean + float sum = 0.0f; + for (int i = warp.thread_rank(); i < C; i += warp.size()) { + sum += x[i]; + } + sum = cg::reduce(warp, sum, cg::plus{}); + float m = sum / C; + if(warp.thread_rank() == 0 && mean != nullptr) { + __stcs(mean + idx, m); + } + + // rstd + sum = 0.0f; + for (int i = warp.thread_rank(); i < C; i += warp.size()) { + float diff = x[i] - m; + sum += diff * diff; + } + sum = cg::reduce(warp, sum, cg::plus{}); + float s = rsqrtf(sum / C + 1e-5f); + if(warp.thread_rank() == 0 && rstd != nullptr) { + __stcs(rstd + idx, s); + } + + // final normalization and scaling by weight/bias + float* o = out + idx * C; + for (int c = warp.thread_rank(); c < C; c += warp.size()) { + // load and store using the .cs "streaming" hint to the compiler, + // indicating that this data will not be reused soon, and can be streamed through the caches + // this allows the threads to get more cache-hits for the (shared) weight and bias parameters + float n = s * (__ldcs(x+c) - m); + __stcs(o+c, n * weight[c] + bias[c]); + } +} + +__global__ void permute_kernel(float* q, float* k, float* v, + const float* inp, + int B, int N, int NH, int d) { + // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) + // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; + q[idx] = __ldcs(&inp[inp_idx]); + k[idx] = __ldcs(&inp[inp_idx + NH * d]); + v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]); + } +} + +__global__ void permute_kernel_backward(float* dinp, + const float* dq, const float* dk, const float* dv, + int B, int N, int NH, int d) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + + int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; + dinp[inp_idx] = dq[idx]; + dinp[inp_idx + NH * d] = dk[idx]; + dinp[inp_idx + 2 * (NH * d)] = dv[idx]; + } +} + +__global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) { + // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; + out[other_idx] = __ldcs(&inp[idx]); + } +} + +__global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; + dinp[idx] = dout[other_idx]; + } +} + +__device__ float& vec_at(float4& vec, int index) { + return reinterpret_cast(&vec)[index]; +} + +__device__ float vec_at(const float4& vec, int index) { + return reinterpret_cast(&vec)[index]; +} + +__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) { + // inp, out shape: (N, T, T), where N = B * NH + // fuses the multiplication by scale inside attention + // directly autoregressive, so we only compute the lower triangular part + // uses the online softmax algorithm + assert(T % 4 == 0); + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + // micro-optimization: we iterate backwards so that + // after the softmax backward operation completes, the cache retains the + // part of the matrix close to the upper left corner, which benefits the + // matmul operation that immediately follows. + // int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // forward order + int idx = (gridDim.x - blockIdx.x - 1) * warp.meta_group_size() + warp.meta_group_rank(); // backward order + if(idx >= N * T) { + return; + } + int own_pos = idx % T; + int pos_by_4 = own_pos / 4; + + // one row of inp, i.e. inp[idx, :] of shape (T,) + const float* x = inp + idx * T; + + // not INF, so we don't get NaNs accidentally when subtracting two values. + float maxval = -FLT_MAX; + float sumval = 0.0f; + + const float4* x_vec = reinterpret_cast(x); + for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) { + float4 v = x_vec[i]; + float old_maxval = maxval; + for(int k = 0; k < 4; ++k) { + maxval = fmaxf(maxval, vec_at(v, k)); + } + sumval *= expf(inv_temperature * (old_maxval - maxval)); + for(int k = 0; k < 4; ++k) { + sumval += expf(inv_temperature * (vec_at(v, k) - maxval)); + } + } + + if(4*pos_by_4 + warp.thread_rank() <= own_pos) { + float old_maxval = maxval; + maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]); + sumval *= expf(inv_temperature * (old_maxval - maxval)); + sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval)); + } + + float global_maxval = cg::reduce(warp, maxval, cg::greater{}); + sumval *= expf(inv_temperature * (maxval - global_maxval)); + + float sum = cg::reduce(warp, sumval, cg::plus{}); + float norm = 1.f / sum; + + // divide the whole row by the sum + for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) { + // recalculation is faster than doing the round-trip through memory. + float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval)); + __stcs(out + idx * T + i, ev * norm); + } +} + +__global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + out[idx] = __ldcs(&inp1[idx]) + __ldcs(&inp2[idx]); + } +} + +#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) +__global__ void gelu_forward_kernel(float* out, const float* inp, int N) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < N) { + float xi = inp[i]; + float cube = 0.044715f * xi * xi * xi; + out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))); + } +} + +__global__ void gelu_backward_kernel(float* dinp, const float* inp, const float* dout, const int N) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < N) { + float x = inp[i]; + float cube = 0.044715f * x * x * x; + float tanh_arg = GELU_SCALING_FACTOR * (x + cube); + float tanh_out = tanhf(tanh_arg); + float coshf_out = coshf(tanh_arg); + float sech_out = 1.0f / (coshf_out * coshf_out); + float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); + dinp[i] = local_grad * dout[i]; + } +} + +__global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int C) { + // out is (N, C) just like inp. Each row of inp will get softmaxed. + // same as kernel4, but optimised for very large Cs with advanced unrolling + + // The trick is to read into a register array (all indices known at compile time) + // and always read UNROLL_FACTOR values to maximise memory level parallelism + // even if we would be out of bounds, we set the index to min(C-1, idx) + // so we just do some unnecessary reads (obviously bad for small C) + // the writes are in a separate loop with a conditional check for out of bounds + // making it separate is necessary to convince the compiler to do the right thing + const int UNROLL_FACTOR = 8; + const int warpsPerBlock = blockDim.x / 32; + + extern __shared__ float shared[]; + int idx = blockIdx.x; + int tid = threadIdx.x; + int warpId = threadIdx.x / 32; // warp index within a block + int laneId = threadIdx.x % 32; // thread index within a warp + + // shared[] must be allocated to have 2 * warpsPerBlock elements + // first half for max values, the second half for sum values + float* maxvals = shared; + float* sumvals = &shared[warpsPerBlock]; + + if (tid >= C) { + maxvals[warpId] = -INFINITY; + sumvals[warpId] = 0.0f; + return; + } + + const float* x = inp + idx * C; // input + float* y = out + idx * C; // output + + // first, thread coarsening by directly accessing global memory in series + float maxval = -INFINITY; + for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { + #pragma unroll + for (int u = 0; u < UNROLL_FACTOR; u++) { + maxval = fmaxf(maxval, x[min(C - 1, i + u*blockDim.x)]); + } + } + + // now within-warp reductions for maxval + maxval = warpReduceMax(maxval); + // the 0th thread of each warp writes the maxval of that warp to shared memory + if (laneId == 0) maxvals[warpId] = maxval; + __syncthreads(); + // now the 0th thread reduces the maxvals in shared memory, i.e. across warps + if (tid == 0) { + float val = maxvals[tid]; + #pragma unroll + for (int i = 1; i < warpsPerBlock; i++) { + val = fmaxf(val, maxvals[i]); + } + // store the final max in the first position + maxvals[0] = val; + } + __syncthreads(); + // broadcast the max to all threads + float offset = maxvals[0]; + + // compute expf and write the result to global memory + // + thread coarsening for sum + float sumval = 0.0f; + for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { + float reg_array[UNROLL_FACTOR]; + #pragma unroll + for (int u = 0; u < UNROLL_FACTOR; u++) { + reg_array[u] = __ldcs(&x[min(C - 1, i + u*blockDim.x)]); + } + #pragma unroll + for (int u = 0; u < UNROLL_FACTOR; u++) { + if (i + u*blockDim.x < C) { + float output = expf(reg_array[u] - offset); + y[min(C - 1, i + u*blockDim.x)] = output; // compiler likes redundant min()?! + sumval += output; // combined into the same loop unlike kernel3 + } + } + } + + // okay now we calculated exp(x - max(x)) + // step 2: sum all the values and divide by the sum + + // within-warp reduction for sumval + sumval = warpReduceSum(sumval); + // write sumval to shared memory + if (laneId == 0) sumvals[warpId] = sumval; + __syncthreads(); + // inter-thread reduction of sum + if (tid == 0) { + float val = sumvals[tid]; + #pragma unroll + for (int i = 1; i < warpsPerBlock; ++i) { + val += sumvals[i]; + } + sumvals[0] = val; + } + __syncthreads(); + // broadcast the sum to all threads + float sum = sumvals[0]; + + // divide the whole row by the sum + for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) { + float reg_array[UNROLL_FACTOR]; + #pragma unroll + for (int u = 0; u < UNROLL_FACTOR; u++) { + reg_array[u] = y[min(C - 1, i + u*blockDim.x)]; + } + #pragma unroll + for (int u = 0; u < UNROLL_FACTOR; u++) { + if (i + u*blockDim.x < C) { + y[i + u*blockDim.x] = reg_array[u] / sum; + } + } + } +} + +// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to: +// dbias = dout.sum((0,1)) +// the idea is to employ one block to reduce along several columns, +// where each block has a width of 32 columns to ensure coalesced access. +// at the end we accumulate the reductions performed by the warps in each block via shared memory +__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { + // this kernel is launched with 1D grid_dim of OC/32 + // for example let's say block_size is 128 + extern __shared__ float smem[]; // of size block_size (128) + const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3 + const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31 + const int tl = blockIdx.x * warpSize; // pointer to the start column for this block + const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4 + + // pointer to the start of the column for one lane of threads + // so e.g. 4 threads (of the same lane_id) will reduce this one column + const float* dout_col = dout + tl + lane_id; + + // column reductions by looping through the rows + // each of the 4 threads offsets by its warp_id and then skips by vstep + // together these 4 threads cover all B*T rows of this (lane_id) column + // importantly, consecutive threads (in threadId) are processing adjacent columns, + // leading to a coalesced memory access pattern + float dout_sum = 0.0f; + for (int row = warp_id; row < B * T; row += vstep) { + dout_sum += dout_col[row * OC]; + } + smem[lane_id + warp_id * warpSize] = dout_sum; + __syncthreads(); + + // warp_id 0 reduces the shared memory column-wise, linearly + dout_sum = 0.0f; + if (warp_id == 0) { + for (int j = 0; j < vstep; j++) { + dout_sum += smem[lane_id + j * warpSize]; + } + dbias[tl + lane_id] += dout_sum; + } +} + +// uses shared memory instead for the reduces +__global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* dbias, + const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd, + int B, int T, int C) { + extern __shared__ float shared[]; // size = 2 * C + + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + int N = B * T; + if(idx >= N) { return; } // thread guards + + int b = idx / T; + int t = idx % T; + + const float* dout_bt = dout + b * T * C + t * C; + const float* inp_bt = inp + b * T * C + t * C; + float* dinp_bt = dinp + b * T * C + t * C; + const float mean_bt = mean[b * T + t]; + const float rstd_bt = rstd[b * T + t]; + + // the first half of shared memory is bias, second is weight + float* dbias_shared = shared; + float* dweight_shared = shared + C; + + // init shared memory to zero + #pragma unroll + for(int i = threadIdx.x; i < C; i+= blockDim.x){ + dbias_shared[i] = 0.0f; + dweight_shared[i] = 0.0f; + } + __syncthreads(); + + // first: two reduce operations + float dnorm_mean = 0.0f; + float dnorm_norm_mean = 0.0f; + for (int i = warp.thread_rank(); i < C; i += warp.size()) { + float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; + float dnorm_i = weight[i] * dout_bt[i]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * norm_bti; + } + dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus{}); + dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus{}); + dnorm_mean = dnorm_mean / C; + dnorm_norm_mean = dnorm_norm_mean / C; + + // now iterate again and accumulate all the gradients + for (int i = warp.thread_rank(); i < C; i += warp.size()) { + float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt; + float dnorm_i = weight[i] * dout_bt[i]; + // gradient contribution to bias + atomicAdd(&dbias_shared[i], dout_bt[i]); + // gradient contribution to weight + atomicAdd(&dweight_shared[i], norm_bti * dout_bt[i]); + // gradient contribution to input + float dval = 0.0f; + dval += dnorm_i; // term 1 + dval -= dnorm_mean; // term 2 + dval -= norm_bti * dnorm_norm_mean; // term 3 + dval *= rstd_bt; // final scale + dinp_bt[i] += dval; + } + __syncthreads(); + + // write to global memory + for(int i = threadIdx.x; i < C; i+= blockDim.x){ + atomicAdd(&dbias[i], dbias_shared[i]); + atomicAdd(&dweight[i], dweight_shared[i]); + } +} + +__global__ void softmax_autoregressive_backward_kernel(float* dpreatt, const float* datt, const float* att, + int B, int T, int C, float scale) { + constexpr const int BlockSize = 256; + constexpr int T_per_block = 4; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + __shared__ float block_acc[32]; + + int idx = blockIdx.y; + // go through blocks in reverse order, so the slowest block starts first + int t0 = T - 1 - T_per_block*blockIdx.x; + + att += idx * T * T; + datt += idx * T * T; + dpreatt += idx * T * T; + + if (warp.meta_group_rank() == 0) { + block_acc[warp.thread_rank()] = 0; + } + + for(int to = 0; to < T_per_block; ++to) { + int t = t0 - to; + if(t < 0) return; + const float* att_bth = att + t * T; + const float* datt_bth = datt + t * T; + float* dpreatt_bth = dpreatt + t * T; + + float local_sum = 0; + for (int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) { + local_sum += att_bth[t2] * datt_bth[t2]; + } + + block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus{}); + block.sync(); + local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus{}); + + for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) { + // don't touch the cache. Some parts will still be here from the previous loop, and + // we want to exploit those. + float acc = __ldcs(att_bth + t3) * (__ldcs(datt_bth + t3) - local_sum); + __stcs(dpreatt_bth + t3, scale * acc); + } + } +} + +// Implements linear interpolation using only two floating-point operations (as opposed to three in a naive implementation). +// Reference: https://developer.nvidia.com/blog/lerp-faster-cuda +__device__ inline float lerp(float start, float end, float weight) { + return fma(weight, end, fma(-weight, start, start)); +} + +__global__ void adamw_kernel2(float* params_memory, float* grads_memory, float* m_memory, float* v_memory, long num_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= num_parameters) return; // guard + float grad = grads_memory[i]; + float m = m_memory[i]; + float v = v_memory[i]; + // update the first moment (momentum) + m = lerp(grad, m, beta1); + m_memory[i] = m; + // update the second moment (RMSprop) + v = lerp(grad * grad, v, beta2); + v_memory[i] = v; + m /= beta1_correction; // m_hat + v /= beta2_correction; // v_hat + params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]); +} + +struct SoftmaxParams { + float Scale; + float Offset; +}; + + +__device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp, + int idx, const float* inp, int V, int P) { + // same but not float4 + // one row of inp, i.e. inp[idx, :] of shape (V,) + + const float* x = inp + idx * P; + float thread_maxval = -INFINITY; + float thread_sumval = 0.0f; + // do the loop in reverse to maximise probability of L2 cache hits + // so even small L2s get some hits on the 2nd read of the same thread + for (int i = V + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { + float v = x[i]; + float old_maxval = thread_maxval; + thread_maxval = fmaxf(thread_maxval, v); + thread_sumval *= expf((old_maxval - thread_maxval)); + thread_sumval += expf(v - thread_maxval); + } + + // two reductions of up to 1024 threads: + // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) + // this results in much cleaner assembly than a multi-warp cg::reduce + __shared__ float shared_maxval[32]; + __shared__ float shared_sumval[32]; + int num_warps = blockDim.x / 32; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + // reduce maxval within each warp + float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater{}); + // thread 0 in each warp writes to shared memory + if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; } + __syncthreads(); + // each thread now loads the maxval across previous warps + // if the thread is "out of range" of data, use -FLT_MAX as the maxval + warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX; + // now reduce the maxval among the warp threads + float block_maxval = cg::reduce(warp, warp_maxval, cg::greater{}); + // each thread uses maxval to scale sumval to avoid numerical instability / overflow + thread_sumval *= expf(thread_maxval - block_maxval); + // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory + float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus{}); + if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; } + __syncthreads(); + // same strategy, now reduce sumval across warps + warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f; + float block_sumval = cg::reduce(warp, warp_sumval, cg::plus{}); + // return the softmax parameters + return SoftmaxParams{1.f / block_sumval, block_maxval}; +} + +// same as 2 but not using float4 (see dev/cuda/classifier_fused.cu) +// will _update_ logits to logit gradients +__global__ void fused_classifier_kernel3(float* logits, float* losses, float* probs, + const float* dlosses, const int* targets, + int B, int T, int V, int P) { + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + int idx = blockIdx.x; + int ix = targets[idx]; + + // softmax (reading B * T * V, same logits read again below, hopefully still in cache) + SoftmaxParams sp = prepare_softmax_blockwide_nofloat4(warp, idx, logits, V, P); + + // calculate the probability needed for the loss and update (single-threaded) + if(threadIdx.x == 0) { + float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale; + losses[idx] = -logf(prob); + } + + // very sensible default for dlosses is 1/(B*T), which is the uniform loss + float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T); + // calculate the gradients directly, saves bandwidth from probs during training + // but also supports writing probs for inference-only and debugging + const float* logits_vec = logits + idx * P; + for (int i = threadIdx.x; i < V; i += blockDim.x) { + // this is the 2nd read of logits after the one in prepare_softmax2 + // this data will never be needed again, so we reduce cache persistence + float v = __ldcs(&logits_vec[i]); + float prob = expf(v - sp.Offset) * sp.Scale; + if (probs != NULL) { + probs[idx * P + i] = prob; + } + float indicator = (i == ix) ? 1.0f : 0.0f; + logits[idx * P + i] = (prob - indicator) * dloss; + } +} + +// ---------------------------------------------------------------------------- +// kernel launchers + +void encoder_forward(float* out, + int* inp, float* wte, float* wpe, + int B, int T, int C) { + const int N = B * T * C; + const int block_size = 256; + const int grid_size = CEIL_DIV(N, block_size); + encoder_forward_kernel2<<>>(out, inp, wte, wpe, B, T, C); + cudaCheck(cudaGetLastError()); +} + +void encoder_backward(float* dwte, float* dwpe, + const float* dout, const int* inp, + int B, int T, int C) { + const int N = B * T * C; + const int block_size = 256; + const int grid_size = CEIL_DIV(N, block_size); + encoder_backward_kernel<<>>(dwte, dwpe, dout, inp, B, T, C); + cudaCheck(cudaGetLastError()); +} + +void layernorm_forward(float* out, float* mean, float* rstd, + float* inp, float* weight, float* bias, + int B, int T, int C) { + const int block_size = 512; + const int N = B * T; + const int grid_size = CEIL_DIV(N * 32, block_size); + layernorm_forward_kernel3<<>>(out, mean, rstd, inp, weight, bias, N, C); + cudaCheck(cudaGetLastError()); +} + +// uses cuBLAS +void matmul_forward_cublas(float* out, + float* inp, float* weight, float* bias, + int B, int T, int C, int OC) { + assert(bias == NULL); // bias is not supported for this kernel + const float alpha = 1.0f; + const float beta = 0.0f; + cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC)); +} + +// uses cuBLASLt to fuse the bias and gelu. does not work with OC = 50257 (last layer) +// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul +// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu +void matmul_forward_cublaslt(float* out, + float* inp, float* weight, float* bias, + int B, int T, int C, int OC) { + int has_bias = (bias != NULL); + + // check bias alignment + if(((uintptr_t)bias % 16) != 0) { + printf("Bias pointer is not aligned (cuBLASLt requirement)!\n"); + exit(EXIT_FAILURE); + } + + int returnedResults = 0; + cublasLtMatmulDesc_t operationDesc; + cublasLtMatmulPreference_t preference; + cublasLtMatrixLayout_t weightLayout; + cublasLtMatrixLayout_t inputLayout; + cublasLtMatrixLayout_t outputLayout; + cublasLtMatrixLayout_t biasLayout; + cublasLtMatmulHeuristicResult_t heuristic; + + // create the operation descriptor + cublasOperation_t opNoTranspose = CUBLAS_OP_N; + cublasOperation_t opTranspose = CUBLAS_OP_T; + cublasLtEpilogue_t epilogueBias = CUBLASLT_EPILOGUE_BIAS; + cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute_type, CUDA_R_32F)); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias, sizeof(epilogueBias))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); + + // define matrix layouts + cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C)); + cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C)); + cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, OC, B*T, OC)); + cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUDA_R_32F, OC, 1, OC)); + + // create a preference handle with specified max workspace + cublasCheck(cublasLtMatmulPreferenceCreate(&preference)); + cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &cublaslt_workspace_size, sizeof(cublaslt_workspace_size))); + + // find a suitable algorithm + cublasCheck(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, + weightLayout, inputLayout, outputLayout, outputLayout, + preference, 1, &heuristic, &returnedResults)); + if (returnedResults == 0) { + printf("No cuBLASLt algorithm: B: %d, T: %d, C: %d, OC: %d, bias: %d\n", B, T, C, OC, has_bias); + exit(EXIT_FAILURE); + } + + // call the matmul + const float alpha = 1.0f, beta = 0.0f; + cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc, + &alpha, weight, weightLayout, inp, inputLayout, &beta, + out, outputLayout, out, outputLayout, &heuristic.algo, + cublaslt_workspace, cublaslt_workspace_size, 0)); + + // cleanups + cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); + cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); + cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(inputLayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout)); +} + +void attention_forward(float* out, float* qkvr, float* att, + float* inp, + int B, int T, int C, int NH) { + // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer. + // Its contents will be overwritten by this function. + const int block_size = 256; + const int softmax_block_size = 256; + + // inp is (B, T, 3C) QKV + // preatt, att are (B, NH, T, T) + // output is (B, T, C) + int HS = C / NH; // head size + + // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) + float *q, *k, *v; + q = qkvr + 0 * B * T * C; + k = qkvr + 1 * B * T * C; + v = qkvr + 2 * B * T * C; + int total_threads = B * NH * T * HS; + int num_blocks = CEIL_DIV(total_threads, block_size); + permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); + cudaCheck(cudaGetLastError()); + + // batched matrix multiply with cuBLAS + const float alpha = 1.0f; + const float beta = 0.0f; + float* preatt = inp; + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH)); + + // multiply all elements of preatt elementwise by scale + float scale = 1.0 / sqrtf(HS); + int grid_size = CEIL_DIV(B * NH * T * 32, softmax_block_size); + softmax_forward_kernel5<<>>(att, scale, preatt, B * NH, T); + cudaCheck(cudaGetLastError()); + + // new approach: first cuBLAS another batched matmul + float* vaccum = inp; + // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, v, HS, T * HS, att, T, T * T, &beta, vaccum, HS, T * HS, B * NH)); + + // now unpermute + // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + num_blocks = CEIL_DIV(B * T * C, block_size); + unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); + cudaCheck(cudaGetLastError()); +} + +void residual_forward(float* out, float* inp1, float* inp2, int N) { + const int block_size = 256; + const int grid_size = CEIL_DIV(N, block_size); + residual_forward_kernel<<>>(out, inp1, inp2, N); + cudaCheck(cudaGetLastError()); +} + +void gelu_forward(float* out, const float* inp, int N) { + const int block_size = 128; + const int grid_size = CEIL_DIV(N, block_size); + gelu_forward_kernel<<>>(out, inp, N); + cudaCheck(cudaGetLastError()); +} + +void gelu_backward(float* dinp, const float* inp, const float* dout, const int N) { + const int block_size = 128; + const int grid_size = CEIL_DIV(N, block_size); + gelu_backward_kernel<<>>(dinp, inp, dout, N); + cudaCheck(cudaGetLastError()); +} + +void softmax_forward(float* out, float* inp, int N, int C) { + int grid_size = N; + const int block_size = 512; + size_t shared_mem_size = 2 * block_size / 32 * sizeof(float); + softmax_forward_kernel7<<>>(out, inp, N, C); + cudaCheck(cudaGetLastError()); +} + +void matmul_backward(float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, + int B, int T, int C, int OC) { + float one = 1.0f; + float zero = 0.0f; + // backward to input, uses = in the backward pass (set the gradient) + cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one, weight, C, dout, OC, &zero, dinp, C)); + // backward to weight, uses += in the backward pass (accumulate the gradient) + cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C)); + // backward to bias, if given, does a += + if (dbias != NULL) { + const int block_size = 1024; + const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work + matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); + cudaCheck(cudaGetLastError()); + } +} + +void layernorm_backward(float* dinp, float* dweight, float* dbias, + const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd, + int B, int T, int C) { + const int block_size = 512; + const int N = B * T; + const int grid_size = CEIL_DIV(32*N, block_size); + size_t shared_mem_size = 2 * C * sizeof(float); + layernorm_backward_kernel2<<>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C); + cudaCheck(cudaGetLastError()); +} + +// the sequence of transformations in this compound op is: +// inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) +void attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* scratch, + const float* dout, + const float* qkvr, const float* att, + int B, int T, int C, int NH) { + const int block_size = 256; + int HS = C / NH; // head size + const float one = 1.0f; + const float zero = 0.0f; // note beta = 1.0f so that we accumulate gradients (+=) + // unpack convenience pointers into q, k, v + const float *q, *k, *v; + q = qkvr + 0 * B * T * C; + k = qkvr + 1 * B * T * C; + v = qkvr + 2 * B * T * C; + float *dq, *dk, *dv; + dq = dqkvr + 0 * B * T * C; + dk = dqkvr + 1 * B * T * C; + dv = dqkvr + 2 * B * T * C; + // backward through the unpermute operation + int num_blocks = CEIL_DIV(B * T * C, block_size); + unpermute_kernel_backward<<>>(scratch, dout, B, T, NH, HS); + cudaCheck(cudaGetLastError()); + // backward into datt + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &one, v, HS, T * HS, scratch, HS, T * HS, &zero, datt, T, T * T, B * NH)); + // backward into dv + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, scratch, HS, T * HS, att, T, T * T, &zero, dv, HS, T * HS, B * NH)); + // backward into preatt + int hs = C / NH; // head size + float scale = 1.0f / sqrtf(hs); + softmax_autoregressive_backward_kernel<<>>(dpreatt, datt, att, B, T, C, scale); + cudaCheck(cudaGetLastError()); + // backward into q + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &one, k, HS, T * HS, dpreatt, T, T * T, &zero, dq, HS, T * HS, B * NH)); + // backward into k + cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, q, HS, T * HS, dpreatt, T, T * T, &zero, dk, HS, T * HS, B * NH)); + // backward into inp + num_blocks = CEIL_DIV(B * NH * T * HS, block_size); + permute_kernel_backward<<>>(dinp, dq, dk, dv, B, T, NH, HS); + cudaCheck(cudaGetLastError()); +} + +// replaces logits with logit gradients +void fused_classifier3(float* logits, float* losses, + const float* dlosses, const int* targets, + int B, int T, int V, int P) { + const int block_size = 1024; + const int N = B * T; + const int grid_size = N; + fused_classifier_kernel3<<>>(logits, losses, NULL, dlosses, targets, B, T, V, P); + cudaCheck(cudaGetLastError()); +} + +// ---------------------------------------------------------------------------- +// GPT-2 model definition + +typedef struct { + int max_seq_len; // max sequence length, e.g. 1024 + int vocab_size; // vocab size, e.g. 50257 + int num_layers; // number of layers, e.g. 12 + int num_heads; // number of heads in attention, e.g. 12 + int channels; // number of channels, e.g. 768 +} GPT2Config; + +// the parameters of the model +#define NUM_PARAMETER_TENSORS 16 +typedef struct { + float* wte; // (V, C) + float* wpe; // (maxT, C) + float* ln1w; // (L, C) + float* ln1b; // (L, C) + float* qkvw; // (L, 3*C, C) + float* qkvb; // (L, 3*C) + float* attprojw; // (L, C, C) + float* attprojb; // (L, C) + float* ln2w; // (L, C) + float* ln2b; // (L, C) + float* fcw; // (L, 4*C, C) + float* fcb; // (L, 4*C) + float* fcprojw; // (L, C, 4*C) + float* fcprojb; // (L, C) + float* lnfw; // (C) + float* lnfb; // (C) +} ParameterTensors; + +void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) { + int V = config.vocab_size; + int C = config.channels; + int maxT = config.max_seq_len; + int L = config.num_layers; + param_sizes[0] = V * C; // wte + param_sizes[1] = maxT * C; // wpe + param_sizes[2] = L * C; // ln1w + param_sizes[3] = L * C; // ln1b + param_sizes[4] = L * (3 * C) * C; // qkvw + param_sizes[5] = L * (3 * C); // qkvb + param_sizes[6] = L * C * C; // attprojw + param_sizes[7] = L * C; // attprojb + param_sizes[8] = L * C; // ln2w + param_sizes[9] = L * C; // ln2b + param_sizes[10] = L * (4 * C) * C; // fcw + param_sizes[11] = L * (4 * C); // fcb + param_sizes[12] = L * C * (4 * C); // fcprojw + param_sizes[13] = L * C; // fcprojb + param_sizes[14] = C; // lnfw + param_sizes[15] = C; // lnfb +} + +// allocate memory for the parameters and point the individual tensors to the right places +float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes, int on_device) { + // on_device: 0 = CPU, 1 = GPU + // calculate the number of parameters + size_t num_parameters = 0; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += param_sizes[i]; + } + // malloc all parameters all at once on the device + float* params_memory; + if (on_device) { + cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters * sizeof(float))); + } else { + params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); + } + // assign all the tensors their place in the array + float** ptrs[] = { + ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, + ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb + }; + float* params_memory_iterator = params_memory; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + *(ptrs[i]) = params_memory_iterator; + params_memory_iterator += param_sizes[i]; + } + return params_memory; +} + +#define NUM_ACTIVATION_TENSORS 21 +typedef struct { + float* encoded; // (B, T, C) + float* ln1; // (L, B, T, C) + float* ln1_mean; // (L, B, T) + float* ln1_rstd; // (L, B, T) + float* atty; // (L, B, T, C) + float* att; // (L, B, NH, T, T) + float* attproj; // (L, B, T, C) + float* residual2; // (L, B, T, C) + float* ln2; // (L, B, T, C) + float* ln2_mean; // (L, B, T) + float* ln2_rstd; // (L, B, T) + float* fch; // (L, B, T, 4*C) + float* fch_gelu; // (L, B, T, 4*C) + float* fcproj; // (L, B, T, C) + float* residual3; // (L, B, T, C) + float* lnf; // (B, T, C) + float* lnf_mean; // (B, T) + float* lnf_rstd; // (B, T) + + float* losses; // (B, T) + // adding these two compared to the CPU .c code, needed for attention kernel as buffers + float* qkvr; // (L, B, T, 3*C) + // in inference mode, this buffer will store the logits + // in training mode, this buffer will contain the *gradients* of the logits. + // during the processing of transformer blocks, we will also use this as a + // general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C), + // (B, NH, T, T), and (B, T, V) shaped tensors. + float* output; +} ActivationTensors; + +void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config) { + size_t V = config.vocab_size; + size_t L = config.num_layers; + size_t NH = config.num_heads; + size_t C = config.channels; + act_sizes[0] = B * T * C; // encoded + act_sizes[1] = L * B * T * C; // ln1 + act_sizes[2] = L * B * T; // ln1_mean + act_sizes[3] = L * B * T; // ln1_rstd + act_sizes[4] = L * B * T * C; // atty + act_sizes[5] = L * B * NH * T * T; // att + act_sizes[6] = L * B * T * C; // attproj + act_sizes[7] = L * B * T * C; // residual2 + act_sizes[8] = L * B * T * C; // ln2 + act_sizes[9] = L * B * T; // ln2_mean + act_sizes[10] = L * B * T; // ln2_rstd + act_sizes[11] = L * B * T * 4*C; // fch + act_sizes[12] = L * B * T * 4*C; // fch_gelu + act_sizes[13] = L * B * T * C; // fcproj + act_sizes[14] = L * B * T * C; // residual3 + act_sizes[15] = B * T * C; // lnf + act_sizes[16] = B * T; // lnf_mean + act_sizes[17] = B * T; // lnf_rstd + act_sizes[18] = B * T; // losses + act_sizes[19] = L * B * T * 3*C; // qkvr + act_sizes[20] = B * T * max(3*C, max(NH*T, V)); // output / scratch +} + +// Backward pass is conceptually quite different from forward, because we can discard +// the activations of a layer as soon as we're done with it. This lets us aggressively +// reuse memory, so that we need far fewer tensors for backward state. +#define NUM_BACKWARD_TENSORS 3 +typedef struct { + float* bt4c; // (B, T, 4*C) + float* preatt; // (B, NH, T, T) + float* residual3; // (B, T, C) +} GradActTensors; + + +void fill_in_grad_act_sizes(size_t* act_sizes, int B, int T, GPT2Config config) { + size_t NH = config.num_heads; + size_t C = config.channels; + act_sizes[0] = B * T * 4 * C; // bt4c + act_sizes[1] = B * NH * T * T; // preatt + act_sizes[2] = B * T * C; // residual3 +} + + +float* malloc_and_point(float** targets[], const size_t* act_sizes, int n) { + size_t num_activations = 0; + for (size_t i = 0; i < n; i++) { + num_activations += act_sizes[i]; + } + float* acts_memory; + cudaCheck(cudaMalloc((void**)&acts_memory, num_activations * sizeof(float))); + float* acts_memory_iterator = acts_memory; + for (size_t i = 0; i < n; i++) { + *(targets[i]) = acts_memory_iterator; + acts_memory_iterator += act_sizes[i]; + } + return acts_memory; +} + +float* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_sizes) { + float** ptrs[] = { + &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->atty, + &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean, + &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf, + &acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output + }; + return malloc_and_point(ptrs, act_sizes, NUM_ACTIVATION_TENSORS); +} + +float* malloc_and_point_backward(GradActTensors* acts, const size_t* act_sizes) { + float** ptrs[] = { + &acts->bt4c, &acts->preatt, &acts->residual3 + }; + return malloc_and_point(ptrs, act_sizes, NUM_BACKWARD_TENSORS); +} + +typedef struct { + GPT2Config config; + // the weights of the model, and their sizes + ParameterTensors params; + size_t param_sizes[NUM_PARAMETER_TENSORS]; + float* params_memory; + size_t num_parameters; + // gradients of the weights + ParameterTensors grads; + float* grads_memory; + // buffers for the AdamW optimizer + float* m_memory; + float* v_memory; + // the activations of the model, and their sizes + ActivationTensors acts; + size_t act_sizes[NUM_ACTIVATION_TENSORS]; + float* acts_memory; + size_t num_activations; + // gradients of the activations + GradActTensors grads_acts; + size_t num_grad_acts; + float* grads_acts_memory; + // other run state configuration + int batch_size; // the batch size (B) of current forward pass + int seq_len; // the sequence length (T) of current forward pass + int* inputs; // the input tokens for the current forward pass + int* targets; // the target tokens for the current forward pass + float mean_loss; // after a forward pass with targets, will be populated with the mean loss + float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost +} GPT2; + +void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { + + // read in model from a checkpoint file + FILE *model_file = fopenCheck(checkpoint_path, "rb"); + int model_header[256]; + freadCheck(model_header, sizeof(int), 256, model_file); + if (model_header[0] != 20240326) { printf("Bad magic model file"); exit(EXIT_FAILURE); } + if (model_header[1] != 1) { printf("Bad version in model file"); exit(EXIT_FAILURE); } + + // read in hyperparameters + model->config.max_seq_len = model_header[2]; + model->config.vocab_size = model_header[3]; + model->config.num_layers = model_header[4]; + model->config.num_heads = model_header[5]; + model->config.channels = model_header[6]; + + // allocate space for all the parameters and read them in + fill_in_parameter_sizes(model->param_sizes, model->config); + + // count the number of parameters + size_t num_parameters = 0; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += model->param_sizes[i]; + } + model->num_parameters = num_parameters; + + // create memory for model parameters on the device + model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes, 1); + + // read in all the parameters from file and copy them to device + float* params_memory_cpu = (float*)mallocCheck(num_parameters * sizeof(float)); + freadCheck(params_memory_cpu, sizeof(float), num_parameters, model_file); + cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, num_parameters * sizeof(float), cudaMemcpyHostToDevice)); + free(params_memory_cpu); + fcloseCheck(model_file); + + // other inits + model->acts_memory = NULL; + model->grads_memory = NULL; + model->m_memory = NULL; + model->v_memory = NULL; + model->grads_acts_memory = NULL; + model->inputs = NULL; + model->targets = NULL; + model->cpu_losses = NULL; + model->batch_size = 0; + model->seq_len = 0; + model->mean_loss = -1.0f; // -1.0f will designate no loss +} + +void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { + // targets are optional and could be NULL + + // ensure the model was initialized or error out + if (model->params_memory == NULL) { + printf("Error: model was not initialized properly.\n"); + exit(EXIT_FAILURE); + } + + // convenience parameters + int V = model->config.vocab_size; + int L = model->config.num_layers; + int NH = model->config.num_heads; + int C = model->config.channels; + + // validate inputs, all indices must be in the range [0, V) + for(int i = 0; i < B * T; i++) { + assert(0 <= inputs[i] && inputs[i] < V); + if (targets != NULL) { + assert(0 <= targets[i] && targets[i] < V); + } + } + + // allocate space for all the activations if needed (done here, lazily) + if(model->acts_memory == NULL) { + // record the current B,T as well + model->batch_size = B; + model->seq_len = T; + // and now allocate the space + fill_in_activation_sizes(model->act_sizes, B, T, model->config); + size_t num_activations = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + num_activations += model->act_sizes[i]; + } + model->num_activations = num_activations; + model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes); + printf("allocated %zu MiB for activations\n", (num_activations * sizeof(float)) >> 20); // >> 20 is /(1024*1024) + // also create memory for caching inputs and targets + cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); + cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); + cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float))); + } else { + // validate B,T is consistent with how we've allocated the memory before + // in principle we could get more clever here in the future, for now this is safest + if (B != model->batch_size || T != model->seq_len) { + printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, B, T); + exit(EXIT_FAILURE); + } + } + + // copy inputs/targets to the model + cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); + if (targets != NULL) { + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + } + + // forward pass + ParameterTensors params = model->params; // for brevity + ActivationTensors acts = model->acts; + float* residual; + encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0] + + for (int l = 0; l < L; l++) { + + residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + + // get the pointers of the weights for this layer + float* l_ln1w = params.ln1w + l * C; + float* l_ln1b = params.ln1b + l * C; + float* l_qkvw = params.qkvw + l * 3*C * C; + float* l_qkvb = params.qkvb + l * 3*C; + float* l_attprojw = params.attprojw + l * C * C; + float* l_attprojb = params.attprojb + l * C; + float* l_ln2w = params.ln2w + l * C; + float* l_ln2b = params.ln2b + l * C; + float* l_fcw = params.fcw + l * 4*C * C; + float* l_fcb = params.fcb + l * 4*C; + float* l_fcprojw = params.fcprojw + l * C * 4*C; + float* l_fcprojb = params.fcprojb + l * C; + + // get the pointers of the activations for this layer + float* l_ln1 = acts.ln1 + l * B * T * C; + float* l_ln1_mean = acts.ln1_mean + l * B * T; + float* l_ln1_rstd = acts.ln1_rstd + l * B * T; + float* l_qkvr = acts.qkvr + l * B * T * 3*C; + float* l_atty = acts.atty + l * B * T * C; + float* l_att = acts.att + l * B * NH * T * T; + float* l_attproj = acts.attproj + l * B * T * C; + float* l_residual2 = acts.residual2 + l * B * T * C; + float* l_ln2 = acts.ln2 + l * B * T * C; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; + float* l_fch = acts.fch + l * B * T * 4*C; + float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; + float* l_fcproj = acts.fcproj + l * B * T * C; + float* l_residual3 = acts.residual3 + l * B * T * C; + // these are only needed as scratchpads for the forward pass, but + // need not be stored for backward + float* scratch = acts.output; + + // now do the forward pass + layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); + matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); + attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH); + matmul_forward_cublaslt(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); + residual_forward(l_residual2, residual, l_attproj, B*T*C); + layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C); + matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C); + gelu_forward(l_fch_gelu, l_fch, B*T*4*C); + matmul_forward_cublaslt(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C); + residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C); + } + + residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 + layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); + matmul_forward_cublas(acts.output, acts.lnf, params.wte, NULL, B, T, C, V); + + // also forward the cross-entropy loss function if we have the targets + if (targets != NULL) { + // fused classifier: does the forward pass and first part of the backward pass + // we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss + fused_classifier3(acts.output, acts.losses, NULL, model->targets, B, T, V, V); + // for convenience also evaluate the mean loss (TODO re-think this compute+sync point) + // move the (B,T) losses to CPU + cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); + float mean_loss = 0.0f; + for (int i=0; icpu_losses[i]; } + mean_loss /= B*T; + model->mean_loss = mean_loss; + + } else { + // if we don't have targets, we don't have loss + model->mean_loss = -1.0f; + } +} + +void gpt2_zero_grad(GPT2 *model) { + if (model->grads_acts_memory != NULL) { cudaCheck(cudaMemset(model->grads_acts_memory, 0, model->num_grad_acts * sizeof(float))); } + if (model->grads_memory != NULL) { cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(float))); } +} + +void gpt2_backward(GPT2 *model) { + + // double check we forwarded previously, with targets + if (model->mean_loss == -1.0f) { + printf("Error: must forward with targets before backward\n"); + exit(EXIT_FAILURE); + } + + // lazily allocate the memory for gradients of the weights and activations, if needed + if (model->grads_memory == NULL) { + // allocate buffers for weight gradients + model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes, 1); + printf("allocated %zu MiB for parameter gradients\n", (model->num_parameters * sizeof(float)) >> 20); + // we're going to be clever for the activations backward pass. we don't need to exactly + // mirror the forward pass acrtivations and we will save memory. + size_t bw_act_sizes[NUM_ACTIVATION_TENSORS]; + GPT2Config cfg = model->config; + cfg.num_layers = 1; // copy the configuration but override number of layers to 1 + fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, cfg); + // count up and allocate the space + model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes); + model->num_grad_acts = 0; + for (int i = 0; i < NUM_BACKWARD_TENSORS; i++) { + model->num_grad_acts += bw_act_sizes[i]; + } + printf("allocated %zu MiB for activation gradients\n", (model->num_grad_acts * sizeof(float)) >> 20); + // init gradients of parameters and activations to zero + gpt2_zero_grad(model); + } + + // convenience shortcuts + int B = model->batch_size; + int T = model->seq_len; + int V = model->config.vocab_size; + int L = model->config.num_layers; + int NH = model->config.num_heads; + int C = model->config.channels; + + // backward pass: go in the reverse order of the forward pass, and call backward() functions + ParameterTensors params = model->params; // for brevity + ParameterTensors grads = model->grads; + ActivationTensors acts = model->acts; + GradActTensors grads_acts = model->grads_acts; + + // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) + // this was done in the fused classifier kernel as last step of forward pass + // technically that is a small, inline backward() pass of calculating + // total, final loss as the mean over all losses over all (B,T) positions in the batch + // next: backward the classifier matmul + matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, V); + // backward the final layernorm + float* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 + float* dresidual = grads_acts.residual3; // the main buffer holding the gradient in the backward pass + layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C); + + // now backward all the layers + for (int l = L-1; l >= 0; l--) { + residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + + // get the pointers of the weights for this layer + float* l_ln1w = params.ln1w + l * C; + float* l_qkvw = params.qkvw + l * 3*C * C; + float* l_attprojw = params.attprojw + l * C * C; + float* l_ln2w = params.ln2w + l * C; + float* l_fcw = params.fcw + l * 4*C * C; + float* l_fcprojw = params.fcprojw + l * C * 4*C; + // get the pointers of the gradients of the weights for this layer + float* dl_ln1w = grads.ln1w + l * C; + float* dl_ln1b = grads.ln1b + l * C; + float* dl_qkvw = grads.qkvw + l * 3*C * C; + float* dl_qkvb = grads.qkvb + l * 3*C; + float* dl_attprojw = grads.attprojw + l * C * C; + float* dl_attprojb = grads.attprojb + l * C; + float* dl_ln2w = grads.ln2w + l * C; + float* dl_ln2b = grads.ln2b + l * C; + float* dl_fcw = grads.fcw + l * 4*C * C; + float* dl_fcb = grads.fcb + l * 4*C; + float* dl_fcprojw = grads.fcprojw + l * C * 4*C; + float* dl_fcprojb = grads.fcprojb + l * C; + // get the pointers of the activations for this layer + float* l_ln1 = acts.ln1 + l * B * T * C; + float* l_ln1_mean = acts.ln1_mean + l * B * T; + float* l_ln1_rstd = acts.ln1_rstd + l * B * T; + float* l_qkvr = acts.qkvr + l * B * T * 3*C; + float* l_atty = acts.atty + l * B * T * C; + float* l_att = acts.att + l * B * NH * T * T; + float* l_residual2 = acts.residual2 + l * B * T * C; + float* l_ln2 = acts.ln2 + l * B * T * C; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; + float* l_fch = acts.fch + l * B * T * 4*C; + float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; + // get the pointers of the gradients of the activations for this layer + // notice that there is no l *, because we just have a single copy, and keep + // re-using this memory in every Transformer block as we calculate backward pass + + // we need a B x T x C buffer; thankfully, the forward activation for lnf isn't needed anymore, + // so we can co-opt it here. + float* dl_btc = acts.lnf; + float* dl_bt4c = grads_acts.bt4c; + float* dl_preatt = grads_acts.preatt; + + // re-use scratch buffer of the forward pass + float* scratch = acts.output; + + // backprop this layer + matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, B, T, 4*C, C); + gelu_backward(dl_bt4c, l_fch, dl_bt4c, B*T*4*C); + matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, B, T, C, 4 * C); + // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above + layernorm_backward(dresidual, dl_ln2w, dl_ln2b, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C); + matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, B, T, C, C); + // we more B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory + float* buffer_a = l_atty; + float* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need + + attention_backward(dl_bt4c, buffer_b, dl_preatt, scratch, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH); + matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, B, T, C, 3 * C); + // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above + layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); + } + encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C); +} + +void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) { + // reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html + + // lazily allocate the memory for m_memory and v_memory + if (model->m_memory == NULL) { + cudaCheck(cudaMalloc((void**)&model->m_memory, model->num_parameters * sizeof(float))); + cudaCheck(cudaMalloc((void**)&model->v_memory, model->num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->m_memory, 0, model->num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->v_memory, 0, model->num_parameters * sizeof(float))); + printf("allocated %zu MiB for AdamW optimizer state m\n", (model->num_parameters * sizeof(float)) >> 20); + printf("allocated %zu MiB for AdamW optimizer state v\n", (model->num_parameters * sizeof(float)) >> 20); + } + + int block_size = 512; + int num_blocks = CEIL_DIV(model->num_parameters, block_size); + float beta1_correction = 1.0f - powf(beta1, t); + float beta2_correction = 1.0f - powf(beta2, t); + adamw_kernel2<<>>(model->params_memory, model->grads_memory, model->m_memory, model->v_memory, + model->num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay); + cudaCheck(cudaGetLastError()); +} + +void gpt2_free(GPT2 *model) { + cudaCheck(cudaFree(model->params_memory)); + cudaCheck(cudaFree(model->grads_memory)); + cudaCheck(cudaFree(model->m_memory)); + cudaCheck(cudaFree(model->v_memory)); + cudaCheck(cudaFree(model->acts_memory)); + cudaCheck(cudaFree(model->grads_acts_memory)); + cudaCheck(cudaFree(model->inputs)); + cudaCheck(cudaFree(model->targets)); + cudaFreeHost(model->cpu_losses); +} + +#ifndef TESTING +// if we are TESTING (see test_gpt2.cu), we'll skip the int main below + +// ---------------------------------------------------------------------------- +// data loader lite: returns random batches of data from a file of integers + +typedef struct { + // hyperparameters + int B; + int T; + // input handling and its state + FILE* tokens_file; + long file_size; + long current_position; + // output memory + int* batch; + int* inputs; + int* targets; + // convenience variables + long num_batches; +} DataLoader; + +void dataloader_init(DataLoader *loader, const char* filename, int B, int T) { + loader->B = B; + loader->T = T; + + // open the input file for reading + loader->tokens_file = fopenCheck(filename, "rb"); + + // determine the file size + fseek(loader->tokens_file, 0, SEEK_END); + loader->file_size = ftell(loader->tokens_file); + fseek(loader->tokens_file, 0, SEEK_SET); + if (loader->file_size < (B * T + 1) * sizeof(int)) { + printf("Error: file size is too small for the batch size and sequence length\n"); + exit(EXIT_FAILURE); + } + loader->current_position = 0; // start at the beginning + + // allocate space for B*T + 1 integers to store the inputs and targets + // Using CUDA CPU pinned memory for faster PCI Express transfers to GPU + // See: https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/ + cudaMallocHost((void**)&loader->batch, (B * T + 1) * sizeof(int)); + loader->inputs = loader->batch; + loader->targets = loader->batch + 1; // targets are shifted by one + loader->num_batches = loader->file_size / (B * T * sizeof(int)); +} + +void dataloader_reset(DataLoader *loader) { + loader->current_position = 0; +} + +void dataloader_next_batch(DataLoader *loader) { + int B = loader->B; + int T = loader->T; + // if we are at the end of the file, loop back to the beginning + if (loader->current_position + (B*T+1) * sizeof(int) > loader->file_size) { + loader->current_position = 0; + } + // read the B*T+1 integers from the file into batch + fseek(loader->tokens_file, loader->current_position, SEEK_SET); + freadCheck(loader->batch, sizeof(int), B*T+1, loader->tokens_file); + // advance the current position by B*T integers + loader->current_position += B*T * sizeof(int); +} + +void dataloader_free(DataLoader *loader) { + fcloseCheck(loader->tokens_file); + cudaFreeHost(loader->batch); +} + +// ---------------------------------------------------------------------------- +// sampler: takes probabilities and samples integers from them + +#define GPT2_EOT 50256 + +unsigned int random_u32(unsigned long long *state) { + // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A + *state ^= *state >> 12; + *state ^= *state << 25; + *state ^= *state >> 27; + return (*state * 0x2545F4914F6CDD1Dull) >> 32; +} +float random_f32(unsigned long long *state) { // random float32 in [0,1) + return (random_u32(state) >> 8) / 16777216.0f; +} + +int sample_softmax(const float* logits, int n, float coin) { + // sample index from logits (converted to probabilities using softmax) + // coin is a random number in [0, 1), usually from random_f32() + double norm = 0; + for (int i = 0; i < n; i++) { + norm += expf(logits[i]); + } + // instead of dividing all exp(logits), we can just multiply coin. + coin *= norm; + float cdf = 0.0f; + for (int i = 0; i < n; i++) { + cdf += expf(logits[i]); + if (coin < cdf) { + return i; + } + } + return n - 1; // in case of rounding errors +} + +// ---------------------------------------------------------------------------- +// Tokenizer (only supports decoding: tokens (integers) -> strings) + +typedef struct { + uint32_t vocab_size; + char **token_table; + int init_ok; +} Tokenizer; + +void safe_printf(const char *piece) { + // the tokens are raw bytes, and we we only want to print the printable ones + // many bytes can be various control codes, backspace, etc. + if (piece == NULL) { return; } + if (piece[0] == '\0') { return; } + // handle individual byte tokens + // every token is asserted to be at least one byte so doing piece[1] is ok + if (piece[1] == '\0') { + unsigned char byte_val = piece[0]; + if (!(isprint(byte_val) || isspace(byte_val))) { + return; // weird byte, don't print it + } + } + printf("%s", piece); +} + +void tokenizer_init(Tokenizer *tokenizer, const char *filename) { + FILE *file = fopen(filename, "rb"); + if (file == NULL) { + // try to be more helpful as we just added this feature, erase later + printf("---\n"); + printf("WARNING: Failed to open the tokenizer file %s\n", filename); + printf("The Tokenizer is a new feature added April 14 2024.\n"); + printf("Re-run `python train_gpt2.py` to write it\n"); + printf("---\n"); + tokenizer->init_ok = 0; + return; + } + // read in the header + uint32_t header[256]; + freadCheck(header, sizeof(uint32_t), 256, file); + assert(header[0] == 20240328); + assert(header[1] == 1); + tokenizer->vocab_size = header[2]; + // read in all the tokens + unsigned char length; + tokenizer->token_table = (char **)mallocCheck(tokenizer->vocab_size * sizeof(char *)); + for (uint32_t i = 0; i < tokenizer->vocab_size; i++) { + freadCheck(&length, sizeof(unsigned char), 1, file); + assert(length > 0); // every token should be at least one character + char *token_bytes = (char *)mallocCheck(length + 1); + freadCheck(token_bytes, sizeof(char), length, file); + token_bytes[length] = '\0'; // Add null terminator for printing + tokenizer->token_table[i] = token_bytes; + } + // cleanups + fcloseCheck(file); + tokenizer->init_ok = 1; +} + +const char *tokenizer_decode(Tokenizer *tokenizer, uint32_t token_id) { + if (tokenizer->init_ok == 0) { + return NULL; + } + if (token_id < tokenizer->vocab_size) { + return tokenizer->token_table[token_id]; + } else { + printf("invalid token id %d!\n", token_id); + return NULL; + } +} + +void tokenizer_free(Tokenizer *tokenizer) { + if (tokenizer->init_ok) { + for (uint32_t i = 0; i < tokenizer->vocab_size; i++) { + free(tokenizer->token_table[i]); + } + free(tokenizer->token_table); + } +} + +// ---------------------------------------------------------------------------- +// Logger lite, will probably grow/change some over time + +typedef struct { + FILE *logfile; + int flush_every; // every how many steps to flush the log +} Logger; + +void logger_init(Logger *logger, const char *filename) { + logger->flush_every = 20; + logger->logfile = NULL; + if (filename != NULL) { logger->logfile = fopenCheck(filename, "w"); } +} + +void logger_log_val(Logger *logger, int step, float val_loss) { + if (logger->logfile != NULL) { + fprintf(logger->logfile, "s:%d tel:%.4f\n", step, val_loss); + } +} + +void logger_log_train(Logger *logger, int step, float train_loss) { + if (logger->logfile != NULL) { + fprintf(logger->logfile, "s:%d trl:%.4f\n", step, train_loss); + if (step % 10 == 0) { fflush(logger->logfile); } + } +} + +void logger_free(Logger *logger) { + if (logger->logfile != NULL) { fclose(logger->logfile); } +} + +// ---------------------------------------------------------------------------- +// CLI, poor man's argparse + +void error_usage() { + // default run = debugging run with TinyShakespeare + // bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile + fprintf(stderr, "Usage: ./train_gpt2cu [options]\n"); + fprintf(stderr, "Example: ./train_gpt2cu -i data/TinyStories -v 100 -s 100 -g 144 -o stories.log\n"); + fprintf(stderr, "Options:\n"); + fprintf(stderr, " -i input dataset prefix (default = data/tiny_shakespeare)\n"); + fprintf(stderr, " -o output log file (default = NULL)\n"); + fprintf(stderr, " -b batch size B (default = 4)\n"); + fprintf(stderr, " -t sequence length T (default = 1024)\n"); + fprintf(stderr, " -l learning rate (default = 3e-4f)\n"); + fprintf(stderr, " -v val_loss_every, how often we evaluate val loss (default = 20)\n"); + fprintf(stderr, " -m val_max_batches, up to how many val batches to estimate val loss? (default = 20)\n"); + fprintf(stderr, " -s sample_every, how often we inference the model (default = 20)\n"); + fprintf(stderr, " -g genT, how many steps of inference we do (default = 64)\n"); + exit(EXIT_FAILURE); +} + +// ---------------------------------------------------------------------------- +// main training loop +int main(int argc, char *argv[]) { + + // read in the (optional) command line arguments + const char* input_dataset_prefix = "data/tiny_shakespeare"; // or e.g. data/TinyStories + const char* output_log_file = NULL; + int B = 4; // batch size + int T = 1024; // sequence length max + float learning_rate = 3e-4f; + int val_loss_every = 20; // every how many steps do we eval validation loss? + int val_max_batches = 20; // how many batches max do we eval for validation loss? + int sample_every = 20; // every how many steps to do inference? + int genT = 64; // number of steps of inference we will do + for (int i = 1; i < argc; i+=2) { + if (i + 1 >= argc) { error_usage(); } // must have arg after flag + if (argv[i][0] != '-') { error_usage(); } // must start with dash + if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter) + // read in the args + if (argv[i][1] == 'i') { input_dataset_prefix = argv[i+1]; } + else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; } + else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } + else if (argv[i][1] == 't') { T = atoi(argv[i+1]); } + else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); } + else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'm') { val_max_batches = atoi(argv[i+1]); } + else if (argv[i][1] == 's') { sample_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); } + else { error_usage(); } + } + printf("+-----------------------+----------------------------------------------------+\n"); + printf("| Parameter | Value |\n"); + printf("+-----------------------+----------------------------------------------------+\n"); + printf("| input dataset prefix | %-50s |\n", input_dataset_prefix); + printf("| output log file | %-50s |\n", output_log_file == NULL ? "NULL" : output_log_file); + printf("| batch size B | %-50d |\n", B); + printf("| sequence length T | %-50d |\n", T); + printf("| learning rate | %-50f |\n", learning_rate); + printf("| val_loss_every | %-50d |\n", val_loss_every); + printf("| val_max_batches | %-50d |\n", val_max_batches); + printf("| sample_every | %-50d |\n", sample_every); + printf("| genT | %-50d |\n", genT); + printf("+-----------------------+----------------------------------------------------+\n"); + + // set up the device + int deviceIdx = 0; + cudaCheck(cudaSetDevice(deviceIdx)); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, deviceIdx); + // setup cuBLAS and cuBLASLt + cublasCheck(cublasCreate(&cublas_handle)); + cublasCheck(cublasLtCreate(&cublaslt_handle)); + // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') + int enable_tf32 = deviceProp.major >= 8 ? 1 : 0; + cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; + cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); + cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); + printf("| device | %-50s |\n", deviceProp.name); + printf("| TF32 | %-50s |\n", enable_tf32 ? "enabled" : "disabled"); + printf("+-----------------------+----------------------------------------------------+\n"); + + // build the GPT-2 model from a checkpoint + GPT2 model; + gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); + printf("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); + printf("| vocab_size V | %-50d |\n", model.config.vocab_size); + printf("| num_layers L | %-50d |\n", model.config.num_layers); + printf("| num_heads NH | %-50d |\n", model.config.num_heads); + printf("| channels C | %-50d |\n", model.config.channels); + printf("| num_parameters | %-50zu |\n", model.num_parameters); + printf("+-----------------------+----------------------------------------------------+\n"); + + // build DataLoaders for both train and val + char train_tokens_filename[128]; + char val_tokens_filename[128]; + assert(strlen(input_dataset_prefix) < 100); // being bit lazy here, make sure we don't overflow + sprintf(train_tokens_filename, "%s_train.bin", input_dataset_prefix); + sprintf(val_tokens_filename, "%s_val.bin", input_dataset_prefix); + DataLoader train_loader; + dataloader_init(&train_loader, train_tokens_filename, B, T); + DataLoader val_loader; + dataloader_init(&val_loader, val_tokens_filename, B, T); + int train_num_batches = train_loader.num_batches; // let's do 1 epoch by default for now + int val_num_batches = train_loader.num_batches < val_max_batches ? train_loader.num_batches : val_max_batches; + printf("| train_num_batches | %-50d |\n", train_num_batches); + printf("| val_num_batches | %-50d |\n", val_num_batches); + printf("+-----------------------+----------------------------------------------------+\n"); + + // print model parameter allocations from gpt2_build_from_checkpoint down here to not mess up our table above + printf("allocated %d MiB for model parameters\n", (int)round(model.num_parameters * sizeof(float) / (1024 * 1024))); + + // set up the Logger + Logger logger; + logger_init(&logger, output_log_file); + + // build the Tokenizer + Tokenizer tokenizer; + tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); + + // some memory for generating samples from the model + unsigned long long rng_state = 1337; + int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); + float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float)); + + // train + struct timespec start, end; + double total_sum_iteration_time_s = 0.0; + for (int step = 0; step <= train_num_batches; step++) { + int last_step = step == train_num_batches; + + // once in a while estimate the validation loss + if (step % val_loss_every == 0 || last_step) { + float val_loss = 0.0f; + dataloader_reset(&val_loader); + for (int i = 0; i < val_num_batches; i++) { + dataloader_next_batch(&val_loader); + gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T); + val_loss += model.mean_loss; + } + val_loss /= val_num_batches; + printf("val loss %f\n", val_loss); + logger_log_val(&logger, step, val_loss); + } + + // once in a while do model inference to print generated text + if (step > 0 && step % sample_every == 0 || last_step) { + // fill up gen_tokens with the GPT2_EOT, which kicks off the generation + for(int i = 0; i < B * T; ++i) { + gen_tokens[i] = GPT2_EOT; + } + // now sample from the model autoregressively + printf("generating:\n---\n"); + for (int t = 1; t < genT; t++) { + // note that inference is very wasteful here because for each token + // we re-calculate the forward pass for all of (B,T) positions from scratch + // but the inference here is just for sanity checking anyway + // and we can maybe optimize a bit more later, with careful tests + gpt2_forward(&model, gen_tokens, NULL, B, T); + // furthermore, below we're only using b=0 (i.e. the first row) of all B rows + // we're in principle running B "inference streams" in parallel here + // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU) + // get the V-dimensional vector probs[0, t-1, :] + float* logits = model.acts.output + (t - 1) * model.config.vocab_size; + // move probs back to CPU and sample + cudaCheck(cudaMemcpy(cpu_logits, logits, model.config.vocab_size * sizeof(float), cudaMemcpyDeviceToHost)); + float coin = random_f32(&rng_state); + int next_token = sample_softmax(cpu_logits, model.config.vocab_size, coin); + gen_tokens[t] = next_token; + // print the generated token, either using the Tokenizer or a fallback + if (tokenizer.init_ok) { + const char* token_str = tokenizer_decode(&tokenizer, next_token); + safe_printf(token_str); + } else { + // fall back to printing the token id + printf("%d ", next_token); + } + fflush(stdout); + } + printf("\n---\n"); + } + + // bit confusing: we want to make sure to eval and sample on 0th iteration + // but also after the very last iteration. so we loop for step <= train_num_batches + // instead of just < train_num_batches (one extra due to <=), only to do + // the validation/sampling one last time, and then we break right here as we're done. + if (last_step) { break; } + + // do a training step + clock_gettime(CLOCK_MONOTONIC, &start); + dataloader_next_batch(&train_loader); + gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T); + gpt2_zero_grad(&model); + gpt2_backward(&model); + gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); + cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings + clock_gettime(CLOCK_MONOTONIC, &end); + double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; + total_sum_iteration_time_s += time_elapsed_s; + int tokens_per_second = (B * T) / time_elapsed_s; + printf("step %4d/%d: train loss %f (%f ms, %d tok/s)\n", step + 1, train_num_batches, model.mean_loss, time_elapsed_s * 1000, tokens_per_second); + logger_log_train(&logger, step, model.mean_loss); + } + // add a total average, for optimizations that are only mild improvements + printf("total average iteration time: %f ms\n", total_sum_iteration_time_s / train_num_batches * 1000); + + // free + dataloader_free(&train_loader); + dataloader_free(&val_loader); + tokenizer_free(&tokenizer); + gpt2_free(&model); + free(cpu_logits); + free(gen_tokens); + cudaCheck(cudaFree(cublaslt_workspace)); + cublasCheck(cublasDestroy(cublas_handle)); + cublasCheck(cublasLtDestroy(cublaslt_handle)); + logger_free(&logger); + + return 0; +} +#endif