diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 1db47f8c8c2e0..f109792ad251b 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -4,7 +4,7 @@ import pytest import torch -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] @@ -123,7 +123,10 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", + use_tensor_cores=( + (num_query_heads//num_kv_heads) not in (1, 2, 4, 8)) + ) wrapper.begin_forward(kv_indptr, kv_indices, kv_last_page_lens, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2aa3bd79e4a64..ce7a7198dc400 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -113,7 +113,8 @@ def _get_decode_wrapper(self): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ + (1, 2, 4, 8) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD", @@ -171,7 +172,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = num_qo_heads // num_kv_heads >= 4 + use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ + (1, 2, 4, 8) self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer,