Skip to content

Commit

Permalink
matmul using shared memory
Browse files Browse the repository at this point in the history
  • Loading branch information
patricxu committed Apr 26, 2024
1 parent 8ac4b47 commit 8e70e15
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions dev/cuda/matmul_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,55 @@ __global__ void matmul_forward_kernel1(float* out,
}
}


template <const int SQRTBLOCKSIZE>
__global__ void sgemm_shared_mem_block(int BT, int OC, int C,
const float *inp, const float *weight,
float *out, const float* bias) {
// out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// the output block that we want to compute in this threadblock
const uint bt = blockIdx.x;
const uint oc = blockIdx.y;

// allocate buffer for current block in fast shared mem
// shared mem is shared between all threads in a block
__shared__ float smemInp[SQRTBLOCKSIZE * SQRTBLOCKSIZE];
__shared__ float smemWeight[SQRTBLOCKSIZE * SQRTBLOCKSIZE];

// the inner row & col that we're accessing in this thread
const uint threadCol = threadIdx.x % SQRTBLOCKSIZE;
const uint threadRow = threadIdx.x / SQRTBLOCKSIZE;

// advance pointers to the starting positions
inp += bt * SQRTBLOCKSIZE * C;
weight += oc * SQRTBLOCKSIZE * C;
out += bt * SQRTBLOCKSIZE * OC + oc * SQRTBLOCKSIZE;

float tmp = (bias == nullptr) ? 0.0 : bias[oc * SQRTBLOCKSIZE + threadCol];
for (int bkIdx = 0; bkIdx < C; bkIdx += SQRTBLOCKSIZE) {
// Have each thread load one of the elements in inp and weight into shared memory
smemInp[threadRow * SQRTBLOCKSIZE + threadCol] = inp[threadRow * C + threadCol];
smemWeight[threadRow * SQRTBLOCKSIZE + threadCol] = weight[threadCol * C + threadRow];

// block threads in this block until cache is fully populated
__syncthreads();
inp += SQRTBLOCKSIZE;
weight += SQRTBLOCKSIZE;

// execute the dotproduct on the currently cached block
for (int dotIdx = 0; dotIdx < SQRTBLOCKSIZE; ++dotIdx) {
tmp += smemInp[threadRow * SQRTBLOCKSIZE + dotIdx] *
smemWeight[dotIdx * SQRTBLOCKSIZE + threadCol];
}
// need to sync again at the end, to avoid faster threads
// fetching the next block into the cache before slower threads are done
__syncthreads();
}
out[threadRow * OC + threadCol] = tmp;
}


// is there no better way other than just adding bias with a whole separate kernel?
// this is a highly memory-bound operation, should be fused into the matmul kernel
// but i can't seem to find a cuBLAS function that does this
Expand Down Expand Up @@ -149,6 +198,45 @@ void matmul_forward2(float* out,
}
}


void matmul_forward4(float *out,
const float *inp, const float *weight, const float* bias,
int B, int T, int C, int OC, int sqrt_block_size) {
dim3 gridDim(ceil_div((B * T), sqrt_block_size), ceil_div(OC, sqrt_block_size));
dim3 blockDim(sqrt_block_size * sqrt_block_size);
// L1 cache becomes useless, since we access GMEM only via SMEM, so we carve
// out all of L1 to SMEM. This doesn't currently make a difference, since
// occupancy is limited by reg and thread count, but it's good to do anyway.
switch (sqrt_block_size) {
case 4:
cudaFuncSetAttribute(sgemm_shared_mem_block<4>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<4><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
case 8:
cudaFuncSetAttribute(sgemm_shared_mem_block<8>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<8><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
case 16:
cudaFuncSetAttribute(sgemm_shared_mem_block<16>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<16><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
case 32:
cudaFuncSetAttribute(sgemm_shared_mem_block<32>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<32><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
default:
break;
}
}

// uses cublasLt to fuse the bias and gelu
// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu
Expand Down Expand Up @@ -244,6 +332,9 @@ void matmul_forward(int kernel_num,
case 3:
matmul_forward3(out, inp, weight, bias, B, T, C, OC);
break;
case 4:
matmul_forward4(out, inp, weight, bias, B, T, C, OC, sqrt_block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -307,6 +398,7 @@ int main(int argc, char **argv) {

// first check the correctness of the kernel
matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC);
// matmul_forward_cpu(out, inp, weight, NULL, B, T, C, OC);

// time the kernel at different block sizes
int sqrt_block_sizes[] = {4, 8, 16, 32};
Expand Down

0 comments on commit 8e70e15

Please sign in to comment.