diff --git a/python/pyproject.toml b/python/pyproject.toml index 1389822a34b..dd6f8ece7bb 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "intere openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate"] +test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py new file mode 100644 index 00000000000..dc90270da6d --- /dev/null +++ b/python/sglang/srt/lora/lora.py @@ -0,0 +1,393 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters" +# and "Punica: Multi-Tenant LoRA Serving" + +# LoRA layers class inheritance adapted from: +# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py + + +import json +import os +import re +from typing import Any, Dict, List, Optional, Tuple + +import safetensors.torch +import torch +from torch import nn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.loader import DefaultModelLoader + +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata + + +class BaseLayerWithLoRA(nn.Module): + def __init__(self, base_layer, segment_gemm, lora_rank, scaling): + super().__init__() + self.base_layer = base_layer + self.segment_gemm = segment_gemm + self.lora_rank = lora_rank + self.scaling = scaling + + def forward(self, x: torch.Tensor): + return self.base_layer.forward(x) + + def set_lora_info(self, *args): + pass + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + def __init__( + self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + self.weight = base_layer.weight + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + def __init__( + self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + # TODO + return output + + def forward(self, input_: torch.Tensor): + # duplicate the logic in ColumnParallelLinear + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_, bias + ) + + if hasattr(self, "A_buffer"): + output_parallel = self.apply_lora(output_parallel, input_) + + if self.base_layer.gather_output: + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output, output_bias + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + def __init__( + self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): + self.A_buffer = A_buffer + self.B_buffer = B_buffer + self.bs = bs + self.seq_lens = seq_lens + self.weight_indices = weight_indices + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + lora_a_output = self.segment_gemm.run( + x=x, + weights=self.A_buffer, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + # FIXME + assert lora_a_output.shape[-1] == self.lora_rank * 2 + lora_output = torch.empty_like(base_output) + output_dim = lora_output.shape[-1] // 2 + for i in range(2): + left = self.lora_rank * i + right = left + self.lora_rank + lora_output[:, output_dim * i : output_dim * (i + 1)] = ( + self.segment_gemm.run( + x=lora_a_output[:, left:right].contiguous(), + weights=self.B_buffer[:, :, left:right].contiguous(), + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + ) + return base_output + lora_output * self.scaling + + +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + def __init__( + self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): + self.A_buffer = A_buffer + self.B_buffer = B_buffer + self.bs = bs + self.seq_lens = seq_lens + self.weight_indices = weight_indices + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + lora_a_output = self.segment_gemm.run( + x=x, + weights=self.A_buffer, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + # FIXME parallelize qkv + assert lora_a_output.shape[-1] == self.lora_rank * 3 + lora_output = torch.empty_like(base_output) + output_dim = lora_output.shape[-1] // 3 + for i in range(3): + left = self.lora_rank * i + right = left + self.lora_rank + lora_output[:, output_dim * i : output_dim * (i + 1)] = ( + self.segment_gemm.run( + x=lora_a_output[:, left:right].contiguous(), + weights=self.B_buffer[:, :, left:right].contiguous(), + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + ) + return base_output + lora_output * self.scaling + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + def __init__( + self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling + ) -> None: + super().__init__(base_layer, segment_gemm, lora_rank, scaling) + + def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices): + self.A_buffer = A_buffer + self.B_buffer = B_buffer + self.bs = bs + self.seq_lens = seq_lens + self.weight_indices = weight_indices + + def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + lora_output = self.segment_gemm.run( + x=x, + weights=self.A_buffer, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + lora_output = self.segment_gemm.run( + x=lora_output, + weights=self.B_buffer, + batch_size=self.bs, + weight_column_major=True, + seg_lens=self.seq_lens, + weight_indices=self.weight_indices, + ) + return base_output + lora_output * self.scaling + + def forward(self, input_): + # duplicate the logic in RowParallelLinear + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_parallel + ) + + if hasattr(self, "A_buffer"): + output_parallel = self.apply_lora(output_parallel, input_parallel) + + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = ( + output_ + self.base_layer.bias + if self.base_layer.bias is not None + else output_ + ) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + +def get_lora_layer( + layer: nn.Module, segment_gemm, lora_rank, scaling +) -> BaseLayerWithLoRA: + supported_layer_types = { + # the order matters + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLora, + RowParallelLinear: RowParallelLinearWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling) + return ret + raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") + + +def params_mapping(module_name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + if module_name in params_mapping: + return params_mapping[module_name] + return module_name + + +def get_mapped_params(module_names): + ret = set() + for module_name in module_names: + ret.add(params_mapping(module_name)) + return list(ret) + + +class LoRALayer(nn.Module): + def __init__(self, config, base_hf_config): + super().__init__() + self.config = config + self.base_hf_config = base_hf_config + self.weights = {} + self.weight_gpu = {} + + def load_to_gpu(self): + for name, weight in self.weights.items(): + self.weight_gpu[name] = weight.to(torch.float16).to("cuda") + + def offload_from_gpu(self): + for name, weight in self.weights.items(): + self.weight_gpu[name] = None + + +class LoRAAdapter(nn.Module): + def __init__(self, uid, config, base_hf_config, load_config): + super().__init__() + self.uid = uid + self.config = config + assert self.config.hf_config["peft_type"].lower() == "lora" + self.base_hf_config = base_hf_config + self.load_config = load_config + self.scaling = self.config.lora_alpha / self.config.r + + self.layers = nn.ModuleList( + [ + LoRALayer(config, base_hf_config) + for i in range(base_hf_config.num_hidden_layers) + ] + ) + + self.weights = {} + self.weights_gpu = {} + + def get_stacked_multiply(self, module_name): + stacked_rank = { + "qkv_proj": 3, + "gate_up_proj": 2, + } + return stacked_rank[module_name] if module_name in stacked_rank else 1 + + def load_to_gpu(self): + for name, weight in self.weights.items(): + self.weights_gpu[name] = weight.to(torch.float16).to("cuda") + for layer in self.layers: + layer.load_to_gpu() + + def offload_from_gpu(self): + for name, weight in self.weights.items(): + self.weights_gpu[name] = None + for layer in self.layers: + layer.offload_from_gpu() + + # initialize the LoRA weights to cpu + def initialize_weights(self): + model_path = self.config.path + loader = DefaultModelLoader(self.load_config) + revision = getattr(self.config.hf_config, "revision", None) + for name, loaded_weight in loader._get_weights_iterator( + model_path, revision=revision, fall_back_to_pt=True + ): + match = re.search(r"layers\.(\d+)\.", name) + if match is not None: + layer_id = int(match.group(1)) + self.layers[layer_id].weights[name] = loaded_weight.cpu() + else: + self.weights[name] = loaded_weight.cpu() + + # stack qkv_proj and gate_up_proj + for i in range(self.base_hf_config.num_hidden_layers): + layer = self.layers[i] + weight_names = [name for name, _ in layer.weights.items()] + for weight_name in weight_names: + if "q_proj" in weight_name: + k_name = weight_name.replace("q_proj", "k_proj") + v_name = weight_name.replace("q_proj", "v_proj") + qkv_name = weight_name.replace("q_proj", "qkv_proj") + layer.weights[qkv_name] = torch.cat( + ( + layer.weights[weight_name], + layer.weights[k_name], + layer.weights[v_name], + ), + ( + 0 + if layer.weights[weight_name].shape[0] == self.config.r + else 1 + ), + ) + layer.weights.pop(weight_name) + layer.weights.pop(k_name) + layer.weights.pop(v_name) + elif "gate_proj" in weight_name: + up_name = weight_name.replace("gate_proj", "up_proj") + gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") + layer.weights[gate_up_name] = torch.cat( + (layer.weights[weight_name], layer.weights[up_name]), + ( + 0 + if layer.weights[weight_name].shape[0] == self.config.r + else 1 + ), + ) + layer.weights.pop(weight_name) + layer.weights.pop(up_name) diff --git a/python/sglang/srt/lora/lora_config.py b/python/sglang/srt/lora/lora_config.py new file mode 100644 index 00000000000..59af0c3a9eb --- /dev/null +++ b/python/sglang/srt/lora/lora_config.py @@ -0,0 +1,43 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import os + +from huggingface_hub import snapshot_download + + +class LoRAConfig: + def __init__( + self, + path: str, + ) -> None: + self.path = path + self.hf_config = self.get_lora_config() + self.target_modules = self.hf_config["target_modules"] + self.r = self.hf_config["r"] + self.lora_alpha = self.hf_config["lora_alpha"] + + def get_lora_config(self, dummy=False): + if dummy: + raise NotImplementedError() + else: + if not os.path.isdir(self.path): + weights_dir = snapshot_download(self.path, allow_patterns=["*.json"]) + else: + weights_dir = self.path + config_name = "adapter_config.json" + with open(os.path.join(weights_dir, config_name), "r") as f: + return json.load(f) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py new file mode 100644 index 00000000000..4b73bd16b12 --- /dev/null +++ b/python/sglang/srt/lora/lora_manager.py @@ -0,0 +1,212 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters" +# and "Punica: Multi-Tenant LoRA Serving" + + +import re +from dataclasses import dataclass + +import torch +from flashinfer import SegmentGEMMWrapper + +from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer, params_mapping +from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.utils import replace_submodule + + +class LoRAManager: + def __init__( + self, + base_model, + lora_paths, + base_hf_config, + max_loras_per_batch, + load_config, + dtype, + ): + self.base_model = base_model + self.lora_paths = lora_paths + self.base_hf_config = base_hf_config + self.max_loras_per_batch = max_loras_per_batch + self.load_config = load_config + self.dtype = dtype + + workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") + self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + + self.init_loras() + self.init_lora_memory_pool() + self.init_lora_batch() + + def match_target_modules(self, module_name): + for target_module in self.target_modules: + if module_name.split(".")[-1] == target_module: + return True + return False + + def get_target_modules(self): + modules = [] + for module_name, module in self.base_model.named_modules(): + if self.match_target_modules(module_name): + modules.append((module_name, module)) + return modules + + def set_lora_module(self, module_name, module): + lora_module = get_lora_layer( + module, self.segment_gemm, self.max_lora_dim, self.scaling + ) + replace_submodule(self.base_model, module_name, lora_module) + return lora_module + + def init_loras(self): + # get configs and target modules + self.configs = {} + self.target_modules = set() + for path in self.lora_paths: + self.configs[path] = LoRAConfig(path) + self.target_modules = set(self.target_modules) | set( + self.configs[path].target_modules + ) + self.target_modules = set( + [params_mapping(module) for module in self.target_modules] + ) + + # load all weights to cpu + self.loras = [] + self.lora_id = {} + for path in self.lora_paths: + self.lora_id[path] = len(self.loras) + self.loras.append( + LoRAAdapter( + path, self.configs[path], self.base_hf_config, self.load_config + ) + ) + self.loras[-1].initialize_weights() + + # misc lora configs + self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) + self.scaling = self.loras[0].scaling + # FIXME remove the restrictions + assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values()) + assert all(x.scaling == self.scaling for x in self.loras) + + # monkey patch to use Lora version + modules = self.get_target_modules() + self.lora_modules = [] + for module_name, module in modules: + self.lora_modules.append( + (module_name, self.set_lora_module(module_name, module)) + ) + + def init_lora_memory_pool(self): + # preallocate lora memory pool + self.A_buffer = {} + self.B_buffer = {} + num_layer = self.base_hf_config.num_hidden_layers + for module in self.target_modules: + c = self.loras[-1].get_stacked_multiply(module) + hidden_dim_A, hidden_dim_B = self.base_model.get_hidden_dim(module) + # init A tensor, column_major=True + self.A_buffer[module] = [ + torch.empty( + ( + self.max_loras_per_batch, + self.max_lora_dim * c, + hidden_dim_A, + ), + dtype=self.dtype, + device="cuda", + ) + for i in range(num_layer) + ] + # init B tensor, column_major=True + self.B_buffer[module] = [ + torch.empty( + ( + self.max_loras_per_batch, + hidden_dim_B, + self.max_lora_dim * c, + ), + dtype=self.dtype, + device="cuda", + ) + for i in range(num_layer) + ] + + def init_lora_batch(self): + self.active_uids = [None] * self.max_loras_per_batch # list of active loras + self.buffer_id = {} # lora uid -> idx in memory pool + + def get_target_module_name(self, module_name): + for module in self.target_modules: + if module in module_name: + return module + + def load_lora(self, uid, buffer_id): + num_layer = self.base_hf_config.num_hidden_layers + for i in range(num_layer): + layer_weights = self.loras[self.lora_id[uid]].layers[i].weights + for module_name, weights in layer_weights.items(): + target_module_name = self.get_target_module_name(module_name) + if "lora_A" in module_name: + self.A_buffer[target_module_name][i][buffer_id].copy_(weights) + else: + assert "lora_B" in module_name + self.B_buffer[target_module_name][i][buffer_id].copy_(weights) + + def prepare_lora_batch( + self, batch, forward_mode: ForwardMode, extend_seq_lens=None + ): + # load active loras into lora memory pool + cur_uids = set([req.lora_path for req in batch.reqs]) + assert len(cur_uids) <= self.max_loras_per_batch + i = 0 + for uid in cur_uids: + if uid not in self.active_uids: + while self.active_uids[i] in cur_uids: + i += 1 + self.load_lora(uid, i) + if self.active_uids[i] is not None: + self.buffer_id.pop(self.active_uids[i]) + self.active_uids[i] = uid + self.buffer_id[uid] = i + + if None in cur_uids: + return + + # setup lora in forward modules + bs = len(batch.reqs) + if forward_mode == ForwardMode.EXTEND: + seg_lens = extend_seq_lens + else: + seg_lens = torch.ones(bs) + weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") + for i, req in enumerate(batch.reqs): + weight_indices[i] = self.buffer_id[req.lora_path] + + for module_name, module in self.lora_modules: + target_model_name = self.get_target_module_name(module_name) + match = re.search(r"layers\.(\d+)\.", module_name) + layer_id = int(match.group(1)) + module.set_lora_info( + self.A_buffer[target_model_name][layer_id], + self.B_buffer[target_model_name][layer_id], + bs, + seg_lens, + weight_indices, + ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f5279eb8d00..1fc518fe7b1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -55,6 +55,9 @@ class GenerateReqInput: is_single: bool = True + # LoRA related + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None @@ -184,6 +187,9 @@ class TokenizedGenerateReqInput: # Modalities of the input images modalites: Optional[List[str]] = None + # LoRA related + lora_path: Optional[str] + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b6000734a26..8a9ecd8df5f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -98,7 +98,7 @@ def __str__(self) -> str: class Req: """Store all inforamtion of a request.""" - def __init__(self, rid, origin_input_text, origin_input_ids): + def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None): # Input and output info self.rid = rid self.origin_input_text = origin_input_text @@ -106,6 +106,7 @@ def __init__(self, rid, origin_input_text, origin_input_ids): self.origin_input_ids = origin_input_ids self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids + self.lora_path = lora_path # Memory info self.req_pool_idx = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d2fa6760129..54b08b337e5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -266,6 +266,11 @@ async def _handle_single_request( top_logprobs_num, obj.stream, modalities, + ( + obj.lora_path[index] + if isinstance(obj.lora_path, list) + else obj.lora_path + ), ) else: # is embedding tokenized_obj = TokenizedEmbeddingReqInput( @@ -364,6 +369,11 @@ async def _handle_batch_request( obj.top_logprobs_num[index], obj.stream, modalities, + ( + obj.lora_path[index] + if isinstance(obj.lora_path, list) + else obj.lora_path + ), ) else: tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b1131b011fe..5c4b56c937d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -323,7 +323,15 @@ def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, ): - req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) + if isinstance(recv_req, TokenizedGenerateReqInput): + req = Req( + recv_req.rid, + recv_req.input_text, + recv_req.input_ids, + recv_req.lora_path, + ) + else: + req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.tokenizer = self.tokenizer req.sampling_params = recv_req.sampling_params req.pixel_values = recv_req.pixel_values diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b04b0d7c019..97275cc68a8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,6 +47,8 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import SampleOutput +from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, @@ -78,6 +80,7 @@ def __init__( tp_size: int, nccl_port: int, server_args: ServerArgs, + lora_config: Optional[LoRAConfig] = None, ): # Parse args self.model_config = model_config @@ -109,6 +112,8 @@ def __init__( min_per_gpu_memory = self.init_torch_distributed() self.load_model() + if server_args.lora_paths is not None: + self.init_lora_manager() self.init_memory_pool( min_per_gpu_memory, server_args.max_running_requests, @@ -314,6 +319,17 @@ def model_load_weights(model, iter): logger.info("Update weights end.") return True, "Succeeded to update model weights" + def init_lora_manager(self): + self.lora_manager = LoRAManager( + base_model=self.model, + lora_paths=self.server_args.lora_paths, + base_hf_config=self.model_config.hf_config, + max_loras_per_batch=self.server_args.max_loras_per_batch, + load_config=self.load_config, + dtype=self.dtype, + ) + logger.info("LoRA manager ready.") + def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 @@ -525,6 +541,8 @@ def init_cuda_graphs(self): @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): + if self.server_args.lora_paths is not None: + self.lora_manager.prepare_lora_batch(batch, ForwardMode.DECODE) if ( self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)) @@ -541,6 +559,11 @@ def forward_decode(self, batch: ScheduleBatch): @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): input_metadata = InputMetadata.from_schedule_batch(self, batch) + if self.server_args.lora_paths is not None: + self.lora_manager.prepare_lora_batch( + batch, ForwardMode.EXTEND, input_metadata.extend_seq_lens + ) + if self.is_generation: return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index ac53712fca4..8b72d7accb7 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -324,6 +324,37 @@ def forward( sample_output = self.sampler(logits_output, input_metadata.sampling_info) return sample_output, logits_output + def get_hidden_dim(self, module_name): + if module_name in ["qkv_proj", "o_proj"]: + return self.config.hidden_size, self.config.hidden_size + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id, num_shard) + ("qkv_proj", "q_proj", "q", 3), + ("qkv_proj", "k_proj", "k", 3), + ("qkv_proj", "v_proj", "v", 3), + ("gate_up_proj", "gate_proj", 0, 2), + ("gate_up_proj", "up_proj", 1, 2), + ] + for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4aaf018a1bb..6c8e4e7d1b2 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -605,18 +605,20 @@ async def async_generate( def generate( self, - prompt: Union[str, List[str]], + prompts: Union[str, List[str]], sampling_params: Optional[Dict] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_paths: Optional[List[Optional[str]]] = None, ): json_data = { - "text": prompt, + "text": prompts, "sampling_params": sampling_params, "return_logprob": return_logprob, "logprob_start_len": logprob_start_len, "top_logprobs_num": top_logprobs_num, + "lora_path": lora_paths, } response = requests.post( self.url + "/generate", @@ -626,10 +628,10 @@ def generate( def encode( self, - prompt: Union[str, List[str]], + prompts: Union[str, List[str]], ): json_data = { - "text": prompt, + "text": prompts, } response = requests.post( self.url + "/encode", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0881344c086..59a44fe438b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -101,6 +101,10 @@ class ServerArgs: enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False + # LoRA + lora_paths: Optional[List[str]] = None + max_loras_per_batch: Optional[int] = 8 + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -509,6 +513,21 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) + # LoRA options + parser.add_argument( + "--lora-paths", + type=str, + nargs="*", + default=None, + help="The list of LoRA adapters.", + ) + parser.add_argument( + "--max-loras-per-batch", + type=int, + default=8, + help="Maximum number of adapters for a running batch", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 66a5679d756..125bb556f52 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -35,6 +35,7 @@ import torch.distributed as dist from fastapi.responses import JSONResponse from packaging import version as pkg_version +from torch import nn from torch.nn.parameter import Parameter from triton.runtime.cache import ( FileCacheManager, @@ -714,3 +715,14 @@ def configure_logger(server_args, prefix: str = ""): datefmt="%H:%M:%S", force=True, ) + + +# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9 +def replace_submodule( + model: nn.Module, module_name: str, new_module: nn.Module +) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 1d18d305fcb..9c27452f7e7 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -21,6 +21,7 @@ import torch import torch.nn.functional as F +from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer from sglang.srt.server import Runtime @@ -95,7 +96,7 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): ) if self.is_generation: - self.model = AutoModelForCausalLM.from_pretrained( + self.base_model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=False, @@ -110,13 +111,16 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): ) while True: - prompts, max_new_tokens = in_queue.get() + prompts, max_new_tokens, lora_paths = in_queue.get() + if lora_paths is not None: + assert len(prompts) == len(lora_paths) + if prompts is not None: if self.is_generation: output_strs = [] top_input_logprobs = [] top_output_logprobs = [] - for p in prompts: + for i, p in enumerate(prompts): if isinstance(p, str): input_ids = self.tokenizer.encode( p, return_tensors="pt" @@ -124,6 +128,17 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): else: input_ids = torch.tensor([p], device="cuda") + if lora_paths is not None: + self.model = PeftModel.from_pretrained( + self.base_model, + lora_paths[i], + torch_dtype=torch_dtype, + is_trainable=False, + ) + else: + self.model = self.base_model + + outputs = self.model.generate( input_ids, do_sample=False, @@ -167,8 +182,9 @@ def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, max_new_tokens=8, + lora_paths=None, ): - self.in_queue.put((prompts, max_new_tokens)) + self.in_queue.put((prompts, max_new_tokens, lora_paths)) return self.out_queue.get() def terminate(self): @@ -191,6 +207,7 @@ def __init__( is_generation, tp_size=1, port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, + lora_paths=None, ): self.is_generation = is_generation self.runtime = Runtime( @@ -201,12 +218,15 @@ def __init__( mem_fraction_static=0.69, trust_remote_code=False, is_embedding=not self.is_generation, + lora_paths=lora_paths, + disable_cuda_graph=True, ) def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, max_new_tokens=8, + lora_paths=None, ): if self.is_generation: # the return value contains logprobs from prefill @@ -214,9 +234,10 @@ def forward( top_input_logprobs = [] top_output_logprobs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} - for prompt in prompts: + for i, prompt in enumerate(prompts): response = self.runtime.generate( prompt, + lora_paths=lora_paths[i] if lora_paths else None, sampling_params=sampling_params, return_logprob=True, logprob_start_len=0, diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 46854b3e869..341b856e317 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -76,6 +76,7 @@ def assert_close_prefill_logits_and_output_strs( ) -> None: if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": prompts = prompts[:-1] + with HFRunner( model_path, torch_dtype=torch_dtype, is_generation=True ) as hf_runner: diff --git a/test/srt/models/test_lora.py b/test/srt/models/test_lora.py new file mode 100644 index 00000000000..2bfda72285a --- /dev/null +++ b/test/srt/models/test_lora.py @@ -0,0 +1,205 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import uuid + +import torch +from vllm.config import LoadConfig + +from sglang.srt.lora.lora import LoRAAdapter +from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.model_config import ModelConfig +from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner + +LORA_SETS = [ + # { + # "base": "meta-llama/Llama-2-7b-hf", + # "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"], + # } + {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]} + # {"base": "meta-llama/Llama-2-7b-hf", "loras": ["yard1/llama-2-7b-sql-lora-test"]} +] +TORCH_DTYPES = [torch.float16] + +PROMPTS = [ + """ +### Instruction: +Write a poem about the transformers Python library. +Mention the word "large language models" in that poem. +### Response: +The Transformers are large language models, +They're used to make predictions on text. +""", + """ +### Instruction: +Tell me about llamas and alpacas +### Response: +Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. +### Question 2: +What do you know about llamas? +### Answer: +""", +] + + +class TestLoRA(unittest.TestCase): + + def load_lora_adapter(self, lora_set, tp_size): + base_path = lora_set["base"] + lora_path = lora_set["loras"][0] + + base_config = ModelConfig(base_path) + lora_config = LoRAConfig(lora_path) + + uid = uuid.uuid4().hex + lora_adapter = LoRAAdapter( + uid, lora_config, base_config, LoadConfig(load_format="auto") + ) + lora_adapter.initialize_weights() + print(lora_adapter) + + def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [] + i = 0 + for _ in range(len(prompts)): + batch_lora_paths.append(all_lora_paths[i]) + i = (i + 1) % len(all_lora_paths) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + is_generation=True, + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + lora_paths=all_lora_paths, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + is_generation=True, + ) as hf_runner: + hf_no_lora_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + ) as srt_runner: + srt_no_lora_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + for i in range(len(prompts)): + hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i]) + srt_no_lora_logprobs = torch.Tensor( + srt_no_lora_outputs.top_input_logprobs[i] + ) + print( + "max_diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + ) + print( + "max_diff between srt_base and srt_lora", + torch.max(abs(srt_no_lora_logprobs - srt_logprobs)), + ) + print( + "max_diff between srt_base and hf_base", + torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)), + ) + print( + "max_diff between hf_lora and hf_base", + torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), + ) + + print(f"{hf_outputs.output_strs=}") + print(f"{srt_outputs.output_strs=}") + print(f"{hf_no_lora_outputs.output_strs=}") + print(f"{srt_no_lora_outputs.output_strs=}") + assert hf_outputs.output_strs == srt_outputs.output_strs + assert hf_no_lora_outputs.output_strs == srt_no_lora_outputs.output_strs + + def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [None] * len(prompts) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + ) as srt_runner: + srt_no_lora_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + is_generation=True, + lora_paths=all_lora_paths, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + for i in range(len(prompts)): + srt_no_lora_logprobs = torch.Tensor( + srt_no_lora_outputs.top_input_logprobs[i] + ) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs))) + + print(f"{srt_no_lora_outputs.output_strs=}") + print(f"{srt_outputs.output_strs=}") + + assert srt_outputs.output_strs == hf_outputs.output_strs + assert srt_no_lora_outputs.output_strs == hf_no_lora_outputs.output_strs + + def test_all(self): + for lora_set in LORA_SETS: + # self.load_lora_adapter(lora_set, 1) + for torch_dtype in TORCH_DTYPES: + tp_size = 1 + max_new_tokens = 64 + self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) + # self.base_inference( + # PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens + # ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore")