Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let ModelRunner take InputMetadata as input, instead of ScheduleBatch #1541

Merged
merged 9 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
logits_output = model_runner.forward(batch)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch


def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
logits_output = model_runner.forward(batch)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits

Expand Down
37 changes: 21 additions & 16 deletions python/sglang/srt/layers/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import is_hip

Expand All @@ -37,9 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends"""

@abstractmethod
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
def init_forward_metadata(self, input_metadata: InputMetadata):
"""Init the metadata for a forward pass."""
raise NotImplementedError()

Expand Down Expand Up @@ -133,12 +131,11 @@ def __init__(self, model_runner: ModelRunner):
self.forward_metadata = None
self.cuda_graph_metadata = {}

def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
def init_forward_metadata(self, input_metadata: InputMetadata):
if input_metadata.forward_mode.is_decode():
prefix_lens = None
use_ragged = False
extend_no_prefix = False
total_num_tokens = None
else:
prefix_lens = input_metadata.extend_prefix_lens
Expand All @@ -152,6 +149,7 @@ def init_forward_metadata(
use_ragged = True

total_num_tokens = torch.sum(input_metadata.seq_lens).item()
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item()

update_flashinfer_indices(
input_metadata.forward_mode,
Expand All @@ -162,7 +160,12 @@ def init_forward_metadata(
use_ragged=use_ragged,
)

self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
self.forward_metadata = (
use_ragged,
extend_no_prefix,
total_num_tokens,
self.decode_wrapper,
)

def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indptr = torch.zeros(
Expand Down Expand Up @@ -228,7 +231,7 @@ def init_forward_metadata_capture_cuda_graph(

self.cuda_graph_metadata[bs] = decode_wrapper

self.forward_metadata = (False, None, decode_wrapper)
self.forward_metadata = (False, False, None, decode_wrapper)

def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
Expand All @@ -254,7 +257,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
else:
prefill_wrapper_paged = self.prefill_wrapper_paged[1]

use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)

if not use_ragged:
if k is not None:
Expand All @@ -280,7 +285,7 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
logits_soft_cap=layer.logit_cap,
)

if input_metadata.extend_no_prefix:
if extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
Expand All @@ -300,7 +305,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
return o.view(-1, layer.tp_q_head_num * layer.head_dim)

def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)

if isinstance(decode_wrapper, list):
if layer.sliding_window_size != -1:
Expand Down Expand Up @@ -351,9 +358,7 @@ def __init__(self, model_runner: ModelRunner):

self.cuda_graph_max_seq_len = model_runner.model_config.context_len

def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
def init_forward_metadata(self, input_metadata: InputMetadata):
"""Init auxiliary variables for triton attention backend."""

if input_metadata.forward_mode.is_decode():
Expand All @@ -371,7 +376,7 @@ def init_forward_metadata(
max_extend_len = None
else:
start_loc = attn_logits = max_seq_len = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
prefix_lens = input_metadata.extend_prefix_lens
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()

self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
Expand Down
19 changes: 11 additions & 8 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@


import re
from dataclasses import dataclass

import torch

from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip, replace_submodule

# ROCm: flashinfer available later
Expand Down Expand Up @@ -208,9 +207,9 @@ def load_lora(self, uid, buffer_id):
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)

def prepare_lora_batch(self, batch, extend_seq_lens=None):
def prepare_lora_batch(self, input_metadata: InputMetadata):
# load active loras into lora memory pool
cur_uids = set([req.lora_path for req in batch.reqs])
cur_uids = set(input_metadata.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
Expand All @@ -230,11 +229,15 @@ def prepare_lora_batch(self, batch, extend_seq_lens=None):
return

# setup lora in forward modules
bs = len(batch.reqs)
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
bs = input_metadata.batch_size
seg_lens = (
input_metadata.extend_seq_lens
if input_metadata.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, req in enumerate(batch.reqs):
weight_indices[i] = self.buffer_id[req.lora_path]
for i, lora_path in enumerate(input_metadata.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]

for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
Expand Down Expand Up @@ -511,6 +511,9 @@ def prepare_for_extend(self, vocab_size: int):
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)

def get_input_metadata(self):
return InputMetadata.from_schedule_batch(self)

def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size()
Expand Down
11 changes: 8 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,9 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
batch
input_metadata, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
Expand Down Expand Up @@ -640,7 +641,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
)
else:
assert batch.extend_num_tokens != 0
embeddings = self.tp_worker.forward_batch_embedding(batch)
input_metadata = batch.get_input_metadata()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)

# Check finish conditions
for i, req in enumerate(batch.reqs):
Expand Down Expand Up @@ -769,7 +771,10 @@ def forward_decode_batch(self, batch: ScheduleBatch):
batch.prepare_for_decode()

# Forward and sample the next tokens
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(batch)
input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
Expand Down
9 changes: 5 additions & 4 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
Expand Down Expand Up @@ -105,13 +106,13 @@ def get_token_and_memory_info(self):
self.random_seed,
)

def forward_batch_generation(self, batch):
logits_output = self.model_runner.forward(batch)
def forward_batch_generation(self, input_metadata: InputMetadata, batch):
logits_output = self.model_runner.forward(input_metadata)
next_token_ids = self.model_runner.sample(logits_output, batch)
return logits_output, next_token_ids

def forward_batch_embedding(self, batch):
logits_output = self.model_runner.forward(batch)
def forward_batch_embedding(self, input_metadata: InputMetadata):
logits_output = self.model_runner.forward(input_metadata)
embeddings = logits_output.embeddings.tolist()
return embeddings

Expand Down
28 changes: 12 additions & 16 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import monkey_patch_vllm_all_gather

Expand Down Expand Up @@ -143,7 +142,6 @@ def __init__(self, model_runner: "ModelRunner"):
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)

# Capture
Expand Down Expand Up @@ -189,7 +187,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
input_ids = self.input_ids[:bs]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
position_ids_offsets = self.position_ids_offsets[:bs]
out_cache_loc = self.out_cache_loc[:bs]

# Attention backend
Expand All @@ -202,6 +199,7 @@ def run_once():
input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
Expand All @@ -210,7 +208,7 @@ def run_once():
out_cache_loc=out_cache_loc,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
)
return forward(input_ids, input_metadata.positions, input_metadata)

Expand All @@ -235,24 +233,22 @@ def run_once():
self.graph_memory_pool = graph.pool()
return graph, out

def replay(self, batch: ScheduleBatch):
assert batch.out_cache_loc is not None
raw_bs = len(batch.reqs)
def replay(self, input_metadata: InputMetadata):
assert input_metadata.out_cache_loc is not None
raw_bs = input_metadata.batch_size

# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value)
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_()

# Common inputs
self.input_ids[:raw_bs] = batch.input_ids
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
self.seq_lens[:raw_bs] = batch.seq_lens
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
self.input_ids[:raw_bs] = input_metadata.input_ids
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
self.seq_lens[:raw_bs] = input_metadata.seq_lens
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
Expand All @@ -275,15 +271,15 @@ def replay(self, batch: ScheduleBatch):
)

# Extract logprobs
if batch.return_logprob:
if input_metadata.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
top_logprobs_nums=input_metadata.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata
Expand Down
Loading
Loading