Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: Fix invalid kernel configuration for sm86 #385

Merged
merged 1 commit into from
Jul 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
// we expect each sm execute two threadblocks
const int max_smem_per_threadblock = max_smem_per_sm / 2;
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;

constexpr uint32_t num_warps_x = get_num_warps_x<WARP_LAYOUT>();
constexpr uint32_t num_warps_z = get_num_warps_z<WARP_LAYOUT>();
Expand Down Expand Up @@ -1949,7 +1950,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
// we expect each sm execute two threadblocks
const int max_smem_per_threadblock = max_smem_per_sm / 2;
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;

const uint32_t max_num_frags_z_reg =
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&
Expand Down Expand Up @@ -2089,7 +2091,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
// we expect each sm execute two threadblocks
const int max_smem_per_threadblock = max_smem_per_sm / 2;
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;

const uint32_t max_num_frags_z_reg =
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&
Expand Down