Skip to content

Commit

Permalink
A copy thread
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Oct 21, 2024
1 parent 09603c6 commit 8685533
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)

# Launch a thread
# Launch threads
self.input_queue = Queue()
self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream()
Expand All @@ -64,6 +64,12 @@ def __init__(
)
self.forward_thread.start()

self.copy_queue = Queue()
self.copy_thread = threading.Thread(
target=self.copy_thread_func,
)
self.copy_thread.start()

def get_worker_info(self):
return self.worker.get_worker_info()

Expand Down Expand Up @@ -113,10 +119,16 @@ def forward_thread_func_(self):
torch.int32
)

# Set the result
next_token_ids = next_token_ids.tolist()
assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids))
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()
self.copy_queue.put((copy_event, next_token_ids))

def copy_thread_func(self):
while True:
copy_event, next_token_ids = self.copy_queue.get()
copy_event.wait()
self.output_queue.put((None, next_token_ids.tolist()))

def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
Expand Down

0 comments on commit 8685533

Please sign in to comment.