Skip to content

Commit

Permalink
[Misc] Set default backend to SDPA for get_vit_attn_backend (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#12235)

Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan authored and tjtanaa committed Jan 28, 2025
1 parent 0db6a75 commit cbe2a73
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,25 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
if current_platform.is_cuda():
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision "
"module, so we use xformers backend instead. You can "
"run `pip install flash-attn` to use flash-attention "
"backend.")
selected_backend = _Backend.XFORMERS
else:
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
# For Volta and Turing GPUs, use xformers instead.
selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
# Default to torch SDPA for other non-GPU platforms.
selected_backend = _Backend.TORCH_SDPA
return selected_backend


Expand Down

0 comments on commit cbe2a73

Please sign in to comment.