Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matmul using shared memory #258

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions dev/cuda/matmul_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,55 @@ __global__ void matmul_forward_kernel1(float* out,
}
}


template <const int SQRTBLOCKSIZE>
__global__ void sgemm_shared_mem_block(int BT, int OC, int C,
const float *inp, const float *weight,
float *out, const float* bias) {
// out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// the output block that we want to compute in this threadblock
const uint bt = blockIdx.x;
const uint oc = blockIdx.y;

// allocate buffer for current block in fast shared mem
// shared mem is shared between all threads in a block
__shared__ float smemInp[SQRTBLOCKSIZE * SQRTBLOCKSIZE];
__shared__ float smemWeight[SQRTBLOCKSIZE * SQRTBLOCKSIZE];

// the inner row & col that we're accessing in this thread
const uint threadCol = threadIdx.x % SQRTBLOCKSIZE;
const uint threadRow = threadIdx.x / SQRTBLOCKSIZE;

// advance pointers to the starting positions
inp += bt * SQRTBLOCKSIZE * C;
weight += oc * SQRTBLOCKSIZE * C;
out += bt * SQRTBLOCKSIZE * OC + oc * SQRTBLOCKSIZE;

float tmp = (bias == nullptr) ? 0.0 : bias[oc * SQRTBLOCKSIZE + threadCol];
for (int bkIdx = 0; bkIdx < C; bkIdx += SQRTBLOCKSIZE) {
// Have each thread load one of the elements in inp and weight into shared memory
smemInp[threadRow * SQRTBLOCKSIZE + threadCol] = inp[threadRow * C + threadCol];
smemWeight[threadRow * SQRTBLOCKSIZE + threadCol] = weight[threadCol * C + threadRow];

// block threads in this block until cache is fully populated
__syncthreads();
inp += SQRTBLOCKSIZE;
weight += SQRTBLOCKSIZE;

// execute the dotproduct on the currently cached block
for (int dotIdx = 0; dotIdx < SQRTBLOCKSIZE; ++dotIdx) {
tmp += smemInp[threadRow * SQRTBLOCKSIZE + dotIdx] *
smemWeight[dotIdx * SQRTBLOCKSIZE + threadCol];
}
// need to sync again at the end, to avoid faster threads
// fetching the next block into the cache before slower threads are done
__syncthreads();
}
out[threadRow * OC + threadCol] = tmp;
}


// is there no better way other than just adding bias with a whole separate kernel?
// this is a highly memory-bound operation, should be fused into the matmul kernel
// but i can't seem to find a cuBLAS function that does this
Expand Down Expand Up @@ -149,6 +198,45 @@ void matmul_forward2(float* out,
}
}


void matmul_forward4(float *out,
const float *inp, const float *weight, const float* bias,
int B, int T, int C, int OC, int sqrt_block_size) {
dim3 gridDim(ceil_div((B * T), sqrt_block_size), ceil_div(OC, sqrt_block_size));
dim3 blockDim(sqrt_block_size * sqrt_block_size);
// L1 cache becomes useless, since we access GMEM only via SMEM, so we carve
// out all of L1 to SMEM. This doesn't currently make a difference, since
// occupancy is limited by reg and thread count, but it's good to do anyway.
switch (sqrt_block_size) {
case 4:
cudaFuncSetAttribute(sgemm_shared_mem_block<4>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<4><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
case 8:
cudaFuncSetAttribute(sgemm_shared_mem_block<8>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<8><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
case 16:
cudaFuncSetAttribute(sgemm_shared_mem_block<16>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<16><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
case 32:
cudaFuncSetAttribute(sgemm_shared_mem_block<32>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
sgemm_shared_mem_block<32><<<gridDim, blockDim>>>(B*T, OC, C, inp, weight, out, bias);
break;
default:
break;
}
}

// uses cublasLt to fuse the bias and gelu
// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu
Expand Down Expand Up @@ -244,6 +332,9 @@ void matmul_forward(int kernel_num,
case 3:
matmul_forward3(out, inp, weight, bias, B, T, C, OC);
break;
case 4:
matmul_forward4(out, inp, weight, bias, B, T, C, OC, sqrt_block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -307,6 +398,7 @@ int main(int argc, char **argv) {

// first check the correctness of the kernel
matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC);
// matmul_forward_cpu(out, inp, weight, NULL, B, T, C, OC);

// time the kernel at different block sizes
int sqrt_block_sizes[] = {4, 8, 16, 32};
Expand Down
Loading