Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix python package dispatch error message #182

Merged
merged 1 commit into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
return DISPATCH_group_size(num_qo_heads / num_kv_heads, [&] {
return DISPATCH_head_dim(head_dim, [&] {
bool success = DISPATCH_group_size(num_qo_heads / num_kv_heads, [&] {
bool success = DISPATCH_head_dim(head_dim, [&] {
DISPATCH_CAUSAL(causal, CAUSAL, {
DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
DISPATCH_POS_ENCODING_MODE(PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, {
Expand All @@ -127,7 +127,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
});
return true;
});
TORCH_CHECK(success, "BatchPrefillWithPagedKVCache failed to dispatch head_dim ", head_dim);
return success;
});
TORCH_CHECK(success, "BatchPrefillWithPagedKVCache failed to dispatch group_size ",
num_qo_heads / num_kv_heads);
});
return true;
});
Expand Down
12 changes: 10 additions & 2 deletions python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
}

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] {
return DISPATCH_group_size(num_qo_heads / num_kv_heads, [&] {
return DISPATCH_head_dim(head_dim, [&] {
bool success = DISPATCH_group_size(num_qo_heads / num_kv_heads, [&] {
bool success = DISPATCH_head_dim(head_dim, [&] {
DISPATCH_CAUSAL(causal, CAUSAL, {
DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
Expand All @@ -80,7 +80,15 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
});
return true;
});
TORCH_CHECK(success,
"SinglePrefillWithKVCache kernel launch failed, error: unknown head_dim ",
head_dim);
return success;
});
TORCH_CHECK(success,
"SinglePrefillWithKVCache kernel launch failed, error: unknown group_size ",
num_qo_heads / num_kv_heads);
return success;
});

TORCH_CHECK(success, "SinglePrefillWithKVCache kernel launch failed, error: unknown dtype");
Expand Down