Skip to content

Commit

Permalink
fix schedule bug (#1450)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Sep 17, 2024
1 parent b3710d2 commit 36078fb
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 99 deletions.
141 changes: 48 additions & 93 deletions python/sglang/srt/managers/policy_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,32 @@ def __init__(
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_total_tokens_ = self.rem_total_tokens
self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens

self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens

self.req_states = None
self.can_run_list = []
self.new_inflight_req = None
self.log_hit_tokens = 0
self.log_input_tokens = 0

if running_batch is not None:
# Pre-remove the tokens which will be occupied by the running requests
self.rem_total_tokens -= sum(
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
for r in running_batch.reqs
]
)

def no_remaining_tokens(self):
return (
self.rem_total_tokens <= 0
Expand All @@ -141,61 +154,22 @@ def no_remaining_tokens(self):
if self.rem_chunk_tokens is not None
else False
)
)

def remove_running_tokens(self, running_batch: ScheduleBatch):
self.rem_total_tokens -= sum(
[
min(
(r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS,
)
* self.new_token_ratio
for r in running_batch.reqs
]
)
self.rem_total_tokens_ -= sum(
[
r.sampling_params.max_new_tokens - len(r.output_ids)
for r in running_batch.reqs
]
or self.cur_rem_tokens <= 0
)

def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
self.rem_total_tokens -= extend_input_len + max_new_tokens
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
self.cur_rem_tokens -= extend_input_len
self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len

self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len

def add_inflight_req_ignore_eos(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)

self._prefill_one_req(
0,
req.extend_input_len,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)

# Return if chunked prefill not finished
return req if truncated else None

def add_inflight_req(self, req: Req):
if req.sampling_params.ignore_eos:
return self.add_inflight_req_ignore_eos(req)

truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
Expand Down Expand Up @@ -225,7 +199,7 @@ def _lock_node(self, last_node: TreeNode):
self.rem_total_tokens += delta

def add_one_req_ignore_eos(self, req: Req):
def get_req_state(r):
def add_req_state(r, insert_sort=False):
new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
)
Expand All @@ -235,56 +209,37 @@ def get_req_state(r):
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)

if tokens_left > 0:
return (tokens_left, tokens_occupied)

return None

# Quick Check
can_run = False
if (
req.extend_input_len + req.sampling_params.max_new_tokens
<= self.rem_total_tokens
):
can_run = True

if not can_run:
if self.req_states is None:
self.req_states = []
if self.running_batch is not None:
for r in self.running_batch.reqs:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
for r in self.can_run_list:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
state = get_req_state(req)
if state is not None:
self.req_states.append(state)

self.req_states.sort(key=lambda x: x[0])
else:
state = get_req_state(req)
if state is not None:
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
if tokens_left >= state[0]:
self.req_states.insert(i, state)
if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
else:
for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]:
break
else:
self.req_states.append(state)

tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
bs = len(self.req_states) - i
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
return False
tokens_freed += tokens_occupied
self.req_states.insert(i, (tokens_left, tokens_occupied))

if self.req_states is None:
self.req_states = []
add_req_state(req)
if self.running_batch is not None:
for r in self.running_batch.reqs:
add_req_state(r)
for r in self.can_run_list:
add_req_state(r)
self.req_states.sort(key=lambda x: x[0])
else:
add_req_state(req, insert_sort=True)

tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
bs = len(self.req_states) - i
if self.cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
return False
tokens_freed += tokens_occupied

if req.extend_input_len <= self.rem_chunk_tokens:
self.can_run_list.append(req)
Expand Down
17 changes: 11 additions & 6 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,6 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
num_mixed_running,
)

if self.running_batch is not None:
adder.remove_running_tokens(self.running_batch)

has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input(
Expand All @@ -465,9 +462,6 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
)

for req in self.waiting_queue:
if adder.no_remaining_tokens():
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
if (
self.lora_paths is not None
and len(
Expand All @@ -478,6 +472,10 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
> self.max_loras_per_batch
):
break

if adder.no_remaining_tokens():
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res
Expand Down Expand Up @@ -507,6 +505,11 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
else:
tree_cache_hit_rate = 0.0

num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)

if num_mixed_running > 0:
logger.info(
f"Prefill batch"
Expand All @@ -515,6 +518,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
else:
Expand All @@ -524,6 +528,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
Expand Down

0 comments on commit 36078fb

Please sign in to comment.