Skip to content

Commit

Permalink
Merge branch 'al0vya-matmul-backward-bias'
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Apr 22, 2024
2 parents f813d63 + 9884895 commit 5f545ca
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 25 deletions.
55 changes: 54 additions & 1 deletion dev/cuda/matmul_backward_bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ nvcc -O3 matmul_backward_bias.cu -lineinfo -o matmul_backward_bias
./matmul_backward_bias 1
./matmul_backward_bias 2
./matmul_backward_bias 3
./matmul_backward_bias 4
ncu:
sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1
*/

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <omp.h>
Expand Down Expand Up @@ -124,6 +126,46 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in
}
}

// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to:
// dbias = dout.sum((0,1))
// the idea is to employ one block to reduce along several columns,
// where each block has a width of 32 columns to ensure coalesced access.
// at the end we accumulate the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {
// this kernel is launched with 1D grid_dim of OC/32
// for example let's say block_size is 128
extern __shared__ float smem[]; // of size block_size (128)
const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3
const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31
const int tl = blockIdx.x * warpSize; // pointer to the start column for this block
const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4

// pointer to the start of the column for one lane of threads
// so e.g. 4 threads (of the same lane_id) will reduce this one column
const float* dout_col = dout + tl + lane_id;

// column reductions by looping through the rows
// each of the 4 threads offsets by its warp_id and then skips by vstep
// together these 4 threads cover all B*T rows of this (lane_id) column
// importantly, consecutive threads (in threadId) are processing adjacent columns,
// leading to a coalesced memory access pattern
float dout_sum = 0.0f;
for (int row = warp_id; row < B * T; row += vstep) {
dout_sum += dout_col[row * OC];
}
smem[lane_id + warp_id * warpSize] = dout_sum;
__syncthreads();

// warp_id 0 reduces the shared memory column-wise, linearly
dout_sum = 0.0f;
if (warp_id == 0) {
for (int j = 0; j < vstep; j++) {
dout_sum += smem[lane_id + j * warpSize];
}
dbias[tl + lane_id] += dout_sum;
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -152,6 +194,14 @@ void matmul_backward_bias3(float* dinp, float* dweight, float* dbias,
matmul_backward_bias_kernel3<<<OC, block_size>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias4(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
int B, int T, int C, int OC, int block_size) {
assert(OC % 32 == 0); // OC must be divisible by 32 for this kernel
const int grid_size = OC / 32;
matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias(int kernel_num,
float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
Expand All @@ -166,6 +216,9 @@ void matmul_backward_bias(int kernel_num,
case 3:
matmul_backward_bias3(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
case 4:
matmul_backward_bias4(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -230,7 +283,7 @@ int main(int argc, char **argv) {
matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, 128);
// compare
printf("Checking correctness...\n");
validate_result(d_dbias, dbias, "dbias", OC, 1e-3f);
validate_result(d_dbias, dbias, "dbias", OC, 5e-3f);
printf("All results match for block_size=%d.\n\n", block_size);
}

Expand Down
64 changes: 40 additions & 24 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -532,27 +532,43 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int
}
}

// cooperative groups solution, one warp per output channel
__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) {
// dout is (B, T, OC), dbias is (OC)
// e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3)
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= OC) { return; }
int BT = B * T; // number of elements to reduce in total, per channel
// first, thread coarsening to sum reduce the problem size from B*T to 32
float sum = 0.0f;
for(int i = warp.thread_rank(); i < BT; i += warp.size()) {
sum += dout[i * OC + idx];
}
// now do a warp-level reduce to get the sum across the 32 threads in this warp
sum = cg::reduce(warp, sum, cg::plus<float>{});
// write the result to output (global memory)
if(warp.thread_rank() == 0) {
dbias[idx] += sum;
// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to:
// dbias = dout.sum((0,1))
// the idea is to employ one block to reduce along several columns,
// where each block has a width of 32 columns to ensure coalesced access.
// at the end we accumulate the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {
// this kernel is launched with 1D grid_dim of OC/32
// for example let's say block_size is 128
extern __shared__ float smem[]; // of size block_size (128)
const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3
const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31
const int tl = blockIdx.x * warpSize; // pointer to the start column for this block
const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4

// pointer to the start of the column for one lane of threads
// so e.g. 4 threads (of the same lane_id) will reduce this one column
const float* dout_col = dout + tl + lane_id;

// column reductions by looping through the rows
// each of the 4 threads offsets by its warp_id and then skips by vstep
// together these 4 threads cover all B*T rows of this (lane_id) column
// importantly, consecutive threads (in threadId) are processing adjacent columns,
// leading to a coalesced memory access pattern
float dout_sum = 0.0f;
for (int row = warp_id; row < B * T; row += vstep) {
dout_sum += dout_col[row * OC];
}
smem[lane_id + warp_id * warpSize] = dout_sum;
__syncthreads();

// warp_id 0 reduces the shared memory column-wise, linearly
dout_sum = 0.0f;
if (warp_id == 0) {
for (int j = 0; j < vstep; j++) {
dout_sum += smem[lane_id + j * warpSize];
}
dbias[tl + lane_id] += dout_sum;
}
}

Expand Down Expand Up @@ -973,9 +989,9 @@ void matmul_backward(float* dinp, float* dweight, float* dbias,
cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C));
// backward to bias, if given, does a +=
if (dbias != NULL) {
const int block_size = 512;
const int grid_size = CEIL_DIV(OC * 32, block_size);
matmul_backward_bias_kernel2<<<grid_size, block_size>>>(dbias, dout, B, T, OC);
const int block_size = 1024;
const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work
matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);
cudaCheck(cudaGetLastError());
}
}
Expand Down

0 comments on commit 5f545ca

Please sign in to comment.