Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the overlapped version of TpModelWorkerClient into a separate file #1726

Merged
merged 5 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,17 +639,17 @@ def retract_decode(self):

if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: seq_lens_cpu[idx]
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
last_uncached_pos : seq_lens_cpu[idx]
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
Expand Down
20 changes: 8 additions & 12 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
SchedulePolicy,
)
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs
Expand Down Expand Up @@ -146,26 +147,21 @@ def __init__(

# Launch a tensor parallel worker
if self.server_args.enable_overlap_schedule:
TpWorkerClass = TpModelWorker
TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else:
TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()

self.tp_worker = TpWorkerClass(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
if self.server_args.enable_overlap_schedule:
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
else:
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.forward_batch_generation = self.tp_worker.forward_batch_generation

# Get token and memory info from the model worker
(
Expand Down Expand Up @@ -728,7 +724,7 @@ def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.forward_batch_generation(
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
else:
Expand Down
109 changes: 0 additions & 109 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@

import json
import logging
import threading
import time
from queue import Queue
from typing import Optional

import torch

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
Expand Down Expand Up @@ -108,9 +103,6 @@ def __init__(
)[0]
set_random_seed(self.random_seed)

if server_args.enable_overlap_schedule:
self.init_overlap_status()

def get_worker_info(self):
return (
self.max_total_num_tokens,
Expand All @@ -137,81 +129,6 @@ def get_memory_pool(self):
self.model_runner.token_to_kv_pool,
)

def init_overlap_status(self):
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()

self.future_event_map = dict()
self.forward_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()

def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()

@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)

# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]

# Run forward
logits_output, next_token_ids = self.forward_batch_generation(
model_worker_batch
)

# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output

# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
# logger.info("Set event")
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()

if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)

def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret

def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)

def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
Expand All @@ -224,32 +141,6 @@ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
embeddings = logits_output.embeddings
return embeddings

def forward_batch_generation_non_blocking(
self, model_worker_batch: ModelWorkerBatch
):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1

bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids

self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret

def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
Expand Down
174 changes: 174 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A tensor parallel worker."""

import logging
import threading
import time
from queue import Queue
from typing import Optional

import torch

from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs

logger = logging.getLogger(__name__)


class TpModelWorkerClient:
"""A tensor parallel model worker."""

def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
):
# Load the model
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device

# Create future mappings
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()

# Launch a thread
self.future_event_map = dict()
self.forward_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()

def get_worker_info(self):
return self.worker.get_worker_info()

def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()

def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()

def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool,
)

def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()

@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)

# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]

# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch
)

# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output

self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()

if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)

def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret

def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)

def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1

bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids

self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret

def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings

def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
)
return success, message
Loading
Loading