diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 7089d59392c36..77cfa8490172b 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -89,8 +89,7 @@ class BlocksparseFlashAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - # For attention layer compatibility - return "FLASH_ATTN" + return "BLOCK_SPARSE_FLASH_ATTN" @staticmethod def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3c2ec9636df91..85fde76796901 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -33,6 +33,7 @@ class _Backend(enum.Enum): HPU_ATTN = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() + BLOCK_SPARSE_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto()