Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix marlin kernel crash on H100 #4218

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions csrc/quantization/marlin/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,13 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}

// Asynchronous global->shared copy with a cache hint indicating that the values
// may be evicted immediately; used for quantized weights B, which are only
// accessed precisely once and should thus not pollute the L2 cache which we
// need for inputs A and outputs C.
__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) {
// Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
}

// Async copy fence.
Expand Down Expand Up @@ -448,14 +441,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
Expand Down Expand Up @@ -750,7 +743,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// write-out
if (group_blocks == -1 && last) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
thread_block_reduce();
Expand Down
Loading