Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up batch data structures: Introducing ModelWorkerBatch #1544

Merged
merged 17 commits into from
Sep 30, 2024
23 changes: 15 additions & 8 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
allocate_init_ports,
configure_logger,
kill_child_process,
suppress_other_loggers,
Expand Down Expand Up @@ -125,6 +127,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
Expand All @@ -136,7 +143,7 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=28888,
nccl_port=server_args.additional_ports[-1],
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
Expand Down Expand Up @@ -225,17 +232,19 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch


def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
forward_batch = batch.get_forward_batch()
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
return next_token_ids, logits_output.next_token_logits


Expand Down Expand Up @@ -357,7 +366,6 @@ def latency_test(
tp_rank,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
_set_envs_and_config(server_args)
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

# Load the model
Expand Down Expand Up @@ -463,6 +471,7 @@ def plot_latency_test(


def main(server_args, bench_args):
_set_envs_and_config(server_args)

if server_args.model_path:
if bench_args.correctness_test:
Expand Down Expand Up @@ -513,8 +522,6 @@ def main(server_args, bench_args):
format="%(message)s",
)

multiprocessing.set_start_method("spawn", force=True)

try:
main(server_args, bench_args)
except Exception as e:
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ class LogitsMetadata:

@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if forward_batch.return_logprob:
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
else:
return_top_logprob = False

if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
Expand Down
170 changes: 122 additions & 48 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,7 +13,19 @@
limitations under the License.
"""

"""Meta data for requests and batches"""
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""

import logging
from dataclasses import dataclass
Expand All @@ -29,7 +39,7 @@
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
Expand Down Expand Up @@ -105,6 +115,8 @@ def to_json(self):

@dataclass
class ImageInputs:
"""The image related inputs."""

pixel_values: torch.Tensor
image_hash: int
image_sizes: Optional[list] = None
Expand Down Expand Up @@ -137,7 +149,7 @@ def from_dict(obj, vocab_size):


class Req:
"""Store all inforamtion of a request."""
"""The input and output status of a request."""

def __init__(
self,
Expand Down Expand Up @@ -393,20 +405,20 @@ class ScheduleBatch:
sampling_info: SamplingBatchInfo = None

# Batched arguments to model runner
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
input_ids: List[int] = None
req_pool_indices: List[int] = None
seq_lens: List[int] = None
out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None

# For mixed chunekd prefill
prefix_lens_cpu: List[int] = None
running_bs: int = None

# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: List[int] = None
top_logprobs_nums: Optional[List[int]] = None

# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: int = None
running_bs: int = None

# Stream
has_stream: bool = False
Expand Down Expand Up @@ -466,12 +478,12 @@ def prepare_for_extend(self, vocab_size: int):
seq_lens = []

# Allocate memory
req_pool_indices_cpu = self.alloc_req_slots(bs)
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)

pt = 0
for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i]
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len)
assert seq_len - pre_len == req.extend_input_len
Expand All @@ -497,22 +509,19 @@ def prepare_for_extend(self, vocab_size: int):
pt += req.extend_input_len

# Set fields
with torch.device("cuda"):
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
self.input_ids = sum(input_ids, [])
self.req_pool_indices = torch.tensor(req_pool_indices, device="cuda")
self.seq_lens = torch.tensor(seq_lens, device="cuda")

self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]

def get_forward_batch(self):
return ForwardBatch.from_schedule_batch(self)
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)

def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
Expand All @@ -522,24 +531,24 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1

input_ids = torch.cat([self.input_ids, running_batch.input_ids])
input_ids = self.input_ids + running_batch.input_ids
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_bs

self.merge(running_batch)
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens

# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens_cpu.extend(
self.prefix_lens.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
self.extend_lens_cpu.extend([1] * running_bs)
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
self.extend_lens.extend([1] * running_bs)
self.extend_logprob_start_lens.extend([0] * running_bs)

def check_decode_mem(self):
bs = len(self.reqs)
Expand Down Expand Up @@ -631,7 +640,7 @@ def retract_decode(self):

return retracted_reqs, new_estimate_ratio

def check_for_jump_forward(self, model_runner):
def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))]

Expand Down Expand Up @@ -688,7 +697,7 @@ def check_for_jump_forward(self, model_runner):

# re-applying image padding
if req.image_inputs is not None:
req.origin_input_ids = model_runner.model.pad_input_ids(
req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)

Expand All @@ -708,7 +717,7 @@ def prepare_for_decode(self, input_ids=None):
for r in self.reqs
]

self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.input_ids = input_ids
self.seq_lens.add_(1)

# Alloc mem
Expand All @@ -731,32 +740,97 @@ def filter_batch(self, unfinished_indices: List[int]):

self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices]
self.input_ids = None
self.req_pool_indices = self.req_pool_indices[new_indices]
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
self.top_logprobs_nums = [
self.top_logprobs_nums[i] for i in unfinished_indices
]
self.has_stream = any(req.stream for req in self.reqs)

self.sampling_info.filter(unfinished_indices, new_indices)
self.sampling_info.filter_batch(unfinished_indices, new_indices)

def merge(self, other: "ScheduleBatch"):
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge(other.sampling_info)
self.sampling_info.merge_batch(other.sampling_info)

self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs))
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.has_stream = any(req.stream for req in self.reqs)

def get_model_worker_batch(self):
if self.forward_mode.is_decode():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
image_inputs
) = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
image_inputs = [r.image_inputs for r in self.reqs]

lora_paths = [req.lora_path for req in self.reqs]
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible overhead

return ModelWorkerBatch(
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=image_inputs,
lora_paths=lora_paths,
sampling_info=self.sampling_info,
)


@dataclass
class ModelWorkerBatch:
# The forward mode
forward_mode: ForwardMode
# The input ids
input_ids: List[int]
# The indices of requests in the req_to_token_pool
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor

# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]

# For extend
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]

# For multimodal
image_inputs: Optional[List[ImageInputs]]

# For LoRA
lora_paths: Optional[List[str]]

# Sampling info
sampling_info: SamplingBatchInfo
Loading
Loading