Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust default mem fraction to avoid OOM #823

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
return o

def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
if not input_metadata.use_ragged:
if not input_metadata.flashinfer_use_ragged:
self.store_kv_cache(k, v, input_metadata)

o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
use_ragged: bool = False
flashinfer_use_ragged: bool = False

@classmethod
def create(
Expand All @@ -797,18 +797,18 @@ def create(
return_logprob=False,
skip_flashinfer_init=False,
):
use_ragged = False
flashinfer_use_ragged = False
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
use_ragged = True
flashinfer_use_ragged = True
init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
model_runner.flashinfer_decode_wrapper,
use_ragged,
flashinfer_use_ragged,
)

batch_size = len(req_pool_indices)
Expand Down Expand Up @@ -863,7 +863,7 @@ def create(
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
use_ragged=use_ragged,
flashinfer_use_ragged=flashinfer_use_ragged,
)

if model_runner.server_args.disable_flashinfer:
Expand All @@ -884,7 +884,7 @@ def init_flashinfer_args(
seq_lens,
prefix_lens,
flashinfer_decode_wrapper,
use_ragged=False,
flashinfer_use_ragged=False,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
Expand All @@ -893,7 +893,7 @@ def init_flashinfer_args(
batch_size = len(req_pool_indices)
total_num_tokens = int(torch.sum(seq_lens))

if use_ragged:
if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens
Expand Down Expand Up @@ -929,7 +929,7 @@ def init_flashinfer_args(
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)

if use_ragged:
if flashinfer_use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
Expand Down
11 changes: 8 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,14 @@ def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
)

if max_num_reqs is None:
max_num_reqs = max(
int(self.max_total_num_tokens / self.model_config.context_len * 512),
2048,
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
5120,
)

self.req_to_token_pool = ReqToTokenPool(
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def __post_init__(self):
self.tokenizer_path = self.model_path
if self.mem_fraction_static is None:
if self.tp_size >= 16:
self.mem_fraction_static = 0.80
self.mem_fraction_static = 0.79
elif self.tp_size >= 8:
self.mem_fraction_static = 0.84
self.mem_fraction_static = 0.83
elif self.tp_size >= 4:
self.mem_fraction_static = 0.86
self.mem_fraction_static = 0.85
elif self.tp_size >= 2:
self.mem_fraction_static = 0.88
self.mem_fraction_static = 0.87
else:
self.mem_fraction_static = 0.89
self.mem_fraction_static = 0.88
if isinstance(self.additional_ports, int):
self.additional_ports = [self.additional_ports]
elif self.additional_ports is None:
Expand Down
Loading