Skip to content

Commit

Permalink
Make token mapping non-blocking in the overlapped mode (#1740)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 21, 2024
1 parent 45d5af2 commit cf470fe
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
22 changes: 6 additions & 16 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,15 @@ def forward_thread_func(self):
@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_token_ids_ct = self.input_queue.get()

# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]
input_ids = model_worker_batch.input_ids
input_ids[:] = torch.where(
input_ids < 0,
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)

# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
Expand All @@ -119,15 +118,6 @@ def forward_thread_func_(self):
assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids))

if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)

def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
return logits_output, next_token_ids
Expand Down
1 change: 1 addition & 0 deletions test/killall_sglang.sh
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')

0 comments on commit cf470fe

Please sign in to comment.