diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index 5039ac2448f83..2f9da6fa3e1d3 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -576,7 +576,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954 + GIT_TAG 0aff05f577e8a10086066a00618609199b25231d GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 36c89d435c7b7..ee193e4693806 --- a/setup.py +++ b/setup.py @@ -598,7 +598,10 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) - ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"): + # FA3 requires CUDA 12.0 or later + ext_modules.append( + CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/kernels/test_cascade_flash_attn.py b/tests/kernels/test_cascade_flash_attn.py old mode 100644 new mode 100755 index 00eb927205d46..8edfde42ede74 --- a/tests/kernels/test_cascade_flash_attn.py +++ b/tests/kernels/test_cascade_flash_attn.py @@ -6,7 +6,9 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import (cascade_attention, merge_attn_states) -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 192, 256] @@ -91,10 +93,9 @@ def test_cascade( fa_version: int, ) -> None: torch.set_default_device("cuda") - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) - or torch.cuda.get_device_capability() == (8, 9)): - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " - "insufficient shared memory for some shapes") + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index b22153c86b25f..0ee0bf6c6a374 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,8 +4,10 @@ import torch from vllm.platforms import current_platform -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv( fa_version: int, ) -> None: torch.set_default_device("cuda") - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) - or torch.cuda.get_device_capability() == (8, 9)): - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " - "insufficient shared memory for some shapes") + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) num_seqs = len(kv_lens) @@ -182,11 +183,9 @@ def test_varlen_with_paged_kv( fa_version: int, ) -> None: torch.set_default_device("cuda") - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) - or torch.cuda.get_device_capability() == (8, 9)): - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " - "insufficient shared memory for some shapes") - + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py old mode 100644 new mode 100755 index 1be099283e472..4a9aa1e217365 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -18,17 +18,20 @@ get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache, - is_fa_version_supported) +logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): @@ -652,6 +655,11 @@ def __init__( assert VLLM_FLASH_ATTN_VERSION in [2, 3] self.fa_version = VLLM_FLASH_ATTN_VERSION + if not is_fa_version_supported(self.fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + self.fa_version, + fa_version_unsupported_reason(self.fa_version)) + assert is_fa_version_supported(self.fa_version) def forward( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py old mode 100644 new mode 100755 index 7fe9b3a8f595a..ce83b1fac6c0b --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,11 +10,15 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import (flash_attn_varlen_func, +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, is_fa_version_supported) +logger = init_logger(__name__) + class FlashAttentionBackend(AttentionBackend): @@ -143,6 +147,11 @@ def __init__( assert VLLM_FLASH_ATTN_VERSION in [2, 3] self.fa_version = VLLM_FLASH_ATTN_VERSION + if not is_fa_version_supported(self.fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + self.fa_version, + fa_version_unsupported_reason(self.fa_version)) + assert is_fa_version_supported(self.fa_version) def forward(