diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 34d5c8ad089a3..e9fce917369b9 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -262,7 +262,7 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend: _Backend = get_vit_attn_backend() if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index a1395982af44c..e867bc1b13d5c 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -71,34 +71,34 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(support_fa: bool = False) -> _Backend: +def get_vit_attn_backend() -> _Backend: """ Get the available attention backend for Vision Transformer. """ - # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. selected_backend: Optional[_Backend] = get_global_forced_attn_backend() if selected_backend is None: backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: - # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.has_device_capability(80) - if device_available and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - selected_backend = _Backend.FLASH_ATTN + if current_platform.is_cuda(): + if current_platform.has_device_capability(80): + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + selected_backend = _Backend.FLASH_ATTN + else: + logger.warning_once( + "Current `vllm-flash-attn` has a bug inside vision " + "module, so we use xformers backend instead. You can " + "run `pip install flash-attn` to use flash-attention " + "backend.") + selected_backend = _Backend.XFORMERS else: - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision module, " - "so we use xformers backend instead. You can run " - "`pip install flash-attn` to use flash-attention backend.") + # For Volta and Turing GPUs, use xformers instead. selected_backend = _Backend.XFORMERS - elif current_platform.is_cpu() or current_platform.is_rocm(): - # ROCM doesn't support xformers - selected_backend = _Backend.TORCH_SDPA else: - selected_backend = _Backend.XFORMERS + # Default to torch SDPA for CPU, ROCm, and other non-GPU platforms. + selected_backend = _Backend.TORCH_SDPA return selected_backend