diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 5d78b97ce43..8b27d2a69a9 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -55,7 +55,7 @@ def __init__( (self.max_running_requests * 5,), dtype=torch.int32, device=self.device ) - # Launch a thread + # Launch threads self.input_queue = Queue() self.output_queue = Queue() self.forward_stream = torch.cuda.Stream() @@ -64,6 +64,12 @@ def __init__( ) self.forward_thread.start() + self.copy_queue = Queue() + self.copy_thread = threading.Thread( + target=self.copy_thread_func, + ) + self.copy_thread.start() + def get_worker_info(self): return self.worker.get_worker_info() @@ -86,7 +92,10 @@ def forward_thread_func(self): @torch.inference_mode() def forward_thread_func_(self): while True: + self.has_inflight_batch = False model_worker_batch, future_token_ids_ct = self.input_queue.get() + self.has_inflight_batch = True + self.launch_event = threading.Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids @@ -100,6 +109,7 @@ def forward_thread_func_(self): logits_output, next_token_ids = self.worker.forward_batch_generation( model_worker_batch ) + self.launch_event.set() # Update the future token ids map bs = len(model_worker_batch.seq_lens) @@ -113,13 +123,23 @@ def forward_thread_func_(self): torch.int32 ) - # Set the result - next_token_ids = next_token_ids.tolist() - assert logits_output.next_token_logprobs is None, "Not supported" - self.output_queue.put((None, next_token_ids)) + next_token_ids = next_token_ids.to("cpu", non_blocking=True) + copy_event = torch.cuda.Event(blocking=True) + copy_event.record() + self.copy_queue.put((copy_event, next_token_ids)) + + def copy_thread_func(self): + while True: + copy_event, next_token_ids = self.copy_queue.get() + while not copy_event.query(): + time.sleep(1e-5) + self.output_queue.put((None, next_token_ids.tolist())) def resulve_batch_result(self, bid: int): logits_output, next_token_ids = self.output_queue.get() + if self.has_inflight_batch: + # Wait until the batch is launched + self.launch_event.wait() return logits_output, next_token_ids def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):