From 4ef5a6d76033087216aa8366521340e498983790 Mon Sep 17 00:00:00 2001 From: JonghoLee Date: Thu, 1 Aug 2024 10:38:06 +0000 Subject: [PATCH] Fix apply_lora_packed_nslice for Multi-LoRA & Add LoRA layer test for HPU --- test_lora_hpu.py | 225 +++++++++++++++++++++++++++++++++++ tests/lora/test_llama_hpu.py | 2 +- tests/lora/utils.py | 10 +- vllm/lora/layers.py | 7 +- vllm/worker/habana_worker.py | 2 +- 5 files changed, 234 insertions(+), 12 deletions(-) create mode 100644 test_lora_hpu.py diff --git a/test_lora_hpu.py b/test_lora_hpu.py new file mode 100644 index 0000000000000..b4d9009789197 --- /dev/null +++ b/test_lora_hpu.py @@ -0,0 +1,225 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.float16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="hpu", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(8, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="hpu", + dtype=dtype) + lora_b_stack = torch.zeros(8, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="hpu", + dtype=dtype) + for i in range(lora_a_stack.shape[0]): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="hpu", dtype=dtype) + _apply_lora( + input, lora_a_stack, lora_b_stack, + torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="hpu"), + output) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="hpu"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="hpu", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="hpu", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="hpu"), output, (m // 2, m // 2)) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="hpu"), + output, (m // 2, m // 2)) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="hpu", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="hpu", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="hpu", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="hpu", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="hpu", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="hpu", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="hpu"), output, (qkv[0], qkv[1], qkv[2])) + + rtol, atol = TOLERANCES[dtype] + # import pdb; pdb.set_trace() + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="hpu"), + output, (qkv[0], qkv[1], qkv[2])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_llama_hpu.py b/tests/lora/test_llama_hpu.py index 79506e1b4d1ad..3c434f3786c2f 100644 --- a/tests/lora/test_llama_hpu.py +++ b/tests/lora/test_llama_hpu.py @@ -6,7 +6,7 @@ import vllm from vllm.lora.request import LoRARequest -from conftest import cleanup +from .conftest import cleanup MODEL_PATH = "meta-llama/Llama-2-7b-hf" diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b73cf5bf55324..3a0c460b7a81d 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -28,16 +28,16 @@ def init_random_lora(self, lora_alpha=1, lora_a=torch.rand([weight.shape[1], rank], dtype=weight.dtype, - device="cuda"), + device="hpu"), lora_b=torch.rand([rank, weight.shape[0]], dtype=weight.dtype, - device="cuda"), + device="hpu"), ) if generate_embeddings_tensor: lora.embeddings_tensor = torch.rand(5, generate_embeddings_tensor, dtype=weight.dtype, - device="cuda") + device="hpu") self.set_module_lora(module_name, lora) return lora @@ -53,8 +53,8 @@ def init_lora(self, module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([input_dim, rank], device="cuda"), - lora_b=torch.rand([rank, output_dim], device="cuda"), + lora_a=torch.rand([input_dim, rank], device="hpu"), + lora_b=torch.rand([rank, output_dim], device="hpu"), embeddings_tensor=embeddings_tensor, ) self.set_module_lora(module_name, lora) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index acf62c0375000..af127529032d7 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -156,16 +156,13 @@ def _apply_lora_packed_nslice( output = output.view(-1, output.shape[-1]) indices = indices.view(-1) offset_left = 0 - slice_size = output.shape[-1] // len(output_slices) for slice_idx in range(len(output_slices)): # add_lora_slice(output, x, lora_a_stacked[slice_idx], # lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, # output_slices[slice_idx]) - # offset_left += output_slices[slice_idx] - start = slice_idx * slice_size - end = min((slice_idx + 1)* slice_size, output.shape[-1]) - custom_bgmv(output[:, start:end], x, lora_a_stacked[slice_idx], + custom_bgmv(output[:, offset_left:offset_left+output_slices[slice_idx]], x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], indices, 0, 1.0) + offset_left += output_slices[slice_idx] return output.view_as(org_output) diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index efe6361b0e231..092790ea2e364 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -162,7 +162,7 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_cpu_blocks = num_cpu_blocks self._init_cache_engine() - self._warm_up_model() + # self._warm_up_model() def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None