Skip to content

Commit

Permalink
Fix apply_lora_packed_nslice for Multi-LoRA & Add LoRA layer test for…
Browse files Browse the repository at this point in the history
… HPU
  • Loading branch information
JHLEE17 committed Aug 1, 2024
1 parent 301579d commit 4ef5a6d
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 12 deletions.
225 changes: 225 additions & 0 deletions test_lora_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import pytest
import torch

from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice

from .utils import DummyLoRAManager

TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
QKV_TENSOR_SIZES = [
(8192, 1024, 1024),
(8192 // 8, 1024 // 8, 1024 // 8),
(4096, 4096, 4096),
(4096 // 2, 4096 // 2, 4096 // 2),
]
BATCH_SIZES = [8, 32, 256]
RANKS = [8]
DTYPES = [torch.float16]
TOLERANCES = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}


@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora(m, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()

module_name = "module"
weight = torch.rand([m, n], device="hpu", dtype=dtype)

manager.init_random_lora(module_name, weight, rank=rank)
lora = manager.get_module_lora(module_name)

input = torch.rand(k, n, device="hpu", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling

lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="hpu",
dtype=dtype)
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="hpu",
dtype=dtype)
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora(
input, lora_a_stack, lora_b_stack,
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="hpu"),
output)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="hpu"), output)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()


@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
if m % 2 != 0:
pytest.skip("m must be divisible by 2")
if m // 2 not in TENSOR_SIZES:
pytest.skip("m//2 must be in TENSOR_SIZES")

manager = DummyLoRAManager()

module_name = "module"
weight = torch.rand([m // 2, n], device="hpu", dtype=dtype)

manager.init_random_lora(module_name + "1", weight, rank=rank)
lora_1 = manager.get_module_lora(module_name + "1")
manager.init_random_lora(module_name + "2", weight, rank=rank)
lora_2 = manager.get_module_lora(module_name + "2")

input = torch.rand(k, n, device="hpu", dtype=dtype)
expected = torch.cat([
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
],
dim=1)

lora_a_stacks = [
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu"), output, (m // 2, m // 2))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
output, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()


@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()

module_name = "module"
weight_q = torch.empty(qkv[0], n, device="hpu", dtype=dtype)
weight_kv = torch.empty(qkv[1], n, device="hpu", dtype=dtype)

manager.init_random_lora(module_name + "q", weight_q, rank=rank)
lora_q = manager.get_module_lora(module_name + "q")
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
lora_k = manager.get_module_lora(module_name + "k")
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
lora_v = manager.get_module_lora(module_name + "v")

input = torch.rand(k, n, device="hpu", dtype=dtype)
expected = torch.cat([
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
],
dim=1)

lora_a_stacks = [
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
lora_a_stacks[2][i][0] = lora_v.lora_a.T
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T

output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu"), output, (qkv[0], qkv[1], qkv[2]))

rtol, atol = TOLERANCES[dtype]
# import pdb; pdb.set_trace()
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
output, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
2 changes: 1 addition & 1 deletion tests/lora/test_llama_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import vllm
from vllm.lora.request import LoRARequest

from conftest import cleanup
from .conftest import cleanup

MODEL_PATH = "meta-llama/Llama-2-7b-hf"

Expand Down
10 changes: 5 additions & 5 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ def init_random_lora(self,
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype,
device="cuda"),
device="hpu"),
lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype,
device="cuda"),
device="hpu"),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5,
generate_embeddings_tensor,
dtype=weight.dtype,
device="cuda")
device="hpu")
self.set_module_lora(module_name, lora)

return lora
Expand All @@ -53,8 +53,8 @@ def init_lora(self,
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"),
lora_a=torch.rand([input_dim, rank], device="hpu"),
lora_b=torch.rand([rank, output_dim], device="hpu"),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)
Expand Down
7 changes: 2 additions & 5 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,13 @@ def _apply_lora_packed_nslice(
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
slice_size = output.shape[-1] // len(output_slices)
for slice_idx in range(len(output_slices)):
# add_lora_slice(output, x, lora_a_stacked[slice_idx],
# lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
# output_slices[slice_idx])
# offset_left += output_slices[slice_idx]
start = slice_idx * slice_size
end = min((slice_idx + 1)* slice_size, output.shape[-1])
custom_bgmv(output[:, start:end], x, lora_a_stacked[slice_idx],
custom_bgmv(output[:, offset_left:offset_left+output_slices[slice_idx]], x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0)
offset_left += output_slices[slice_idx]
return output.view_as(org_output)


Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def initialize_cache(self, num_gpu_blocks: int,
self.cache_config.num_cpu_blocks = num_cpu_blocks

self._init_cache_engine()
self._warm_up_model()
# self._warm_up_model()

def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
Expand Down

0 comments on commit 4ef5a6d

Please sign in to comment.