From 55fbd7e6adb78e3e2c46afe416852f8004e14881 Mon Sep 17 00:00:00 2001 From: cip19aac Date: Mon, 22 Apr 2024 16:35:21 +0100 Subject: [PATCH 1/4] add a faster matmul backward bias kernel that uses coalesced reads and shared memory --- dev/cuda/matmul_backward_bias.cu | 51 +++++++++++++++++++++++++- train_gpt2.cu | 63 ++++++++++++++++++++------------ 2 files changed, 89 insertions(+), 25 deletions(-) diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 15753d8bd..40cc727c2 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -124,6 +124,45 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in } } +// this kernel essentially performs a column-wise reduction over dout +// the philosophy of this kernel is to employ one block to reduce +// along several columns and then to share and accumulate the +// reductions performed by different warps via shared memory +__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { + + const int vstep = blockDim.x / warpSize; + const int row = threadIdx.x >> 5; // basically warp_id + const int tl = blockIdx.x * warpSize; + const int lane = threadIdx.x & (warpSize - 1); + + const float* dout_col = dout + tl + lane; + + extern __shared__ float smem[]; + + float dout_sum = 0.0f; + // column reductions by looping through the rows: + // the loop should not exceed B * T rows + for (int j = row; j < B * T; j += vstep) { + dout_sum += dout_col[j * OC]; + } + + smem[lane + row * warpSize] = dout_sum; + + // our kernel assures that entire blocks are running + // inside the loop, so we can safely call sync I believe + __syncthreads(); + + dout_sum = 0.0f; + + if (row == 0) { + for (int j = 0; j < vstep; j++) { + dout_sum += smem[lane + j * warpSize]; + } + + dbias[tl + lane] += dout_sum; + } +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -152,6 +191,13 @@ void matmul_backward_bias3(float* dinp, float* dweight, float* dbias, matmul_backward_bias_kernel3<<>>(dbias, dout, B, T, OC); } +void matmul_backward_bias4(float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, float* ones, + int B, int T, int C, int OC, int block_size) { + const int grid_size = OC / 32; + matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); +} + void matmul_backward_bias(int kernel_num, float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, float* ones, @@ -166,6 +212,9 @@ void matmul_backward_bias(int kernel_num, case 3: matmul_backward_bias3(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); break; + case 4: + matmul_backward_bias4(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -230,7 +279,7 @@ int main(int argc, char **argv) { matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, 128); // compare printf("Checking correctness...\n"); - validate_result(d_dbias, dbias, "dbias", OC, 1e-3f); + validate_result(d_dbias, dbias, "dbias", OC, 5e-3f); printf("All results match for block_size=%d.\n\n", block_size); } diff --git a/train_gpt2.cu b/train_gpt2.cu index de7971392..0f75947ee 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -532,27 +532,42 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int } } -// cooperative groups solution, one warp per output channel -__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) { - // dout is (B, T, OC), dbias is (OC) - // e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel - namespace cg = cooperative_groups; - cg::thread_block block = cg::this_thread_block(); - cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - // meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3) - int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); - if(idx >= OC) { return; } - int BT = B * T; // number of elements to reduce in total, per channel - // first, thread coarsening to sum reduce the problem size from B*T to 32 - float sum = 0.0f; - for(int i = warp.thread_rank(); i < BT; i += warp.size()) { - sum += dout[i * OC + idx]; - } - // now do a warp-level reduce to get the sum across the 32 threads in this warp - sum = cg::reduce(warp, sum, cg::plus{}); - // write the result to output (global memory) - if(warp.thread_rank() == 0) { - dbias[idx] += sum; +// this kernel essentially performs a column-wise reduction over dout +// the philosophy of the kernel is to employ one block to reduce +// along several columns and then to share and accumulate the +// reductions performed by different warps via shared memory +__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { + + const int vstep = blockDim.x / warpSize; + const int row = threadIdx.x >> 5; // basically warp_id + const int tl = blockIdx.x * warpSize; + const int lane = threadIdx.x & (warpSize - 1); + + const float* dout_col = dout + tl + lane; + + extern __shared__ float smem[]; + + float dout_sum = 0.0f; + // column reductions by looping through the rows: + // the loop should not exceed B * T rows + for (int j = row; j < B * T; j += vstep) { + dout_sum += dout_col[j * OC]; + } + + smem[lane + row * warpSize] = dout_sum; + + // our kernel assures that entire blocks are running + // inside the loop, so we can safely call sync I believe + __syncthreads(); + + dout_sum = 0.0f; + + if (row == 0) { + for (int j = 0; j < vstep; j++) { + dout_sum += smem[lane + j * warpSize]; + } + + dbias[tl + lane] += dout_sum; } } @@ -973,9 +988,9 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C)); // backward to bias, if given, does a += if (dbias != NULL) { - const int block_size = 512; - const int grid_size = CEIL_DIV(OC * 32, block_size); - matmul_backward_bias_kernel2<<>>(dbias, dout, B, T, OC); + const int block_size = 1024; + const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work + matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } } From b82ec201beef935c0e5ffd90cec6daab55c258f5 Mon Sep 17 00:00:00 2001 From: cip19aac Date: Mon, 22 Apr 2024 17:05:05 +0100 Subject: [PATCH 2/4] add comment --- dev/cuda/matmul_backward_bias.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 40cc727c2..9f085825a 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -194,7 +194,7 @@ void matmul_backward_bias3(float* dinp, float* dweight, float* dbias, void matmul_backward_bias4(float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, float* ones, int B, int T, int C, int OC, int block_size) { - const int grid_size = OC / 32; + const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); } From 35393b444265f89d562018f390ee1010866739d4 Mon Sep 17 00:00:00 2001 From: cip19aac Date: Mon, 22 Apr 2024 17:29:11 +0100 Subject: [PATCH 3/4] add more comments to explain the philosophy behind the kernel --- dev/cuda/matmul_backward_bias.cu | 10 ++++++---- train_gpt2.cu | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 9f085825a..c550c81f8 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -124,10 +124,12 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in } } -// this kernel essentially performs a column-wise reduction over dout -// the philosophy of this kernel is to employ one block to reduce -// along several columns and then to share and accumulate the -// reductions performed by different warps via shared memory +// this kernel essentially performs a column-wise reduction over dout, +// which in pytorch would simply look like: dbias = dout.sum((0,1)) +// the philosophy of this kernel is to employ one block to reduce along +// several columns, whereby each block has a "width" of 32 columns to ensure +// coalesced access. near the end of the column-wise reduction, we accumulate +// the reductions performed by the warps in each block via shared memory __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { const int vstep = blockDim.x / warpSize; diff --git a/train_gpt2.cu b/train_gpt2.cu index 0f75947ee..079ecb42b 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -532,10 +532,12 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int } } -// this kernel essentially performs a column-wise reduction over dout -// the philosophy of the kernel is to employ one block to reduce -// along several columns and then to share and accumulate the -// reductions performed by different warps via shared memory +// this kernel essentially performs a column-wise reduction over dout, +// which in pytorch would simply look like: dbias = dout.sum((0,1)) +// the philosophy of this kernel is to employ one block to reduce along +// several columns, whereby each block has a "width" of 32 columns to ensure +// coalesced access. near the end of the column-wise reduction, we accumulate +// the reductions performed by the warps in each block via shared memory __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { const int vstep = blockDim.x / warpSize; From 988489519c8ea73a613990f7020b1c082c3b51c2 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 22 Apr 2024 17:51:06 +0000 Subject: [PATCH 4/4] nice new kernel for bias reduce, we are down by ~1ms/iter, to 76.13ms/iter on average --- dev/cuda/matmul_backward_bias.cu | 68 ++++++++++++++++---------------- train_gpt2.cu | 65 +++++++++++++++--------------- 2 files changed, 67 insertions(+), 66 deletions(-) diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index c550c81f8..7feab39ea 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -7,6 +7,7 @@ nvcc -O3 matmul_backward_bias.cu -lineinfo -o matmul_backward_bias ./matmul_backward_bias 1 ./matmul_backward_bias 2 ./matmul_backward_bias 3 +./matmul_backward_bias 4 ncu: sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1 @@ -14,6 +15,7 @@ sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1 #include #include +#include #include #include #include @@ -124,44 +126,43 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in } } -// this kernel essentially performs a column-wise reduction over dout, -// which in pytorch would simply look like: dbias = dout.sum((0,1)) -// the philosophy of this kernel is to employ one block to reduce along -// several columns, whereby each block has a "width" of 32 columns to ensure -// coalesced access. near the end of the column-wise reduction, we accumulate -// the reductions performed by the warps in each block via shared memory +// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to: +// dbias = dout.sum((0,1)) +// the idea is to employ one block to reduce along several columns, +// where each block has a width of 32 columns to ensure coalesced access. +// at the end we accumulate the reductions performed by the warps in each block via shared memory __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { - - const int vstep = blockDim.x / warpSize; - const int row = threadIdx.x >> 5; // basically warp_id - const int tl = blockIdx.x * warpSize; - const int lane = threadIdx.x & (warpSize - 1); - - const float* dout_col = dout + tl + lane; - - extern __shared__ float smem[]; - + // this kernel is launched with 1D grid_dim of OC/32 + // for example let's say block_size is 128 + extern __shared__ float smem[]; // of size block_size (128) + const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3 + const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31 + const int tl = blockIdx.x * warpSize; // pointer to the start column for this block + const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4 + + // pointer to the start of the column for one lane of threads + // so e.g. 4 threads (of the same lane_id) will reduce this one column + const float* dout_col = dout + tl + lane_id; + + // column reductions by looping through the rows + // each of the 4 threads offsets by its warp_id and then skips by vstep + // together these 4 threads cover all B*T rows of this (lane_id) column + // importantly, consecutive threads (in threadId) are processing adjacent columns, + // leading to a coalesced memory access pattern float dout_sum = 0.0f; - // column reductions by looping through the rows: - // the loop should not exceed B * T rows - for (int j = row; j < B * T; j += vstep) { - dout_sum += dout_col[j * OC]; + for (int row = warp_id; row < B * T; row += vstep) { + dout_sum += dout_col[row * OC]; } - - smem[lane + row * warpSize] = dout_sum; - - // our kernel assures that entire blocks are running - // inside the loop, so we can safely call sync I believe - __syncthreads(); - + smem[lane_id + warp_id * warpSize] = dout_sum; + __syncthreads(); + + // warp_id 0 reduces the shared memory column-wise, linearly dout_sum = 0.0f; - - if (row == 0) { + if (warp_id == 0) { for (int j = 0; j < vstep; j++) { - dout_sum += smem[lane + j * warpSize]; + dout_sum += smem[lane_id + j * warpSize]; } - - dbias[tl + lane] += dout_sum; + dbias[tl + lane_id] += dout_sum; } } @@ -196,7 +197,8 @@ void matmul_backward_bias3(float* dinp, float* dweight, float* dbias, void matmul_backward_bias4(float* dinp, float* dweight, float* dbias, float* dout, float* inp, float* weight, float* ones, int B, int T, int C, int OC, int block_size) { - const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work + assert(OC % 32 == 0); // OC must be divisible by 32 for this kernel + const int grid_size = OC / 32; matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); } diff --git a/train_gpt2.cu b/train_gpt2.cu index 079ecb42b..b66bdb55b 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -532,44 +532,43 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int } } -// this kernel essentially performs a column-wise reduction over dout, -// which in pytorch would simply look like: dbias = dout.sum((0,1)) -// the philosophy of this kernel is to employ one block to reduce along -// several columns, whereby each block has a "width" of 32 columns to ensure -// coalesced access. near the end of the column-wise reduction, we accumulate -// the reductions performed by the warps in each block via shared memory +// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to: +// dbias = dout.sum((0,1)) +// the idea is to employ one block to reduce along several columns, +// where each block has a width of 32 columns to ensure coalesced access. +// at the end we accumulate the reductions performed by the warps in each block via shared memory __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { - - const int vstep = blockDim.x / warpSize; - const int row = threadIdx.x >> 5; // basically warp_id - const int tl = blockIdx.x * warpSize; - const int lane = threadIdx.x & (warpSize - 1); - - const float* dout_col = dout + tl + lane; - - extern __shared__ float smem[]; - + // this kernel is launched with 1D grid_dim of OC/32 + // for example let's say block_size is 128 + extern __shared__ float smem[]; // of size block_size (128) + const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3 + const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31 + const int tl = blockIdx.x * warpSize; // pointer to the start column for this block + const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4 + + // pointer to the start of the column for one lane of threads + // so e.g. 4 threads (of the same lane_id) will reduce this one column + const float* dout_col = dout + tl + lane_id; + + // column reductions by looping through the rows + // each of the 4 threads offsets by its warp_id and then skips by vstep + // together these 4 threads cover all B*T rows of this (lane_id) column + // importantly, consecutive threads (in threadId) are processing adjacent columns, + // leading to a coalesced memory access pattern float dout_sum = 0.0f; - // column reductions by looping through the rows: - // the loop should not exceed B * T rows - for (int j = row; j < B * T; j += vstep) { - dout_sum += dout_col[j * OC]; - } - - smem[lane + row * warpSize] = dout_sum; - - // our kernel assures that entire blocks are running - // inside the loop, so we can safely call sync I believe - __syncthreads(); - + for (int row = warp_id; row < B * T; row += vstep) { + dout_sum += dout_col[row * OC]; + } + smem[lane_id + warp_id * warpSize] = dout_sum; + __syncthreads(); + + // warp_id 0 reduces the shared memory column-wise, linearly dout_sum = 0.0f; - - if (row == 0) { + if (warp_id == 0) { for (int j = 0; j < vstep; j++) { - dout_sum += smem[lane + j * warpSize]; + dout_sum += smem[lane_id + j * warpSize]; } - - dbias[tl + lane] += dout_sum; + dbias[tl + lane_id] += dout_sum; } }