Skip to content

Commit

Permalink
[V1] LoRA Support (vllm-project#10957)
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Felix Marty <[email protected]>
  • Loading branch information
2 people authored and fxmarty-amd committed Feb 7, 2025
1 parent 4588bac commit 89ebc6b
Show file tree
Hide file tree
Showing 16 changed files with 453 additions and 56 deletions.
17 changes: 17 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,20 @@ def get_model_patched(**kwargs):
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)


@pytest.fixture(params=[True, False])
def run_with_both_engines_lora(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
monkeypatch.setenv('VLLM_USE_V1', '1')
else:
monkeypatch.setenv('VLLM_USE_V1', '0')

yield
8 changes: 8 additions & 0 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
Expand Down
13 changes: 13 additions & 0 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import List

import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -47,6 +49,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
Expand All @@ -66,6 +77,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
Expand All @@ -87,6 +99,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
Expand Down
12 changes: 12 additions & 0 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List

import pytest
import ray

import vllm
Expand Down Expand Up @@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files):
print("removing lora")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

Expand All @@ -85,6 +94,9 @@ def test_llama_lora(sql_lora_files):
generate_and_test(llm, sql_lora_files)


# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks
# used by the engine yet.
@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and
Expand Down
11 changes: 11 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


# Skipping for V1 for now as we are hitting,
# "Head size 80 is not supported by FlashAttention." error.
@pytest.mark.skip_v1
@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
Expand Down
13 changes: 13 additions & 0 deletions tests/lora/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -48,6 +50,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


# Skipping for V1 for now as we are hitting,
# "Head size 80 is not supported by FlashAttention." error.
@pytest.mark.skip_v1
def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def format_prompt_tuples(prompt):
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_generate_block_hash_extra_keys():

# Test with no overlap
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
assert extra_keys == ()
assert extra_keys is None
assert next_mm_idx == 1

# Test with multiple extra keys
Expand Down
8 changes: 5 additions & 3 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
tensor_model_parallel_all_reduce)
from vllm.distributed.utils import divide
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -1043,7 +1042,10 @@ def _get_logits(
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)

# Gather logits for TP
logits = self.base_layer._gather_logits(logits)

if logits is None:
return None

Expand Down
28 changes: 17 additions & 11 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(self,
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.

parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or envs.VLLM_USE_V1 \
Expand Down Expand Up @@ -88,6 +87,20 @@ def forward(

return logits

def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""gather/all-gather the logits tensor across model parallel group."""
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
return logits

def _get_logits(
self,
hidden_states: torch.Tensor,
Expand All @@ -99,16 +112,9 @@ def _get_logits(
hidden_states,
bias=embedding_bias)

if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
# Gather logits for TP
logits = self._gather_logits(logits)

# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
Expand Down
Loading

0 comments on commit 89ebc6b

Please sign in to comment.