Skip to content

Commit

Permalink
Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 30, 2024
1 parent 55b974f commit 3f0fe08
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 158 deletions.
6 changes: 4 additions & 2 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
logits_output = model_runner.forward(batch)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch


def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
logits_output = model_runner.forward(batch)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits

Expand Down
37 changes: 21 additions & 16 deletions python/sglang/srt/layers/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import is_hip

Expand All @@ -37,9 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends"""

@abstractmethod
def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
def init_forward_metadata(self, input_metadata: InputMetadata):
"""Init the metadata for a forward pass."""
raise NotImplementedError()

Expand Down Expand Up @@ -133,12 +131,11 @@ def __init__(self, model_runner: ModelRunner):
self.forward_metadata = None
self.cuda_graph_metadata = {}

def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
def init_forward_metadata(self, input_metadata: InputMetadata):
if input_metadata.forward_mode.is_decode():
prefix_lens = None
use_ragged = False
extend_no_prefix = False
total_num_tokens = None
else:
prefix_lens = input_metadata.extend_prefix_lens
Expand All @@ -152,6 +149,7 @@ def init_forward_metadata(
use_ragged = True

total_num_tokens = torch.sum(input_metadata.seq_lens).item()
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item()

update_flashinfer_indices(
input_metadata.forward_mode,
Expand All @@ -162,7 +160,12 @@ def init_forward_metadata(
use_ragged=use_ragged,
)

self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper)
self.forward_metadata = (
use_ragged,
extend_no_prefix,
total_num_tokens,
self.decode_wrapper,
)

def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indptr = torch.zeros(
Expand Down Expand Up @@ -228,7 +231,7 @@ def init_forward_metadata_capture_cuda_graph(

self.cuda_graph_metadata[bs] = decode_wrapper

self.forward_metadata = (False, None, decode_wrapper)
self.forward_metadata = (False, False, None, decode_wrapper)

def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
Expand All @@ -254,7 +257,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
else:
prefill_wrapper_paged = self.prefill_wrapper_paged[1]

use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)

if not use_ragged:
if k is not None:
Expand All @@ -280,7 +285,7 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
logits_soft_cap=layer.logit_cap,
)

if input_metadata.extend_no_prefix:
if extend_no_prefix:
o = o1
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
Expand All @@ -300,7 +305,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
return o.view(-1, layer.tp_q_head_num * layer.head_dim)

def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)

if isinstance(decode_wrapper, list):
if layer.sliding_window_size != -1:
Expand Down Expand Up @@ -351,9 +358,7 @@ def __init__(self, model_runner: ModelRunner):

self.cuda_graph_max_seq_len = model_runner.model_config.context_len

def init_forward_metadata(
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
def init_forward_metadata(self, input_metadata: InputMetadata):
"""Init auxiliary variables for triton attention backend."""

if input_metadata.forward_mode.is_decode():
Expand All @@ -371,7 +376,7 @@ def init_forward_metadata(
max_extend_len = None
else:
start_loc = attn_logits = max_seq_len = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
prefix_lens = input_metadata.extend_prefix_lens
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()

self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
Expand Down
19 changes: 11 additions & 8 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@


import re
from dataclasses import dataclass

import torch

from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip, replace_submodule

# ROCm: flashinfer available later
Expand Down Expand Up @@ -208,9 +207,9 @@ def load_lora(self, uid, buffer_id):
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)

def prepare_lora_batch(self, batch, extend_seq_lens=None):
def prepare_lora_batch(self, input_metadata: InputMetadata):
# load active loras into lora memory pool
cur_uids = set([req.lora_path for req in batch.reqs])
cur_uids = set(input_metadata.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
Expand All @@ -230,11 +229,15 @@ def prepare_lora_batch(self, batch, extend_seq_lens=None):
return

# setup lora in forward modules
bs = len(batch.reqs)
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs)
bs = input_metadata.batch_size
seg_lens = (
input_metadata.extend_seq_lens
if input_metadata.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, req in enumerate(batch.reqs):
weight_indices[i] = self.buffer_id[req.lora_path]
for i, lora_path in enumerate(input_metadata.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]

for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
Expand Down Expand Up @@ -511,6 +511,9 @@ def prepare_for_extend(self, vocab_size: int):
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)

def get_input_metadata(self):
return InputMetadata.from_schedule_batch(self)

def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size()
Expand Down
11 changes: 8 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,9 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
batch
input_metadata, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
Expand Down Expand Up @@ -640,7 +641,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
)
else:
assert batch.extend_num_tokens != 0
embeddings = self.tp_worker.forward_batch_embedding(batch)
input_metadata = batch.get_input_metadata()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)

# Check finish conditions
for i, req in enumerate(batch.reqs):
Expand Down Expand Up @@ -769,7 +771,10 @@ def forward_decode_batch(self, batch: ScheduleBatch):
batch.prepare_for_decode()

# Forward and sample the next tokens
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(batch)
input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
Expand Down
9 changes: 5 additions & 4 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
Expand Down Expand Up @@ -105,13 +106,13 @@ def get_token_and_memory_info(self):
self.random_seed,
)

def forward_batch_generation(self, batch):
logits_output = self.model_runner.forward(batch)
def forward_batch_generation(self, input_metadata: InputMetadata, batch):
logits_output = self.model_runner.forward(input_metadata)
next_token_ids = self.model_runner.sample(logits_output, batch)
return logits_output, next_token_ids

def forward_batch_embedding(self, batch):
logits_output = self.model_runner.forward(batch)
def forward_batch_embedding(self, input_metadata: InputMetadata):
logits_output = self.model_runner.forward(input_metadata)
embeddings = logits_output.embeddings.tolist()
return embeddings

Expand Down
28 changes: 12 additions & 16 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import monkey_patch_vllm_all_gather

Expand Down Expand Up @@ -143,7 +142,6 @@ def __init__(self, model_runner: "ModelRunner"):
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)

# Capture
Expand Down Expand Up @@ -189,7 +187,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
input_ids = self.input_ids[:bs]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
position_ids_offsets = self.position_ids_offsets[:bs]
out_cache_loc = self.out_cache_loc[:bs]

# Attention backend
Expand All @@ -202,6 +199,7 @@ def run_once():
input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
Expand All @@ -210,7 +208,7 @@ def run_once():
out_cache_loc=out_cache_loc,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
)
return forward(input_ids, input_metadata.positions, input_metadata)

Expand All @@ -235,24 +233,22 @@ def run_once():
self.graph_memory_pool = graph.pool()
return graph, out

def replay(self, batch: ScheduleBatch):
assert batch.out_cache_loc is not None
raw_bs = len(batch.reqs)
def replay(self, input_metadata: InputMetadata):
assert input_metadata.out_cache_loc is not None
raw_bs = input_metadata.batch_size

# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value)
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_()

# Common inputs
self.input_ids[:raw_bs] = batch.input_ids
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
self.seq_lens[:raw_bs] = batch.seq_lens
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
self.input_ids[:raw_bs] = input_metadata.input_ids
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
self.seq_lens[:raw_bs] = input_metadata.seq_lens
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
Expand All @@ -275,15 +271,15 @@ def replay(self, batch: ScheduleBatch):
)

# Extract logprobs
if batch.return_logprob:
if input_metadata.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
top_logprobs_nums=input_metadata.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata
Expand Down
Loading

0 comments on commit 3f0fe08

Please sign in to comment.