Skip to content

Commit

Permalink
integrate our rmsnorm backward and move the other rmsnorm functions i…
Browse files Browse the repository at this point in the history
…nto rmsnorm.cuh that is a new file
  • Loading branch information
karpathy committed Sep 26, 2024
1 parent 102067f commit 2c4b3cc
Show file tree
Hide file tree
Showing 4 changed files with 419 additions and 214 deletions.
179 changes: 0 additions & 179 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,66 +139,6 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res
}
}

__global__ void rmsnorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ rms,
const floatX* __restrict__ inp, const floatX* __restrict__ weight, int N, int C) {
// this kernel is a simplified version of layernorm_forward_kernel6
assert(blockDim.x == WARP_SIZE);

// load weights into shared memory
// do this before we allow any threads to exit!
extern __shared__ char* params[];
// load128/store128 sometimes generated multiple instructions when the types here were floatX*, so
// let's keep everything as x128
x128* s_weight = reinterpret_cast<x128*>(params);
x128* s_in = reinterpret_cast<x128*>(params) + ((1 + threadIdx.y) * C / x128::size);

int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size;
for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) {
s_weight[i/x128::size] = load128(weight + i);
}
__syncthreads();

int idx = blockIdx.x * blockDim.y + threadIdx.y;
if(idx >= N) { return; } // guard

// adjust pointers to current token
inp += idx * C;
out += idx * C;

const float eps = 1e-5f;
float acc = 0.f;

for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 in_data = load128cs(inp + c);
s_in[c / x128::size] = in_data;
for(int k = 0; k < x128::size; ++k) {
float data_k = (float)in_data[k];
acc += data_k * data_k;
}
}

acc = warpReduceSum(acc) / C;
float s = rsqrtf(acc + eps);

for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 in_data = s_in[c / x128::size];
const x128 w = s_weight[c / x128::size];
x128 out_data;
for(int k = 0; k < x128::size; ++k) {
float n = s * (float)in_data[k]; // normalized output
float o = n * (float)w[k]; // scale
out_data[k] = (floatX)o;
}

store128cs(out + c, out_data);
}

// store the rms, no need to cache it
if(threadIdx.x == 0 && rms != nullptr) {
__stcs(rms + idx, s);
}
}

__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd,
const floatX* inp1, const floatX* inp2,
const floatX* weight, const floatX* bias,
Expand Down Expand Up @@ -278,77 +218,6 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed,
}
}

__global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* normed, float* rrms,
const floatX* inp1, const floatX* inp2,
const floatX* weight,
int N, int C) {
assert(blockDim.x == WARP_SIZE);

// load weights and biases into shared memory
// do this before we allow any threads to exit!
extern __shared__ char* params[];
// load128/store128 sometimes generated multiple instructions when the types here were floatX*, so
// let's keep everything as x128
x128* s_weight = reinterpret_cast<x128*>(params);
x128* s_res = reinterpret_cast<x128*>(params) + ((1 + threadIdx.y) * C / x128::size);

int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size;
for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) {
s_weight[i/x128::size] = load128(weight + i);
}
__syncthreads();

int idx = blockIdx.x * blockDim.y + threadIdx.y;
if(idx > N) return;

// adjust pointers to current token
residual += C * idx;
normed += C * idx;
inp1 += C * idx;
inp2 += C * idx;

const float eps = 1e-5f;
for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 in1 = load128cs(inp1 + c);
const x128 in2 = load128cs(inp2 + c);
x128 out;
for(int k = 0; k < x128::size; ++k) {
out[k] = (float)in1[k] + (float)in2[k];
}
store128cs(residual + c, out);
s_res[c / x128::size] = out;
}

float v = 0.f;

for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 res = s_res[c / x128::size];
for(int k = 0; k < x128::size; ++k) {
v += (float)res[k] * (float)res[k];
}
}

v = warpReduceSum(v) / C;
float s = rsqrtf(v + eps);

for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {
const x128 res = s_res[c / x128::size];
const x128 w = s_weight[c / x128::size];
x128 out;
for(int k = 0; k < x128::size; ++k) {
float n = s * (float)res[k]; // normalized output
float o = n * (float)w[k]; // scale
out[k] = o;
}

store128cs(normed + c, out);
}
// cache the rrms for the backward pass later
if(threadIdx.x == 0) {
rrms[idx] = s;
}
}

__global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

Expand Down Expand Up @@ -620,30 +489,6 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa
cudaCheck(cudaGetLastError());
}

void fused_residual_rmsnorm_forward5(floatX* residual, floatX* normed, float* rrms,
const floatX* inp1, const floatX* inp2,
const floatX* weight,
int N, int C, cudaStream_t stream) {
const int block_size = 256;
int block_y = block_size / WARP_SIZE;
const int grid_size = CEIL_DIV(N, block_y);
size_t smem = (1 + block_y) * C * sizeof(floatX);

// in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute
// this may fail, in which case we fall back to the smem free implementation.
cudaCheck(cudaGetLastError());
auto status = cudaFuncSetAttribute(fused_residual_rmsnorm_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
cudaCheck(cudaGetLastError());
if(status == cudaSuccess) {
fused_residual_rmsnorm_forward_kernel5<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(residual, normed,
rrms, inp1, inp2,
weight, N, C);
} else {
assert(false);
}
cudaCheck(cudaGetLastError());
}

void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd,
int B, int T, int C, cudaStream_t stream) {
Expand All @@ -658,27 +503,3 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr
layernorm_backward_kernel10<<<grid_size, block_size, shared_mem_size, stream>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);
cudaCheck(cudaGetLastError());
}

void rmsnorm_forward(floatX* out, float* rms,
floatX* inp, const floatX* weight,
int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
int block_y = block_size / WARP_SIZE;
const int N = B * T;
const int grid_size = CEIL_DIV(N, block_y);
size_t smem = (1 + block_y) * C * sizeof(floatX);

// in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute
// this may fail, in which case we fall back to the smem free implementation.
cudaCheck(cudaGetLastError());
auto status = cudaFuncSetAttribute(rmsnorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
cudaCheck(cudaGetLastError());
if (status == cudaSuccess) {
rmsnorm_forward_kernel6<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(out, rms, inp, weight, N, C);
} else {
// We should not allow for these perf regressions for now - just throw an error
assert(false);
}
cudaCheck(cudaGetLastError());
}
Loading

0 comments on commit 2c4b3cc

Please sign in to comment.