diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index c522c972585..355960367c0 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -85,9 +85,9 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - self.store_kv_cache(k, v, input_metadata) + if not input_metadata.use_ragged: + self.store_kv_cache(k, v, input_metadata) - if input_metadata.total_num_tokens <= 4096: o = input_metadata.flashinfer_prefill_wrapper_paged.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), @@ -122,6 +122,8 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): o, _ = merge_state(o1, s1, o2, s2) + self.store_kv_cache(k, v, input_metadata) + if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: torch.cuda.synchronize() diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 58136d4b81c..e7c5ab5f713 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -726,6 +726,7 @@ class InputMetadata: flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None + use_ragged: bool = False @classmethod def create( @@ -741,7 +742,10 @@ def create( return_logprob=False, skip_flashinfer_init=False, ): + use_ragged = False if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: + if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096: + use_ragged = True init_flashinfer_args( forward_mode, model_runner, @@ -749,6 +753,7 @@ def create( seq_lens, prefix_lens, model_runner.flashinfer_decode_wrapper, + use_ragged, ) batch_size = len(req_pool_indices) @@ -803,6 +808,7 @@ def create( flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, + use_ragged=use_ragged, ) if model_runner.server_args.disable_flashinfer: @@ -823,6 +829,7 @@ def init_flashinfer_args( seq_lens, prefix_lens, flashinfer_decode_wrapper, + use_ragged=False, ): """Init auxiliary variables for FlashInfer attention backend.""" num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -831,10 +838,10 @@ def init_flashinfer_args( batch_size = len(req_pool_indices) total_num_tokens = int(torch.sum(seq_lens)) - if forward_mode == ForwardMode.DECODE or total_num_tokens <= 4096: - paged_kernel_lens = seq_lens - else: + if use_ragged: paged_kernel_lens = prefix_lens + else: + paged_kernel_lens = seq_lens kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) @@ -867,14 +874,15 @@ def init_flashinfer_args( qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - model_runner.flashinfer_prefill_wrapper_ragged.end_forward() - model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) + if use_ragged: + model_runner.flashinfer_prefill_wrapper_ragged.end_forward() + model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) # cached part model_runner.flashinfer_prefill_wrapper_paged.end_forward()