Skip to content

Commit

Permalink
Split the overlapped version of TpModelWorkerClient into a separate f…
Browse files Browse the repository at this point in the history
…ile (#1726)
  • Loading branch information
merrymercy authored Oct 20, 2024
1 parent 593b19f commit b48edff
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 131 deletions.
8 changes: 4 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,17 +639,17 @@ def retract_decode(self):

if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: seq_lens_cpu[idx]
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
last_uncached_pos : seq_lens_cpu[idx]
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
Expand Down
20 changes: 8 additions & 12 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
SchedulePolicy,
)
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs
Expand Down Expand Up @@ -146,26 +147,21 @@ def __init__(

# Launch a tensor parallel worker
if self.server_args.enable_overlap_schedule:
TpWorkerClass = TpModelWorker
TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else:
TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()

self.tp_worker = TpWorkerClass(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
if self.server_args.enable_overlap_schedule:
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
else:
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.forward_batch_generation = self.tp_worker.forward_batch_generation

# Get token and memory info from the model worker
(
Expand Down Expand Up @@ -728,7 +724,7 @@ def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.forward_batch_generation(
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
else:
Expand Down
109 changes: 0 additions & 109 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@

import json
import logging
import threading
import time
from queue import Queue
from typing import Optional

import torch

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
Expand Down Expand Up @@ -108,9 +103,6 @@ def __init__(
)[0]
set_random_seed(self.random_seed)

if server_args.enable_overlap_schedule:
self.init_overlap_status()

def get_worker_info(self):
return (
self.max_total_num_tokens,
Expand All @@ -137,81 +129,6 @@ def get_memory_pool(self):
self.model_runner.token_to_kv_pool,
)

def init_overlap_status(self):
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()

self.future_event_map = dict()
self.forward_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()

def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()

@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)

# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]

# Run forward
logits_output, next_token_ids = self.forward_batch_generation(
model_worker_batch
)

# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output

# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
# logger.info("Set event")
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()

if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)

def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret

def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)

def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
Expand All @@ -224,32 +141,6 @@ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
embeddings = logits_output.embeddings
return embeddings

def forward_batch_generation_non_blocking(
self, model_worker_batch: ModelWorkerBatch
):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1

bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids

self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret

def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
Expand Down
174 changes: 174 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A tensor parallel worker."""

import logging
import threading
import time
from queue import Queue
from typing import Optional

import torch

from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs

logger = logging.getLogger(__name__)


class TpModelWorkerClient:
"""A tensor parallel model worker."""

def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
):
# Load the model
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device

# Create future mappings
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()

# Launch a thread
self.future_event_map = dict()
self.forward_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()

def get_worker_info(self):
return self.worker.get_worker_info()

def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()

def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()

def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool,
)

def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()

@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)

# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]

# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch
)

# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output

self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()

if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)

def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret

def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)

def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1

bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids

self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret

def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings

def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
)
return success, message
Loading

0 comments on commit b48edff

Please sign in to comment.