Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some online scheduling delay (#1345) #3

Merged
merged 1 commit into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions python/sglang/srt/managers/policy_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_total_tokens_ = self.rem_total_tokens
self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
Expand Down Expand Up @@ -153,11 +154,18 @@ def remove_running_tokens(self, running_batch: ScheduleBatch):
for r in running_batch.reqs
]
)
self.rem_total_tokens_ -= sum(
[
r.sampling_params.max_new_tokens - len(r.output_ids)
for r in running_batch.reqs
]
)

def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
self.rem_total_tokens -= extend_input_len + max_new_tokens
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len
Expand Down Expand Up @@ -231,43 +239,52 @@ def get_req_state(r):

return None

if self.req_states is None:
self.req_states = []
if self.running_batch is not None:
for r in self.running_batch.reqs:
# Quick Check
can_run = False
if (
req.extend_input_len + req.sampling_params.max_new_tokens
<= self.rem_total_tokens
):
can_run = True

if not can_run:
if self.req_states is None:
self.req_states = []
if self.running_batch is not None:
for r in self.running_batch.reqs:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
for r in self.can_run_list:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
for r in self.can_run_list:
state = get_req_state(r)
state = get_req_state(req)
if state is not None:
self.req_states.append(state)
state = get_req_state(req)
if state is not None:
self.req_states.append(state)

self.req_states.sort(key=lambda x: x[0])
else:
state = get_req_state(req)
if state is not None:
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
if tokens_left >= state[0]:
self.req_states.insert(i, state)
break
else:
self.req_states.append(state)
self.req_states.sort(key=lambda x: x[0])
else:
state = get_req_state(req)
if state is not None:
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
if tokens_left >= state[0]:
self.req_states.insert(i, state)
break
else:
self.req_states.append(state)

tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
bs = len(self.req_states) - i
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
return False
tokens_freed += tokens_occupied
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
bs = len(self.req_states) - i
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
return False
tokens_freed += tokens_occupied

if req.extend_input_len <= self.rem_chunk_tokens:
self.can_run_list.append(req)
Expand Down
11 changes: 5 additions & 6 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def exposed_step(self, recv_reqs: List):
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
Expand All @@ -254,12 +255,10 @@ def exposed_step(self, recv_reqs: List):

@torch.inference_mode()
def forward_step(self):
if self.current_inflight_req is not None:
self.do_not_get_new_batch = False

new_batch = (
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
)
if self.do_not_get_new_batch and self.current_inflight_req is None:
new_batch = None
else:
new_batch = self.get_new_prefill_batch()
self.do_not_get_new_batch = False

if new_batch is not None:
Expand Down
Loading