diff --git a/dev/cuda/matmul_forward.cu b/dev/cuda/matmul_forward.cu index ec13805a3..8f73e5401 100644 --- a/dev/cuda/matmul_forward.cu +++ b/dev/cuda/matmul_forward.cu @@ -82,6 +82,55 @@ __global__ void matmul_forward_kernel1(float* out, } } + +template +__global__ void sgemm_shared_mem_block(int BT, int OC, int C, + const float *inp, const float *weight, + float *out, const float* bias) { + // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C + // inp is (B,T,C), weight is (OC, C), bias is (OC) + // the output block that we want to compute in this threadblock + const uint bt = blockIdx.x; + const uint oc = blockIdx.y; + + // allocate buffer for current block in fast shared mem + // shared mem is shared between all threads in a block + __shared__ float smemInp[SQRTBLOCKSIZE * SQRTBLOCKSIZE]; + __shared__ float smemWeight[SQRTBLOCKSIZE * SQRTBLOCKSIZE]; + + // the inner row & col that we're accessing in this thread + const uint threadCol = threadIdx.x % SQRTBLOCKSIZE; + const uint threadRow = threadIdx.x / SQRTBLOCKSIZE; + + // advance pointers to the starting positions + inp += bt * SQRTBLOCKSIZE * C; + weight += oc * SQRTBLOCKSIZE * C; + out += bt * SQRTBLOCKSIZE * OC + oc * SQRTBLOCKSIZE; + + float tmp = (bias == nullptr) ? 0.0 : bias[oc * SQRTBLOCKSIZE + threadCol]; + for (int bkIdx = 0; bkIdx < C; bkIdx += SQRTBLOCKSIZE) { + // Have each thread load one of the elements in inp and weight into shared memory + smemInp[threadRow * SQRTBLOCKSIZE + threadCol] = inp[threadRow * C + threadCol]; + smemWeight[threadRow * SQRTBLOCKSIZE + threadCol] = weight[threadCol * C + threadRow]; + + // block threads in this block until cache is fully populated + __syncthreads(); + inp += SQRTBLOCKSIZE; + weight += SQRTBLOCKSIZE; + + // execute the dotproduct on the currently cached block + for (int dotIdx = 0; dotIdx < SQRTBLOCKSIZE; ++dotIdx) { + tmp += smemInp[threadRow * SQRTBLOCKSIZE + dotIdx] * + smemWeight[dotIdx * SQRTBLOCKSIZE + threadCol]; + } + // need to sync again at the end, to avoid faster threads + // fetching the next block into the cache before slower threads are done + __syncthreads(); + } + out[threadRow * OC + threadCol] = tmp; +} + + // is there no better way other than just adding bias with a whole separate kernel? // this is a highly memory-bound operation, should be fused into the matmul kernel // but i can't seem to find a cuBLAS function that does this @@ -149,6 +198,45 @@ void matmul_forward2(float* out, } } + +void matmul_forward4(float *out, + const float *inp, const float *weight, const float* bias, + int B, int T, int C, int OC, int sqrt_block_size) { + dim3 gridDim(ceil_div((B * T), sqrt_block_size), ceil_div(OC, sqrt_block_size)); + dim3 blockDim(sqrt_block_size * sqrt_block_size); + // L1 cache becomes useless, since we access GMEM only via SMEM, so we carve + // out all of L1 to SMEM. This doesn't currently make a difference, since + // occupancy is limited by reg and thread count, but it's good to do anyway. + switch (sqrt_block_size) { + case 4: + cudaFuncSetAttribute(sgemm_shared_mem_block<4>, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxShared); + sgemm_shared_mem_block<4><<>>(B*T, OC, C, inp, weight, out, bias); + break; + case 8: + cudaFuncSetAttribute(sgemm_shared_mem_block<8>, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxShared); + sgemm_shared_mem_block<8><<>>(B*T, OC, C, inp, weight, out, bias); + break; + case 16: + cudaFuncSetAttribute(sgemm_shared_mem_block<16>, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxShared); + sgemm_shared_mem_block<16><<>>(B*T, OC, C, inp, weight, out, bias); + break; + case 32: + cudaFuncSetAttribute(sgemm_shared_mem_block<32>, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxShared); + sgemm_shared_mem_block<32><<>>(B*T, OC, C, inp, weight, out, bias); + break; + default: + break; + } +} + // uses cublasLt to fuse the bias and gelu // https://docs.nvidia.com/cuda/cublas/#cublasltmatmul // https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu @@ -244,6 +332,9 @@ void matmul_forward(int kernel_num, case 3: matmul_forward3(out, inp, weight, bias, B, T, C, OC); break; + case 4: + matmul_forward4(out, inp, weight, bias, B, T, C, OC, sqrt_block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -307,6 +398,7 @@ int main(int argc, char **argv) { // first check the correctness of the kernel matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC); + // matmul_forward_cpu(out, inp, weight, NULL, B, T, C, OC); // time the kernel at different block sizes int sqrt_block_sizes[] = {4, 8, 16, 32};