Skip to content

Commit

Permalink
[Misc][LoRA] Improve the readability of LoRA error messages (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#12102)

Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
jeejeelee authored and Isotr0py committed Feb 2, 2025
1 parent 4ece6d7 commit 68a7e95
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 116 deletions.
69 changes: 50 additions & 19 deletions tests/entrypoints/openai/test_lora_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,33 @@
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"

BADREQUEST_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]


@pytest.fixture(scope="module")
def zephyr_lora_files():
Expand Down Expand Up @@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
tmp_path, zephyr_lora_files):
invalid_rank = tmp_path / "invalid_rank"

# Copy adapter from zephyr_lora_files to invalid_rank
shutil.copytree(zephyr_lora_files, invalid_rank)

with open(invalid_rank / "adapter_config.json") as f:
@pytest.mark.parametrize("test_name,config_change,expected_error",
BADREQUEST_CASES)
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
zephyr_lora_files, test_name: str,
config_change: dict,
expected_error: str):
# Create test directory
test_dir = tmp_path / test_name

# Copy adapter files
shutil.copytree(zephyr_lora_files, test_dir)

# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)

print(adapter_config)

# assert False

# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)

with pytest.raises(openai.BadRequestError,
match="is greater than max_lora_rank"):
# Test loading the adapter
with pytest.raises(openai.BadRequestError, match=expected_error):
await client.post("load_lora_adapter",
cast_to=str,
body={
"lora_name": "invalid-json",
"lora_path": str(invalid_rank)
"lora_name": test_name,
"lora_path": str(test_dir)
})


Expand Down
16 changes: 16 additions & 0 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.utils import WeightsMapper

Expand Down Expand Up @@ -30,11 +31,14 @@ def test_load_checkpoints(
else:
expected_lora_modules.append(module)
if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
# For the baichuan7B model, load it's LoRA,
# and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -43,19 +47,25 @@ def test_load_checkpoints(
# Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -64,10 +74,13 @@ def test_load_checkpoints(
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
max_position_embeddings=4096)
with pytest.raises(ValueError, match=expected_error):
LoRAModel.from_local_checkpoint(
chatglm3_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand All @@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files):
".layers.": ".baichuan_layers.",
},
)
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand Down
3 changes: 3 additions & 0 deletions tests/lora/test_lora_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM

Expand All @@ -27,9 +28,11 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_path = get_adapter_absolute_path(lora_name)

# lora loading should work for either absolute path and hugggingface id.
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
lora_model = LoRAModel.from_local_checkpoint(
lora_path,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
Expand Down
59 changes: 2 additions & 57 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json
import math
import os
from typing import Dict, List

Expand Down Expand Up @@ -34,68 +32,15 @@
] if current_platform.is_cuda_alike() else ["cpu"])


def test_peft_helper(sql_lora_files):
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)
peft_helper = PEFTHelper.from_dict(config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
scaling = peft_helper.lora_alpha / peft_helper.r
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

# test RSLoRA
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
peft_helper = PEFTHelper.from_dict(config)

scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error):
config = dict(
r=8,
lora_alpha=16,
target_modules=["gate_proj"],
modules_to_save=["lm_head"],
)
PEFTHelper.from_dict(config)

expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_dora=True)
PEFTHelper.from_dict(config)


@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))

lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)

peft_helper = PEFTHelper.from_dict(config)
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
max_position_embeddings=4096)
lora_model = LoRAModel.from_lora_tensors(
1,
tensors,
Expand Down
109 changes: 109 additions & 0 deletions tests/lora/test_peft_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import json
import math
import shutil

import pytest

from vllm.config import LoRAConfig
from vllm.lora.peft_helper import PEFTHelper

ERROR_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]


def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
max_position_embeddings=4096)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
peft_helper.validate_legal(lora_config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
assert peft_helper.context_length == 16384
assert peft_helper.vllm_max_position_embeddings == 4096
assert peft_helper.vllm_long_context_scaling_factor == float(
math.ceil(peft_helper.context_length /
peft_helper.vllm_max_position_embeddings))
# test RSLoRA
rslora_config = dict(use_rslora=True)
test_dir = tmp_path / "test_rslora"
shutil.copytree(long_context_lora_files_16k_1, test_dir)

# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(rslora_config)

# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)

peft_helper = PEFTHelper.from_local_dir(test_dir,
max_position_embeddings=4096)
peft_helper.validate_legal(lora_config)
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3


@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
def test_peft_helper_error(
sql_lora_files,
tmp_path,
test_name: str,
config_change: dict,
expected_error: str,
):
test_dir = tmp_path / test_name
shutil.copytree(sql_lora_files, test_dir)

# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)

# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
# Test loading the adapter
with pytest.raises(ValueError, match=expected_error):
PEFTHelper.from_local_dir(
test_dir, max_position_embeddings=4096).validate_legal(lora_config)
1 change: 1 addition & 0 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
is_engine_errored=False,
exception=e)
self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message
self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id))
Expand Down
Loading

0 comments on commit 68a7e95

Please sign in to comment.