From 4eb13bdc09b23a9d7db705063301218a65844254 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Tue, 30 Jul 2024 01:27:41 -0700 Subject: [PATCH 1/3] adjust mem frac --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- python/sglang/srt/server_args.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 10b1b40ded8..b43f2d04561 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -212,10 +212,10 @@ def init_memory_pool(self, total_gpu_memory, max_num_reqs=None): ) if max_num_reqs is None: - max_num_reqs = max( + max_num_reqs = min(max( int(self.max_total_num_tokens / self.model_config.context_len * 512), 2048, - ) + ), 5120) self.req_to_token_pool = ReqToTokenPool( max_num_reqs, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e62987dd9d8..4940109d472 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -91,15 +91,15 @@ def __post_init__(self): self.tokenizer_path = self.model_path if self.mem_fraction_static is None: if self.tp_size >= 16: - self.mem_fraction_static = 0.80 + self.mem_fraction_static = 0.79 elif self.tp_size >= 8: - self.mem_fraction_static = 0.84 + self.mem_fraction_static = 0.83 elif self.tp_size >= 4: - self.mem_fraction_static = 0.86 + self.mem_fraction_static = 0.85 elif self.tp_size >= 2: - self.mem_fraction_static = 0.88 + self.mem_fraction_static = 0.87 else: - self.mem_fraction_static = 0.89 + self.mem_fraction_static = 0.88 if isinstance(self.additional_ports, int): self.additional_ports = [self.additional_ports] elif self.additional_ports is None: From 470b2d07d49dea27f20869b3e860f60c92e3b5da Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Tue, 30 Jul 2024 01:29:13 -0700 Subject: [PATCH 2/3] update flashinfer_useragged --- python/sglang/srt/layers/radix_attention.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index ab3a650290f..45b80b8f23e 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -103,7 +103,7 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - if not input_metadata.use_ragged: + if not input_metadata.flashinfer_use_ragged: self.store_kv_cache(k, v, input_metadata) o = input_metadata.flashinfer_prefill_wrapper_paged.forward( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6cfd2f6509e..157cfd77886 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -781,7 +781,7 @@ class InputMetadata: flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - use_ragged: bool = False + flashinfer_use_ragged: bool = False @classmethod def create( @@ -797,10 +797,10 @@ def create( return_logprob=False, skip_flashinfer_init=False, ): - use_ragged = False + flashinfer_use_ragged = False if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096: - use_ragged = True + flashinfer_use_ragged = True init_flashinfer_args( forward_mode, model_runner, @@ -808,7 +808,7 @@ def create( seq_lens, prefix_lens, model_runner.flashinfer_decode_wrapper, - use_ragged, + flashinfer_use_ragged, ) batch_size = len(req_pool_indices) @@ -863,7 +863,7 @@ def create( flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, - use_ragged=use_ragged, + flashinfer_use_ragged=flashinfer_use_ragged, ) if model_runner.server_args.disable_flashinfer: @@ -884,7 +884,7 @@ def init_flashinfer_args( seq_lens, prefix_lens, flashinfer_decode_wrapper, - use_ragged=False, + flashinfer_use_ragged=False, ): """Init auxiliary variables for FlashInfer attention backend.""" num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -893,7 +893,7 @@ def init_flashinfer_args( batch_size = len(req_pool_indices) total_num_tokens = int(torch.sum(seq_lens)) - if use_ragged: + if flashinfer_use_ragged: paged_kernel_lens = prefix_lens else: paged_kernel_lens = seq_lens @@ -929,7 +929,7 @@ def init_flashinfer_args( qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - if use_ragged: + if flashinfer_use_ragged: model_runner.flashinfer_prefill_wrapper_ragged.end_forward() model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( qo_indptr, From 394b308ad692d6f880839df59b6edcf89b75d57c Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Tue, 30 Jul 2024 01:31:24 -0700 Subject: [PATCH 3/3] lint --- python/sglang/srt/model_executor/model_runner.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b43f2d04561..e68c2e1b9a3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -212,10 +212,15 @@ def init_memory_pool(self, total_gpu_memory, max_num_reqs=None): ) if max_num_reqs is None: - max_num_reqs = min(max( - int(self.max_total_num_tokens / self.model_config.context_len * 512), - 2048, - ), 5120) + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 5120, + ) self.req_to_token_pool = ReqToTokenPool( max_num_reqs,