From b6910975496ae7b83b672bfec746e943d5daa2d5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 17:43:35 -0700 Subject: [PATCH 1/9] separate mem pool --- python/sglang/bench_latency.py | 8 +- python/sglang/srt/layers/attention_backend.py | 14 +-- python/sglang/srt/managers/schedule_batch.py | 8 +- .../srt/model_executor/cuda_graph_runner.py | 28 ++--- .../srt/model_executor/forward_batch_info.py | 103 ++++++------------ .../sglang/srt/model_executor/model_runner.py | 47 ++++---- 6 files changed, 86 insertions(+), 122 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 406d91f18c9..a51a688a2d9 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -224,15 +224,15 @@ def extend(reqs, model_runner): token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, ) - batch.prepare_for_extend(model_runner.model_config.vocab_size) - logits_output = model_runner.forward(batch) + input_metadata = batch.prepare_for_extend(model_runner.model_config.vocab_size) + logits_output = model_runner.forward(input_metadata) next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): - batch.prepare_for_decode(input_token_ids) - logits_output = model_runner.forward(batch) + input_metadata = batch.prepare_for_decode(input_token_ids) + logits_output = model_runner.forward(input_metadata) next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index d7c1cf39d82..72ec8270e0f 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -37,9 +37,7 @@ class AttentionBackend(ABC): """The base class of attention backends""" @abstractmethod - def init_forward_metadata( - self, batch: ScheduleBatch, input_metadata: InputMetadata - ): + def init_forward_metadata(self, input_metadata: InputMetadata): """Init the metadata for a forward pass.""" raise NotImplementedError() @@ -133,9 +131,7 @@ def __init__(self, model_runner: ModelRunner): self.forward_metadata = None self.cuda_graph_metadata = {} - def init_forward_metadata( - self, batch: ScheduleBatch, input_metadata: InputMetadata - ): + def init_forward_metadata(self, input_metadata: InputMetadata): if input_metadata.forward_mode.is_decode(): prefix_lens = None use_ragged = False @@ -351,9 +347,7 @@ def __init__(self, model_runner: ModelRunner): self.cuda_graph_max_seq_len = model_runner.model_config.context_len - def init_forward_metadata( - self, batch: ScheduleBatch, input_metadata: InputMetadata - ): + def init_forward_metadata(self, input_metadata: InputMetadata): """Init auxiliary variables for triton attention backend.""" if input_metadata.forward_mode.is_decode(): @@ -371,7 +365,7 @@ def init_forward_metadata( max_extend_len = None else: start_loc = attn_logits = max_seq_len = None - prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + prefix_lens = input_metadata.extend_prefix_lens max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 75b8b80ce92..3ab0aa2e794 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,7 +29,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -511,6 +511,9 @@ def prepare_for_extend(self, vocab_size: int): self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) + input_metadata = InputMetadata.from_schedule_batch(self) + return input_metadata + def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED running_bs = running_batch.batch_size() @@ -716,6 +719,9 @@ def prepare_for_decode(self, input_ids=None): self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc + input_metadata = InputMetadata.from_schedule_batch(self) + return input_metadata + def filter_batch(self, unfinished_indices: List[int]): if unfinished_indices is None or len(unfinished_indices) == 0: # Filter out all requests diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 4eb2197aac0..763ee0e1531 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -31,7 +31,6 @@ LogitsProcessor, LogitsProcessorOutput, ) -from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.utils import monkey_patch_vllm_all_gather @@ -143,7 +142,6 @@ def __init__(self, model_runner: "ModelRunner"): self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32) self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) # Capture @@ -189,7 +187,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable): input_ids = self.input_ids[:bs] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] - position_ids_offsets = self.position_ids_offsets[:bs] out_cache_loc = self.out_cache_loc[:bs] # Attention backend @@ -202,6 +199,7 @@ def run_once(): input_metadata = InputMetadata( forward_mode=ForwardMode.DECODE, batch_size=bs, + input_ids=input_ids, req_pool_indices=req_pool_indices, seq_lens=seq_lens, req_to_token_pool=self.model_runner.req_to_token_pool, @@ -210,7 +208,7 @@ def run_once(): out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=[0] * bs, - positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), + positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), ) return forward(input_ids, input_metadata.positions, input_metadata) @@ -235,24 +233,22 @@ def run_once(): self.graph_memory_pool = graph.pool() return graph, out - def replay(self, batch: ScheduleBatch): - assert batch.out_cache_loc is not None - raw_bs = len(batch.reqs) + def replay(self, input_metadata: InputMetadata): + assert input_metadata.out_cache_loc is not None + raw_bs = input_metadata.batch_size # Pad index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: self.seq_lens.fill_(self.seq_len_fill_value) - self.position_ids_offsets.fill_(1) self.out_cache_loc.zero_() # Common inputs - self.input_ids[:raw_bs] = batch.input_ids - self.req_pool_indices[:raw_bs] = batch.req_pool_indices - self.seq_lens[:raw_bs] = batch.seq_lens - self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets - self.out_cache_loc[:raw_bs] = batch.out_cache_loc + self.input_ids[:raw_bs] = input_metadata.input_ids + self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices + self.seq_lens[:raw_bs] = input_metadata.seq_lens + self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( @@ -275,15 +271,15 @@ def replay(self, batch: ScheduleBatch): ) # Extract logprobs - if batch.return_logprob: + if input_metadata.return_logprob: logits_output.next_token_logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logits, dim=-1 ) - return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) + return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) if return_top_logprob: logits_metadata = LogitsMetadata( forward_mode=ForwardMode.DECODE, - top_logprobs_nums=batch.top_logprobs_nums, + top_logprobs_nums=input_metadata.top_logprobs_nums, ) logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( logits_output.next_token_logprobs, logits_metadata diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8421774f115..e727821ef1f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -27,7 +27,6 @@ from sglang.srt.layers.attention_backend import AttentionBackend from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool - from sglang.srt.model_executor.model_runner import ModelRunner class ForwardMode(IntEnum): @@ -37,7 +36,7 @@ class ForwardMode(IntEnum): EXTEND = auto() # Decode one token. DECODE = auto() - # Contains both PREFILL and EXTEND. + # Contains both EXTEND and DECODE. MIXED = auto() def is_prefill(self): @@ -57,15 +56,17 @@ def is_mixed(self): class InputMetadata: """Store all inforamtion of a forward pass.""" + # The forward mode forward_mode: ForwardMode + # The batch size batch_size: int + # The input ids + input_ids: torch.Tensor + # The indices of requests in the req_to_token_pool req_pool_indices: torch.Tensor + # The sequence length seq_lens: torch.Tensor - req_to_token_pool: ReqToTokenPool - token_to_kv_pool: BaseTokenToKVPool - attn_backend: AttentionBackend - - # Output location of the KV cache + # The indices of output tokens in the token_to_kv_pool out_cache_loc: torch.Tensor # Position information @@ -86,82 +87,48 @@ class InputMetadata: # For multimodal image_inputs: List[ImageInputs] = None - def init_multimuldal_info(self, batch: ScheduleBatch): - self.image_inputs = [r.image_inputs for r in batch.reqs] - - def compute_positions(self, batch: ScheduleBatch): - if self.forward_mode.is_decode(): - if True: - self.positions = self.seq_lens - 1 - else: - # Deprecated - self.positions = (self.seq_lens - 1) + batch.position_ids_offsets - else: - if True: - self.positions = torch.tensor( - np.concatenate( - [ - np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids)) - for i, req in enumerate(batch.reqs) - ], - axis=0, - ), - device="cuda", - ) - else: - # Deprecated - position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy() - self.positions = torch.tensor( - np.concatenate( - [ - np.arange( - batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i], - len(req.fill_ids) + position_ids_offsets_cpu[i], - ) - for i, req in enumerate(batch.reqs) - ], - axis=0, - ), - device="cuda", - ) - - # Positions should be in long type - self.positions = self.positions.to(torch.int64) - - def compute_extend_infos(self, batch: ScheduleBatch): - self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda") - self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") - self.extend_start_loc = torch.zeros_like(self.extend_seq_lens) - self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu) - self.extend_seq_lens_cpu = batch.extend_lens_cpu - self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu + # Attention backend + req_to_token_pool: ReqToTokenPool = None + token_to_kv_pool: BaseTokenToKVPool = None + attn_backend: AttentionBackend = None @classmethod def from_schedule_batch( cls, - model_runner: "ModelRunner", batch: ScheduleBatch, ): ret = cls( forward_mode=batch.forward_mode, batch_size=batch.batch_size(), + input_ids=batch.input_ids, req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, - req_to_token_pool=model_runner.req_to_token_pool, - token_to_kv_pool=model_runner.token_to_kv_pool, - attn_backend=model_runner.attn_backend, out_cache_loc=batch.out_cache_loc, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, ) - ret.compute_positions(batch) - - if not batch.forward_mode.is_decode(): - ret.init_multimuldal_info(batch) - ret.compute_extend_infos(batch) - - model_runner.attn_backend.init_forward_metadata(batch, ret) + if ret.forward_mode.is_decode(): + ret.positions = (ret.seq_lens - 1).to(torch.int64) + else: + ret.positions = torch.tensor( + np.concatenate( + [ + np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids)) + for i, req in enumerate(batch.reqs) + ], + axis=0, + ), + device="cuda", + ).to(torch.int64) + + ret.image_inputs = [r.image_inputs for r in batch.reqs] + ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda") + ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) + ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) + ret.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu) + ret.extend_seq_lens_cpu = batch.extend_lens_cpu + ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu return ret diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9dfe7005155..2c003c4e030 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -466,46 +466,47 @@ def init_cuda_graphs(self): logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) - def forward_decode(self, batch: ScheduleBatch): - if self.server_args.lora_paths is not None: - self.lora_manager.prepare_lora_batch(batch) - - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): - return self.cuda_graph_runner.replay(batch) - - input_metadata = InputMetadata.from_schedule_batch(self, batch) + def forward_decode(self, input_metadata: InputMetadata): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run( + input_metadata.batch_size + ): + return self.cuda_graph_runner.replay(input_metadata) return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata + input_metadata.input_ids, input_metadata.positions, input_metadata ) - def forward_extend(self, batch: ScheduleBatch): - input_metadata = InputMetadata.from_schedule_batch(self, batch) - if self.server_args.lora_paths is not None: - self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens) - + def forward_extend(self, input_metadata: InputMetadata): if self.is_generation: return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata + input_metadata.input_ids, input_metadata.positions, input_metadata ) else: # Only embedding models have get_embedding parameter return self.model.forward( - batch.input_ids, + input_metadata.input_ids, input_metadata.positions, input_metadata, get_embedding=True, ) - def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]: - assert batch.forward_mode is not None + def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput: + # Attach attention information + input_metadata.req_to_token_pool = self.req_to_token_pool + input_metadata.token_to_kv_pool = self.token_to_kv_pool + input_metadata.attn_backend = self.attn_backend + input_metadata.attn_backend.init_forward_metadata(input_metadata) + + # Attach lora information + if self.server_args.lora_paths is not None: + self.lora_manager.prepare_lora_batch(None) - if batch.forward_mode.is_decode(): - return self.forward_decode(batch) - elif batch.forward_mode.is_extend(): - return self.forward_extend(batch) + if input_metadata.forward_mode.is_decode(): + return self.forward_decode(input_metadata) + elif input_metadata.forward_mode.is_extend(): + return self.forward_extend(input_metadata) else: - raise ValueError(f"Invaid forward mode: {batch.forward_mode}") + raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}") def _apply_logits_bias( self, logits: torch.Tensor, sampling_info: SamplingBatchInfo From 5fab8fee3b445d0f6a6683f14320a7fd71431796 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 17:51:23 -0700 Subject: [PATCH 2/9] drop extend_no_prefix --- python/sglang/srt/layers/attention_backend.py | 20 ++++++++++++++----- .../srt/model_executor/forward_batch_info.py | 2 -- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 72ec8270e0f..f6fdfc4ce89 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -148,6 +148,7 @@ def init_forward_metadata(self, input_metadata: InputMetadata): use_ragged = True total_num_tokens = torch.sum(input_metadata.seq_lens).item() + extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item() update_flashinfer_indices( input_metadata.forward_mode, @@ -158,7 +159,12 @@ def init_forward_metadata(self, input_metadata: InputMetadata): use_ragged=use_ragged, ) - self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper) + self.forward_metadata = ( + use_ragged, + extend_no_prefix, + total_num_tokens, + self.decode_wrapper, + ) def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_kv_indptr = torch.zeros( @@ -224,7 +230,7 @@ def init_forward_metadata_capture_cuda_graph( self.cuda_graph_metadata[bs] = decode_wrapper - self.forward_metadata = (False, None, decode_wrapper) + self.forward_metadata = (False, False, None, decode_wrapper) def init_forward_metadata_replay_cuda_graph( self, bs: int, req_pool_indices, seq_lens @@ -250,7 +256,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat else: prefill_wrapper_paged = self.prefill_wrapper_paged[1] - use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata + use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( + self.forward_metadata + ) if not use_ragged: if k is not None: @@ -276,7 +284,7 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat logits_soft_cap=layer.logit_cap, ) - if input_metadata.extend_no_prefix: + if extend_no_prefix: o = o1 else: o2, s2 = prefill_wrapper_paged.forward_return_lse( @@ -296,7 +304,9 @@ def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): - use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata + use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( + self.forward_metadata + ) if isinstance(decode_wrapper, list): if layer.sliding_window_size != -1: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e727821ef1f..87ec6dceaf7 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -76,7 +76,6 @@ class InputMetadata: extend_seq_lens: torch.Tensor = None extend_prefix_lens: torch.Tensor = None extend_start_loc: torch.Tensor = None - extend_no_prefix: bool = None # For logprob return_logprob: bool = False @@ -127,7 +126,6 @@ def from_schedule_batch( ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens) ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0) - ret.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu) ret.extend_seq_lens_cpu = batch.extend_lens_cpu ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu From 4e8f9291b7d4425fb97612bfb435ac8d477be65a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 17:53:11 -0700 Subject: [PATCH 3/9] fix --- python/sglang/srt/layers/attention_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index f6fdfc4ce89..c8fe52ed352 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -15,7 +15,7 @@ from sglang.global_config import global_config from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices -from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.utils import is_hip @@ -135,6 +135,7 @@ def init_forward_metadata(self, input_metadata: InputMetadata): if input_metadata.forward_mode.is_decode(): prefix_lens = None use_ragged = False + extend_no_prefix = False total_num_tokens = None else: prefix_lens = input_metadata.extend_prefix_lens From d4a0553d2c04b013f337f868d92ba52a1851233b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 18:36:08 -0700 Subject: [PATCH 4/9] fix lora --- python/sglang/srt/lora/lora_manager.py | 19 +++++++++++-------- .../srt/model_executor/forward_batch_info.py | 6 +++++- .../sglang/srt/model_executor/model_runner.py | 2 +- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index d0a604fe64d..04f29feadb1 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -18,13 +18,12 @@ import re -from dataclasses import dataclass import torch from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.utils import is_hip, replace_submodule # ROCm: flashinfer available later @@ -208,9 +207,9 @@ def load_lora(self, uid, buffer_id): if lora_weight_name: self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) - def prepare_lora_batch(self, batch, extend_seq_lens=None): + def prepare_lora_batch(self, input_metadata: InputMetadata): # load active loras into lora memory pool - cur_uids = set([req.lora_path for req in batch.reqs]) + cur_uids = set(input_metadata.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch i = 0 evictable_uids = list(self.active_uids) @@ -230,11 +229,15 @@ def prepare_lora_batch(self, batch, extend_seq_lens=None): return # setup lora in forward modules - bs = len(batch.reqs) - seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs) + bs = input_metadata.batch_size + seg_lens = ( + input_metadata.extend_seq_lens + if input_metadata.forward_mode.is_extend() + else torch.ones(bs) + ) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") - for i, req in enumerate(batch.reqs): - weight_indices[i] = self.buffer_id[req.lora_path] + for i, lora_path in enumerate(input_metadata.lora_paths): + weight_indices[i] = self.buffer_id[lora_path] for module_name, module in self.lora_modules: layer_id = get_layer_id(module_name) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 87ec6dceaf7..c5b218a1b30 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -18,7 +18,7 @@ """Meta data for a forward pass.""" from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Set import numpy as np import torch @@ -86,6 +86,9 @@ class InputMetadata: # For multimodal image_inputs: List[ImageInputs] = None + # For LoRA + lora_paths: List[str] = None + # Attention backend req_to_token_pool: ReqToTokenPool = None token_to_kv_pool: BaseTokenToKVPool = None @@ -105,6 +108,7 @@ def from_schedule_batch( out_cache_loc=batch.out_cache_loc, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, + lora_paths=[req.lora_path for req in batch.reqs], ) if ret.forward_mode.is_decode(): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2c003c4e030..9f65e58174d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -499,7 +499,7 @@ def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput: # Attach lora information if self.server_args.lora_paths is not None: - self.lora_manager.prepare_lora_batch(None) + self.lora_manager.prepare_lora_batch(input_metadata) if input_metadata.forward_mode.is_decode(): return self.forward_decode(input_metadata) From c248e3aec19ce73999d9e241e0505f2289c7fb82 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 18:46:26 -0700 Subject: [PATCH 5/9] fix scheduler --- python/sglang/bench_latency.py | 6 ++++-- python/sglang/srt/managers/schedule_batch.py | 7 ++----- python/sglang/srt/managers/scheduler.py | 11 ++++++++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index a51a688a2d9..354559a089a 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -224,14 +224,16 @@ def extend(reqs, model_runner): token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, ) - input_metadata = batch.prepare_for_extend(model_runner.model_config.vocab_size) + batch.prepare_for_extend(model_runner.model_config.vocab_size) + input_metadata = batch.get_input_metadata() logits_output = model_runner.forward(input_metadata) next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): - input_metadata = batch.prepare_for_decode(input_token_ids) + batch.prepare_for_decode(input_token_ids) + input_metadata = batch.get_input_metadata() logits_output = model_runner.forward(input_metadata) next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3ab0aa2e794..6cf870ad7fd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -511,8 +511,8 @@ def prepare_for_extend(self, vocab_size: int): self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) - input_metadata = InputMetadata.from_schedule_batch(self) - return input_metadata + def get_input_metadata(self): + return InputMetadata.from_schedule_batch(self) def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED @@ -719,9 +719,6 @@ def prepare_for_decode(self, input_ids=None): self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc - input_metadata = InputMetadata.from_schedule_batch(self) - return input_metadata - def filter_batch(self, unfinished_indices: List[int]): if unfinished_indices is None or len(unfinished_indices) == 0: # Filter out all requests diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f80fc9e3cc0..ca090438da7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -575,8 +575,9 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if self.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: + input_metadata = batch.get_input_metadata() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - batch + input_metadata ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids @@ -640,7 +641,8 @@ def forward_prefill_batch(self, batch: ScheduleBatch): ) else: assert batch.extend_num_tokens != 0 - embeddings = self.tp_worker.forward_batch_embedding(batch) + input_metadata = batch.get_input_metadata() + embeddings = self.tp_worker.forward_batch_embedding(input_metadata) # Check finish conditions for i, req in enumerate(batch.reqs): @@ -769,7 +771,10 @@ def forward_decode_batch(self, batch: ScheduleBatch): batch.prepare_for_decode() # Forward and sample the next tokens - logits_output, next_token_ids = self.tp_worker.forward_batch_generation(batch) + input_metadata = batch.get_input_metadata() + logits_output, next_token_ids = self.tp_worker.forward_batch_generation( + input_metadata + ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) From c8f496fbd69da51558aa8591ecbd1cbd079aeb8e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 18:48:20 -0700 Subject: [PATCH 6/9] fix tp worker --- python/sglang/srt/managers/tp_worker.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 9cee6aeaa93..b5d8b4f7f73 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -21,6 +21,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import UpdateWeightReqInput +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed @@ -105,13 +106,13 @@ def get_token_and_memory_info(self): self.random_seed, ) - def forward_batch_generation(self, batch): - logits_output = self.model_runner.forward(batch) + def forward_batch_generation(self, input_metadata: InputMetadata, batch): + logits_output = self.model_runner.forward(input_metadata) next_token_ids = self.model_runner.sample(logits_output, batch) return logits_output, next_token_ids - def forward_batch_embedding(self, batch): - logits_output = self.model_runner.forward(batch) + def forward_batch_embedding(self, input_metadata: InputMetadata): + logits_output = self.model_runner.forward(input_metadata) embeddings = logits_output.embeddings.tolist() return embeddings From 8d585088ccd14adbd7271c70cd4f3e278fc1d888 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 18:48:49 -0700 Subject: [PATCH 7/9] fix tp worker --- python/sglang/srt/managers/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ca090438da7..093bcbe05b4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -577,7 +577,7 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if batch.extend_num_tokens != 0: input_metadata = batch.get_input_metadata() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - input_metadata + input_metadata, batch ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids @@ -773,7 +773,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): # Forward and sample the next tokens input_metadata = batch.get_input_metadata() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - input_metadata + input_metadata, batch ) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids From ba1b7b565b340bbcf3aed6ce175271d5ea37ce95 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 18:57:36 -0700 Subject: [PATCH 8/9] fix lora tests --- python/sglang/test/runners.py | 26 +++++++++++++------------- test/srt/models/test_lora.py | 5 ++--- test/srt/run_suite.py | 1 + 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 023ff892990..8439aa8bbcd 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -71,10 +71,10 @@ class ModelOutput: class HFRunner: def __init__( self, - model_path, - torch_dtype, - model_type="generation", - output_str_only=False, + model_path: str, + torch_dtype: torch.dtype, + model_type: str = "generation", + output_str_only: bool = False, ): self.model_type = model_type self.output_str_only = output_str_only @@ -244,15 +244,15 @@ def __exit__(self, exc_type, exc_value, traceback): class SRTRunner: def __init__( self, - model_path, - torch_dtype, - model_type, - tp_size=1, - port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, - lora_paths=None, - max_loras_per_batch=4, - disable_cuda_graph=False, - disable_radix_cache=False, + model_path: str, + torch_dtype: torch.dtype, + model_type: str, + tp_size: int = 1, + port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + lora_paths: List[str] = None, + max_loras_per_batch: int = 4, + disable_cuda_graph: bool = False, + disable_radix_cache: bool = False, ): self.model_type = model_type self.is_generation = model_type == "generation" diff --git a/test/srt/models/test_lora.py b/test/srt/models/test_lora.py index 51f20e492af..e044c4c0bf4 100644 --- a/test/srt/models/test_lora.py +++ b/test/srt/models/test_lora.py @@ -15,7 +15,6 @@ import multiprocessing as mp import unittest -import uuid import torch @@ -85,9 +84,9 @@ def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): with SRTRunner( base_path, - tp_size=tp_size, torch_dtype=torch_dtype, - is_generation=True, + model_type="generation", + tp_size=tp_size, lora_paths=all_lora_paths, max_loras_per_batch=3, disable_cuda_graph=True, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b7b81f9dda9..bfa5f0cc7b1 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -7,6 +7,7 @@ "minimal": [ "models/test_embedding_models.py", "models/test_generation_models.py", + "models/test_lora.py", "models/test_reward_models.py", "sampling/penaltylib", "test_chunked_prefill.py", From 6ed6b6fbe2a3fe77001e13661832d97b52c02ee2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 29 Sep 2024 19:01:10 -0700 Subject: [PATCH 9/9] disable lora test --- test/srt/run_suite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index bfa5f0cc7b1..4e6ce73a519 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -7,7 +7,7 @@ "minimal": [ "models/test_embedding_models.py", "models/test_generation_models.py", - "models/test_lora.py", + # "models/test_lora.py", "models/test_reward_models.py", "sampling/penaltylib", "test_chunked_prefill.py",