diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index 46a064f6d9e68..6ff99f6faa143 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -17,6 +17,33 @@ # generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" +BADREQUEST_CASES = [ + ( + "test_rank", + { + "r": 1024 + }, + "is greater than max_lora_rank", + ), + ( + "test_bias", + { + "bias": "all" + }, + "Adapter bias cannot be used without bias_enabled", + ), + ("test_dora", { + "use_dora": True + }, "does not yet support DoRA"), + ( + "test_modules_to_save", + { + "modules_to_save": ["lm_head"] + }, + "only supports modules_to_save being None", + ), +] + @pytest.fixture(scope="module") def zephyr_lora_files(): @@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, @pytest.mark.asyncio -async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI, - tmp_path, zephyr_lora_files): - invalid_rank = tmp_path / "invalid_rank" - - # Copy adapter from zephyr_lora_files to invalid_rank - shutil.copytree(zephyr_lora_files, invalid_rank) - - with open(invalid_rank / "adapter_config.json") as f: +@pytest.mark.parametrize("test_name,config_change,expected_error", + BADREQUEST_CASES) +async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, + zephyr_lora_files, test_name: str, + config_change: dict, + expected_error: str): + # Create test directory + test_dir = tmp_path / test_name + + # Copy adapter files + shutil.copytree(zephyr_lora_files, test_dir) + + # Load and modify configuration + config_path = test_dir / "adapter_config.json" + with open(config_path) as f: adapter_config = json.load(f) + # Apply configuration changes + adapter_config.update(config_change) - print(adapter_config) - - # assert False - - # Change rank to invalid value - adapter_config["r"] = 1024 - with open(invalid_rank / "adapter_config.json", "w") as f: + # Save modified configuration + with open(config_path, "w") as f: json.dump(adapter_config, f) - with pytest.raises(openai.BadRequestError, - match="is greater than max_lora_rank"): + # Test loading the adapter + with pytest.raises(openai.BadRequestError, match=expected_error): await client.post("load_lora_adapter", cast_to=str, body={ - "lora_name": "invalid-json", - "lora_path": str(invalid_rank) + "lora_name": test_name, + "lora_path": str(test_dir) }) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 537d95b025a9d..b907af47d08d7 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -3,6 +3,7 @@ import pytest from vllm.lora.models import LoRAModel +from vllm.lora.peft_helper import PEFTHelper from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.utils import WeightsMapper @@ -30,11 +31,14 @@ def test_load_checkpoints( else: expected_lora_modules.append(module) if lora_name == "baichuan7B": + peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, + max_position_embeddings=4096) # For the baichuan7B model, load it's LoRA, # and the test should pass. LoRAModel.from_local_checkpoint( baichuan_lora_files, expected_lora_modules, + peft_helper=peft_helper, lora_model_id=1, device="cpu", embedding_modules=embedding_modules, @@ -43,9 +47,12 @@ def test_load_checkpoints( # Test that the target_modules contain prefix # such as "model.layers.0.self_atten.W_pack", and # the test should pass. + peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files, + max_position_embeddings=4096) LoRAModel.from_local_checkpoint( baichuan_zero_lora_files, expected_lora_modules, + peft_helper=peft_helper, lora_model_id=1, device="cpu", embedding_modules=embedding_modules, @@ -53,9 +60,12 @@ def test_load_checkpoints( elif lora_name == "baichuan7B-zero-regex": # Test that the `target_modules` in the form of regular expressions, # such as `model\\..*(W_pack|o_proj)`, and the test should pass. + peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files, + max_position_embeddings=4096) LoRAModel.from_local_checkpoint( baichuan_regex_lora_files, expected_lora_modules, + peft_helper=peft_helper, lora_model_id=1, device="cpu", embedding_modules=embedding_modules, @@ -64,10 +74,13 @@ def test_load_checkpoints( # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 + peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files, + max_position_embeddings=4096) with pytest.raises(ValueError, match=expected_error): LoRAModel.from_local_checkpoint( chatglm3_lora_files, expected_lora_modules, + peft_helper=peft_helper, lora_model_id=1, device="cpu", embedding_modules=embedding_modules, @@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files): ".layers.": ".baichuan_layers.", }, ) + peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, + max_position_embeddings=4096) lora_model = LoRAModel.from_local_checkpoint( baichuan_lora_files, expected_lora_modules, + peft_helper=peft_helper, lora_model_id=1, device="cpu", embedding_modules=embedding_modules, diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index e2daf9d135113..1c0ee01c038d0 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -3,6 +3,7 @@ import pytest from vllm.lora.models import LoRAModel +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.utils import get_adapter_absolute_path from vllm.model_executor.models.llama import LlamaForCausalLM @@ -27,9 +28,11 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_path = get_adapter_absolute_path(lora_name) # lora loading should work for either absolute path and hugggingface id. + peft_helper = PEFTHelper.from_local_dir(lora_path, 4096) lora_model = LoRAModel.from_local_checkpoint( lora_path, expected_lora_modules, + peft_helper=peft_helper, lora_model_id=1, device="cpu", embedding_modules=embedding_modules, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index ca523c66abe42..9a5b9aabf5078 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -1,5 +1,3 @@ -import json -import math import os from typing import Dict, List @@ -34,56 +32,6 @@ ] if current_platform.is_cuda_alike() else ["cpu"]) -def test_peft_helper(sql_lora_files): - lora_config_path = os.path.join(sql_lora_files, "adapter_config.json") - with open(lora_config_path) as f: - config = json.load(f) - peft_helper = PEFTHelper.from_dict(config) - assert peft_helper.r == 8 - assert peft_helper.lora_alpha == 16 - assert peft_helper.target_modules == [ - "q_proj", - "v_proj", - "k_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] - scaling = peft_helper.lora_alpha / peft_helper.r - assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 - - # test RSLoRA - config = dict(r=8, - lora_alpha=16, - target_modules=["gate_proj"], - use_rslora=True) - peft_helper = PEFTHelper.from_dict(config) - - scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r) - assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 - - expected_error = "vLLM only supports modules_to_save being None." - with pytest.raises(ValueError, match=expected_error): - config = dict( - r=8, - lora_alpha=16, - target_modules=["gate_proj"], - modules_to_save=["lm_head"], - ) - PEFTHelper.from_dict(config) - - expected_error = "vLLM does not yet support DoRA." - with pytest.raises(ValueError, match=expected_error): - config = dict(r=8, - lora_alpha=16, - target_modules=["gate_proj"], - use_dora=True) - PEFTHelper.from_dict(config) - - @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( @@ -91,11 +39,8 @@ def test_from_lora_tensors(sql_lora_files, device): new_embeddings = load_file( os.path.join(sql_lora_files, "new_embeddings.safetensors")) - lora_config_path = os.path.join(sql_lora_files, "adapter_config.json") - with open(lora_config_path) as f: - config = json.load(f) - - peft_helper = PEFTHelper.from_dict(config) + peft_helper = PEFTHelper.from_local_dir(sql_lora_files, + max_position_embeddings=4096) lora_model = LoRAModel.from_lora_tensors( 1, tensors, diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py new file mode 100644 index 0000000000000..a524d5ce5f34a --- /dev/null +++ b/tests/lora/test_peft_helper.py @@ -0,0 +1,109 @@ +import json +import math +import shutil + +import pytest + +from vllm.config import LoRAConfig +from vllm.lora.peft_helper import PEFTHelper + +ERROR_CASES = [ + ( + "test_rank", + { + "r": 1024 + }, + "is greater than max_lora_rank", + ), + ( + "test_bias", + { + "bias": "all" + }, + "Adapter bias cannot be used without bias_enabled", + ), + ("test_dora", { + "use_dora": True + }, "does not yet support DoRA"), + ( + "test_modules_to_save", + { + "modules_to_save": ["lm_head"] + }, + "only supports modules_to_save being None", + ), +] + + +def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): + peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1, + max_position_embeddings=4096) + lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) + peft_helper.validate_legal(lora_config) + assert peft_helper.r == 8 + assert peft_helper.lora_alpha == 16 + assert peft_helper.target_modules == [ + "q_proj", + "v_proj", + "k_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + assert peft_helper.context_length == 16384 + assert peft_helper.vllm_max_position_embeddings == 4096 + assert peft_helper.vllm_long_context_scaling_factor == float( + math.ceil(peft_helper.context_length / + peft_helper.vllm_max_position_embeddings)) + # test RSLoRA + rslora_config = dict(use_rslora=True) + test_dir = tmp_path / "test_rslora" + shutil.copytree(long_context_lora_files_16k_1, test_dir) + + # Load and modify configuration + config_path = test_dir / "adapter_config.json" + with open(config_path) as f: + adapter_config = json.load(f) + # Apply configuration changes + adapter_config.update(rslora_config) + + # Save modified configuration + with open(config_path, "w") as f: + json.dump(adapter_config, f) + + peft_helper = PEFTHelper.from_local_dir(test_dir, + max_position_embeddings=4096) + peft_helper.validate_legal(lora_config) + scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r) + assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 + + +@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES) +def test_peft_helper_error( + sql_lora_files, + tmp_path, + test_name: str, + config_change: dict, + expected_error: str, +): + test_dir = tmp_path / test_name + shutil.copytree(sql_lora_files, test_dir) + + # Load and modify configuration + config_path = test_dir / "adapter_config.json" + with open(config_path) as f: + adapter_config = json.load(f) + # Apply configuration changes + adapter_config.update(config_change) + + # Save modified configuration + with open(config_path, "w") as f: + json.dump(adapter_config, f) + lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) + # Test loading the adapter + with pytest.raises(ValueError, match=expected_error): + PEFTHelper.from_local_dir( + test_dir, max_position_embeddings=4096).validate_legal(lora_config) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 8f231de912c95..3aa9d30549f36 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -296,6 +296,7 @@ def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): is_engine_errored=False, exception=e) self._send_outputs(rpc_err) + return # Otherwise, send back the successful load message self._send_outputs( RPCAdapterLoadedResponse(request_id=request.request_id)) diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index a222eafadcb68..fc422f0917bd5 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -157,24 +157,16 @@ async def load_lora_adapter( # This will also pre-load it for incoming requests try: await self.engine_client.add_lora(lora_request) - except ValueError as e: - # Adapter not found or lora configuration errors - if "No adapter found" in str(e): - return create_error_response(message=str(e), - err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) - else: - return create_error_response( - message=str(e), - err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST) except BaseException as e: - # Some other unexpected problem loading the adapter, e.g. malformed - # input files. - # More detailed error messages for the user would be nicer here + error_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + if isinstance(e, ValueError) and "No adapter found" in str(e): + error_type = "NotFoundError" + status_code = HTTPStatus.NOT_FOUND + return create_error_response(message=str(e), - err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST) + err_type=error_type, + status_code=status_code) self.lora_requests.append(lora_request) logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 5b7225bdc8f37..9809405ca9a61 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -1,5 +1,4 @@ import copy -import json import math import os import re @@ -180,8 +179,8 @@ def from_local_checkpoint( cls, lora_dir: str, expected_lora_modules: List[str], + peft_helper: PEFTHelper, *, - max_position_embeddings: Optional[int] = None, lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -196,9 +195,7 @@ def from_local_checkpoint( lora_dir: The local path that has lora data. expected_lora_modules: Name of modules that are expected to be replaced by lora. - max_position_embeddings: Max position embedding length. Used to - scaling the largest context length. If None, the lora model's - context length is not scaled. + peft_helper: Loaded lora configuration information. lora_model_id: Lora model id. If not given, automatically set by a global counter. device: Device where the lora model is loaded. @@ -207,18 +204,13 @@ def from_local_checkpoint( Returns: Loaded LoRA Model. """ - lora_config_path = os.path.join(lora_dir, "adapter_config.json") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") new_embeddings_tensor_path = os.path.join( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") - with open(lora_config_path) as f: - config = json.load(f) - config["vllm_max_position_embeddings"] = max_position_embeddings - peft_helper = PEFTHelper.from_dict(config) unexpected_modules: List[Union[list[str], str]] if os.path.isfile(lora_tensor_path): tensors: Dict[str, torch.Tensor] = {} diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index dacfb9ebd1480..b9c506f6e0bfd 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -1,9 +1,12 @@ # Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py +import json import math +import os from dataclasses import MISSING, dataclass, field, fields -from typing import Literal, Optional, Union +from typing import List, Literal, Optional, Union +from vllm.config import LoRAConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -11,6 +14,12 @@ @dataclass class PEFTHelper: + """ + A helper class for PEFT configurations, specifically designed for LoRA. + This class handles configuration validation, compatibility checks for + various LoRA implementations. + """ + # Required fields r: int lora_alpha: int @@ -29,20 +38,18 @@ class PEFTHelper: vllm_max_position_embeddings: Optional[int] = field(default=False) vllm_long_context_scaling_factor: Optional[float] = field(default=None) - def _validate_features(self): + def _validate_features(self) -> List[str]: + """ + Check if there are any unsupported Lora features. + """ error_msg = [] - if self.modules_to_save: error_msg.append("vLLM only supports modules_to_save being None.") - if self.use_dora: error_msg.append("vLLM does not yet support DoRA.") - - if error_msg: - raise ValueError(f"{', '.join(error_msg)}") + return error_msg def __post_init__(self): - self._validate_features() if self.use_rslora: logger.info_once("Loading LoRA weights trained with rsLoRA.") self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) @@ -78,3 +85,29 @@ def from_dict(cls, config_dict: dict) -> "PEFTHelper": for k, v in config_dict.items() if k in class_fields } return cls(**filtered_dict) + + @classmethod + def from_local_dir(cls, lora_path: str, + max_position_embeddings: Optional[int]) -> "PEFTHelper": + lora_config_path = os.path.join(lora_path, "adapter_config.json") + + with open(lora_config_path) as f: + config = json.load(f) + config["vllm_max_position_embeddings"] = max_position_embeddings + return cls.from_dict(config) + + def validate_legal(self, lora_config: LoRAConfig) -> None: + """ + Validates the LoRA configuration settings against application + constraints and requirements. + """ + error_msg = self._validate_features() + if self.r > lora_config.max_lora_rank: + error_msg.append( + f"LoRA rank {self.r} is greater than max_lora_rank" + f" {lora_config.max_lora_rank}.") + if self.bias != "none" and not lora_config.bias_enabled: + error_msg.append( + "Adapter bias cannot be used without bias_enabled.") + if error_msg: + raise ValueError(f"{' '.join(error_msg)}") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index eec462743fe9d..a64296f7fd902 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -95,6 +96,13 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: expected_lora_modules = list(set(expected_lora_modules)) lora_path = get_adapter_absolute_path(lora_request.lora_path) + peft_helper = PEFTHelper.from_local_dir( + lora_path, self.max_position_embeddings) + + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. hf_to_vllm_mapper = None @@ -105,7 +113,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: lora = self._lora_model_cls.from_local_checkpoint( lora_path, expected_lora_modules, - max_position_embeddings=self.max_position_embeddings, + peft_helper=peft_helper, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, @@ -120,15 +128,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # - No adapter found to download from huggingface (or in # offline mode) # - No local adapter files found at `lora_request.lora_path` + # For NotFoundError raise ValueError( f"Loading lora {lora_request.lora_name} failed: No adapter " f"found for {lora_path}") from e except Exception as e: - raise RuntimeError(f"Loading lora {lora_path} failed") from e - if lora.rank > self.lora_config.max_lora_rank: - raise ValueError( - f"LoRA rank {lora.rank} is greater than max_lora_rank " - f"{self.lora_config.max_lora_rank}.") + # For BadRequestError + raise e + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " f"is greater than lora_extra_vocab_size "