diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 08b62458fd7..a753233a9f6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -14,7 +14,10 @@ from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend -from sglang.srt.layers.attention.flashinfer_utils import update_flashinfer_indices +from sglang.srt.layers.attention.flashinfer_utils import ( + WrapperDispatch, + update_flashinfer_indices, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_hip @@ -53,10 +56,19 @@ def __init__(self, model_runner: ModelRunner): device="cuda", ) + assert not ( + model_runner.sliding_window_size is not None + and model_runner.has_cross_attention + ), "Sliding window and cross attention are not supported together" + + self.num_wrappers = 1 + self.dispatch_reason = None if model_runner.sliding_window_size is not None: self.num_wrappers = 2 - else: - self.num_wrappers = 1 + self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW + elif model_runner.has_cross_attention: + self.num_wrappers = 2 + self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION # NOTE: we do not use ragged attention when there are multiple wrappers self.prefill_wrapper_ragged = ( @@ -88,8 +100,12 @@ def _get_wrapper_idx(self, layer: nn.Module): if self.num_wrappers == 1: return 0 - # TODO: make sure the idx is related to sliding window size - return layer.sliding_window_size == -1 + if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + return layer.sliding_window_size == -1 + if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + return layer.is_cross_attention + + raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}") def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode(): diff --git a/python/sglang/srt/layers/attention/flashinfer_utils.py b/python/sglang/srt/layers/attention/flashinfer_utils.py index 9568226ea3d..796203c933c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_utils.py +++ b/python/sglang/srt/layers/attention/flashinfer_utils.py @@ -1,8 +1,15 @@ +from enum import Enum, auto + import torch import triton import triton.language as tl +class WrapperDispatch(Enum): + SLIDING_WINDOW = auto() + CROSS_ATTENTION = auto() + + @triton.jit def create_flashinfer_kv_indices_triton( req_to_token_ptr, # [max_batch, max_context_len] @@ -80,67 +87,6 @@ def __init__( (self.batch_size,), dtype=torch.int32, device="cuda" ) - def _init_indices_no_sliding_window(self): - if self.use_ragged: - paged_kernel_lens = self.prefix_lens - else: - paged_kernel_lens = self.seq_lens - - self.kv_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - self.kv_indices = torch.empty( - self.kv_indptr[-1], dtype=torch.int32, device="cuda" - ) - - create_flashinfer_kv_indices_triton[(self.batch_size,)]( - self.model_runner.req_to_token_pool.req_to_token, - self.req_pool_indices, - paged_kernel_lens, - self.kv_indptr, - None, - self.kv_indices, - self.model_runner.req_to_token_pool.req_to_token.size(1), - ) - - def _init_indices_sliding_window(self, wrapper_id): - if wrapper_id == 0: - # window attention use paged only - if self.forward_mode.is_decode(): - paged_kernel_lens = torch.minimum( - self.seq_lens, - torch.tensor(self.model_runner.sliding_window_size + 1), - ) - else: - paged_kernel_lens = torch.minimum( - self.seq_lens, - torch.tensor(self.model_runner.sliding_window_size) - + self.seq_lens - - self.prefix_lens, - ) - else: - # full attention - paged_kernel_lens = self.seq_lens - - kv_start_idx = self.seq_lens - paged_kernel_lens - self.kv_indptr = torch.zeros( - (self.batch_size + 1,), dtype=torch.int32, device="cuda" - ) - self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - self.kv_indices = torch.empty( - self.kv_indptr[-1], dtype=torch.int32, device="cuda" - ) - create_flashinfer_kv_indices_triton[(self.batch_size,)]( - self.model_runner.req_to_token_pool.req_to_token, - self.req_pool_indices, - paged_kernel_lens, - self.kv_indptr, - kv_start_idx, - self.kv_indices, - self.model_runner.req_to_token_pool.req_to_token.size(1), - ) - def _update_decode_indices(self, decode_wrapper): assert not isinstance(decode_wrapper, list) decode_wrapper.end_forward() @@ -189,8 +135,53 @@ def _update_extend_indices(self, ragged_wrapper, paged_wrapper): 1, ) - def update_indices_no_sliding_window(self): - self._init_indices_no_sliding_window() + def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0): + if dispatch_reason is None: + if self.use_ragged: + paged_kernel_lens = self.prefix_lens + else: + paged_kernel_lens = self.seq_lens + self.kv_start_idx = None + elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + if wrapper_id == 0: + # window attention use paged only + if self.forward_mode.is_decode(): + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size + 1), + ) + else: + paged_kernel_lens = torch.minimum( + self.seq_lens, + torch.tensor(self.model_runner.sliding_window_size) + + self.seq_lens + - self.prefix_lens, + ) + else: + # full attention + paged_kernel_lens = self.seq_lens + self.kv_start_idx = self.seq_lens - paged_kernel_lens + + self.kv_indptr = torch.zeros( + (self.batch_size + 1,), dtype=torch.int32, device="cuda" + ) + self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + self.kv_indices = torch.empty( + self.kv_indptr[-1], dtype=torch.int32, device="cuda" + ) + + create_flashinfer_kv_indices_triton[(self.batch_size,)]( + self.model_runner.req_to_token_pool.req_to_token, + self.req_pool_indices, + paged_kernel_lens, + self.kv_indptr, + self.kv_start_idx, + self.kv_indices, + self.model_runner.req_to_token_pool.req_to_token.size(1), + ) + + def _update_indicess_single_wrapper(self): + self._get_indices() if self.forward_mode.is_decode(): self._update_decode_indices(self.decode_wrappers[0]) @@ -200,11 +191,13 @@ def update_indices_no_sliding_window(self): self.prefill_wrappers_paged[0], ) - def update_indices_sliding_window(self): - assert self.use_ragged is False + def _update_indices_cross_attention(self): + pass + def _update_indices_sliding_window(self): + assert self.use_ragged is False for wrapper_id in range(2): - self._init_indices_sliding_window(wrapper_id) + self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id) if self.forward_mode.is_decode(): self._update_decode_indices(self.decode_wrappers[wrapper_id]) else: @@ -233,7 +226,12 @@ def update_flashinfer_indices( use_ragged, ) - if model_runner.sliding_window_size is None: - updater.update_indices_no_sliding_window() + dispatch_reason = model_runner.attn_backend.dispatch_reason + + if dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + updater._update_indices_sliding_window() + elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION: + updater._update_indices_cross_attention() else: - updater.update_indices_sliding_window() + assert model_runner.attn_backend.num_wrappers == 1 + updater._update_indicess_single_wrapper() diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 25432660e39..61437362327 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -32,9 +32,10 @@ def __init__( scaling: float, num_kv_heads: int, layer_id: int, - sliding_window_size: int = -1, logit_cap: float = 0.0, v_head_dim: int = -1, + sliding_window_size: int = -1, + is_cross_attention: bool = False, ): super().__init__() self.tp_q_head_num = num_heads @@ -47,6 +48,7 @@ def __init__( self.layer_id = layer_id self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 + self.is_cross_attention = is_cross_attention def forward(self, q, k, v, forward_batch: ForwardBatch): if k is not None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 63cd1d3d6a3..79d42f5469b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -231,6 +231,7 @@ def load_model(self): if hasattr(self.model, "get_attention_sliding_window_size") else None ) + self.has_cross_attention = getattr(self.model, "has_cross_attention", False) self.is_generation = is_generation_model( self.model_config.hf_config.architectures, self.server_args.is_embedding ) @@ -453,6 +454,10 @@ def init_attention_backend(self): "Window attention is not supported in the triton attention backend. " "Please use `--attention-backend flashinfer`." ) + assert not self.has_cross_attention, ( + "Cross attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) self.attn_backend = TritonAttnBackend(self) else: raise ValueError( diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 59205416a93..47fbd6334c1 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -163,12 +163,12 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_idx, + logit_cap=self.config.attn_logit_softcapping, sliding_window_size=( get_attention_sliding_window_size(config) if use_sliding_window else None ), - logit_cap=self.config.attn_logit_softcapping, ) def forward(