From 313befb74d8dcf29a58823dc242995915f1bd664 Mon Sep 17 00:00:00 2001 From: lmy Date: Tue, 23 Jul 2024 22:04:13 +0000 Subject: [PATCH 1/2] remove hardcoded logic --- python/sglang/srt/layers/radix_attention.py | 6 ++-- .../srt/managers/controller/infer_batch.py | 29 ++++++++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index c522c972585..355960367c0 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -85,9 +85,9 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - self.store_kv_cache(k, v, input_metadata) + if not input_metadata.use_ragged: + self.store_kv_cache(k, v, input_metadata) - if input_metadata.total_num_tokens <= 4096: o = input_metadata.flashinfer_prefill_wrapper_paged.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), @@ -122,6 +122,8 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): o, _ = merge_state(o1, s1, o2, s2) + self.store_kv_cache(k, v, input_metadata) + if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: torch.cuda.synchronize() diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 58136d4b81c..4b0b8b076cb 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -726,6 +726,7 @@ class InputMetadata: flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None + use_ragged = False @classmethod def create( @@ -742,6 +743,8 @@ def create( skip_flashinfer_init=False, ): if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: + if forward_mode != ForwardMode.DECODE and total_num_tokens > 4096: + use_ragged = True init_flashinfer_args( forward_mode, model_runner, @@ -749,6 +752,7 @@ def create( seq_lens, prefix_lens, model_runner.flashinfer_decode_wrapper, + use_ragged, ) batch_size = len(req_pool_indices) @@ -803,6 +807,7 @@ def create( flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, + use_ragged=use_ragged, ) if model_runner.server_args.disable_flashinfer: @@ -823,6 +828,7 @@ def init_flashinfer_args( seq_lens, prefix_lens, flashinfer_decode_wrapper, + use_ragged=False, ): """Init auxiliary variables for FlashInfer attention backend.""" num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -831,10 +837,10 @@ def init_flashinfer_args( batch_size = len(req_pool_indices) total_num_tokens = int(torch.sum(seq_lens)) - if forward_mode == ForwardMode.DECODE or total_num_tokens <= 4096: - paged_kernel_lens = seq_lens - else: + if use_ragged: paged_kernel_lens = prefix_lens + else: + paged_kernel_lens = seq_lens kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) @@ -867,14 +873,15 @@ def init_flashinfer_args( qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - model_runner.flashinfer_prefill_wrapper_ragged.end_forward() - model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) + if use_ragged: + model_runner.flashinfer_prefill_wrapper_ragged.end_forward() + model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) # cached part model_runner.flashinfer_prefill_wrapper_paged.end_forward() From 610ca501e2760399868e95492835a1951cfbca77 Mon Sep 17 00:00:00 2001 From: lmy Date: Tue, 23 Jul 2024 23:38:10 +0000 Subject: [PATCH 2/2] update --- python/sglang/srt/managers/controller/infer_batch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 4b0b8b076cb..e7c5ab5f713 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -726,7 +726,7 @@ class InputMetadata: flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - use_ragged = False + use_ragged: bool = False @classmethod def create( @@ -742,8 +742,9 @@ def create( return_logprob=False, skip_flashinfer_init=False, ): + use_ragged = False if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: - if forward_mode != ForwardMode.DECODE and total_num_tokens > 4096: + if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096: use_ragged = True init_flashinfer_args( forward_mode,