From 9276ccca92ed8703648c03ddc713990c168d6e96 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Thu, 17 Oct 2024 15:23:49 +0200 Subject: [PATCH 01/17] Add WA for RuntimeError: "fill_cpu" not implemented for 'Float8_e4m3fn' (#402) --- vllm/worker/hpu_worker.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 752388e0d632f..8badc5f6bdb43 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -436,12 +436,14 @@ def _allocate_kv_cache( kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + dtype = self.dtype + if device != 'hpu' and not is_fake_hpu() \ + and self.dtype == torch.float8_e4m3fn: + dtype = torch.uint8 for _ in range(self.num_attention_layers): - key_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - device=device) + key_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=device) value_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, + dtype=dtype, device=device) kv_layer = (key_cache, value_cache) kv_cache.append(kv_layer) From 07c98a5263967af6afd7eb58119c5c5504d9a9f2 Mon Sep 17 00:00:00 2001 From: Artur Fierka Date: Fri, 18 Oct 2024 09:13:23 +0200 Subject: [PATCH 02/17] Workaround for OOM during loading llama-405 (#396) Repeating missing code --- vllm/model_executor/models/llama.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 18ce8d7f7d164..a64edc94825f3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -3,6 +3,7 @@ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright 2024 Habana Labs, Ltd. an Intel Company # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -420,6 +421,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if is_hpu: + torch.hpu.synchronize() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should From acde882c25f64d150c4ef0c60d27e010f78c8fd5 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Tue, 22 Oct 2024 10:24:52 +0200 Subject: [PATCH 03/17] Add HPU specific arguments to benchmark_throughput (#406) Modify `benchmark_throughput.py` to allow running with FP8 on HPU (KV cache dtype `fp8_inc`) and to use padding-aware scheduling. --- benchmarks/benchmark_throughput.py | 38 ++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index b7bc2a6402375..e1a359b871e71 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -90,6 +90,10 @@ def run_vllm( download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, + weights_load_device: str = None, + use_padding_aware_scheduling: bool = False, + max_num_seqs: int = 256, + max_num_prefill_seqs: int = None, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -115,6 +119,10 @@ def run_vllm( num_scheduler_steps=num_scheduler_steps, use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, + weights_load_device=weights_load_device, + use_padding_aware_scheduling=use_padding_aware_scheduling, + max_num_seqs=max_num_seqs, + max_num_prefill_seqs=max_num_prefill_seqs, ) # Add the requests to the engine. @@ -181,6 +189,10 @@ async def run_vllm_async( load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, disable_frontend_multiprocessing: bool = False, + weights_load_device: str = None, + use_padding_aware_scheduling: bool = False, + max_num_seqs: int = 256, + max_num_prefill_seqs: int = None, ) -> float: from vllm import SamplingParams engine_args = AsyncEngineArgs( @@ -208,6 +220,9 @@ async def run_vllm_async( disable_async_output_proc=disable_async_output_proc, worker_use_ray=False, disable_log_requests=True, + weights_load_device=weights_load_device, + use_padding_aware_scheduling=use_padding_aware_scheduling, + max_num_prefill_seqs=max_num_prefill_seqs, ) async with build_async_engine_client_from_engine_args( @@ -342,7 +357,9 @@ def main(args: argparse.Namespace): args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, args.use_v2_block_manager, args.download_dir, args.load_format, - args.disable_async_output_proc + args.disable_async_output_proc, args.weights_load_device, + args.use_padding_aware_scheduling, args.max_num_seqs, + args.max_num_prefill_seqs ] if args.async_engine: @@ -446,7 +463,7 @@ def main(args: argparse.Namespace): parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default="auto", help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' @@ -540,6 +557,23 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + parser.add_argument("--weights-load-device", + type=str, + default=None, + choices=DEVICE_OPTIONS, + help='Device on which weights are loaded.') + parser.add_argument("--use-padding-aware-scheduling", + action='store_true', + default=False, + help="Enable padding-aware scheduling.") + parser.add_argument("--max-num-seqs", + type=int, + default=256, + help="Maximum number of requests for single decode.") + parser.add_argument("--max-num-prefill-seqs", + type=int, + default=None, + help="Maximum number of requests for single prefill.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model From 8c43ff1fb2c5e1c62dfb1771f6d7f5665958ff85 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Tue, 22 Oct 2024 10:25:43 +0200 Subject: [PATCH 04/17] Add forward_hpu to RotaryEmbedding, remove custom module (#404) This PR removes the usage of custom HPU RotaryEmbedding modules, and adds a forward_hpu method to existing RotaryEmbedding, for reusing multiple derived implementations without the need of adding them to HPU extension. Mark_steps should not be needed within the test, but for whatever reason, if they are not there, PT bridge crashes. To be investigated later on. It does not affect actual model execution in any way I could test/observe. --- tests/kernels/test_pos_encoding.py | 10 ++ .../model_executor/layers/rotary_embedding.py | 94 +++++++++++++------ 2 files changed, 73 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index ba9d2d4389b21..6ca3a645c7771 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -20,6 +21,9 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + CUDA_DEVICES = ['hpu'] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -65,6 +69,8 @@ def test_rotary_embedding( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) + if current_platform.is_hpu(): + htorch.core.mark_step() out_query, out_key = rope.forward(positions, query, key) # Compare the results. torch.testing.assert_close(out_query, @@ -120,6 +126,8 @@ def test_batched_rotary_embedding( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) + if current_platform.is_hpu(): + htorch.core.mark_step() out_query, out_key = rope.forward(positions, query, key, @@ -193,6 +201,8 @@ def test_batched_rotary_embedding_multi_lora( # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key, query_offsets) + if current_platform.is_hpu(): + htorch.core.mark_step() out_query, out_key = rope.forward(positions, query, key, query_offsets.flatten()) # Compare the results. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 85cd700c978ea..10626d53338e3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,7 +28,6 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -195,6 +194,61 @@ def forward_xpu( self.cos_sin_cache, self.is_neox_style) return query, key + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, apply_rotary_pos_emb) + positions = positions.flatten() + if offsets is not None: + positions = positions + offsets + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view( + num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, + rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" @@ -918,17 +972,8 @@ def get_rope( return _ROPE_DICT[key] if rope_scaling is None: - if current_platform.is_hpu(): - from vllm_hpu_extension.rotary_embed import HpuRotaryEmbedding - rotary_emb = HpuRotaryEmbedding(head_size, - rotary_dim, - max_position, - base, - is_neox_style, - RoPEFallback=RotaryEmbedding) - else: - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, - base, is_neox_style, dtype) + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) else: scaling_type = rope_scaling[ "type"] if "type" in rope_scaling else rope_scaling["rope_type"] @@ -941,25 +986,12 @@ def get_rope( high_freq_factor = rope_scaling["high_freq_factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] - if current_platform.is_hpu(): - from vllm_hpu_extension.rotary_embed import ( - HpuLlama3RotaryEmbedding) - rotary_emb = HpuLlama3RotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_factor, - low_freq_factor, - high_freq_factor, - original_max_position, - RoPEFallback=Llama3RotaryEmbedding) - else: - rotary_emb = Llama3RotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - dtype, scaling_factor, low_freq_factor, high_freq_factor, - original_max_position) + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, From aecd6677f07ab493caf2ca78174e4b40c65aa163 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Tue, 22 Oct 2024 10:29:33 +0200 Subject: [PATCH 05/17] Remove if blocks smaller than bs in generate_decode_buckets (#412) With this check while running decode_block_bucket_min=128 and bs>128 it will skip buckets smaller than bs. Then during the run buckets that got skipped can be used by vllm and are being warmed-up which is causing perf drop & they are not run as hpu graphs. This change is removing said check. --- vllm/worker/hpu_model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 785337478468f..888a9a9da942c 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -202,8 +202,6 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, last_bucket = round_up(max_blocks, bstep) for bs in bs_buckets: for blocks in block_buckets: - if blocks < bs: - continue if blocks > last_bucket: break buckets.append((bs, blocks)) From 0cf52619196d7f8f277aafe4af48624291b42d12 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Tue, 22 Oct 2024 10:43:03 +0200 Subject: [PATCH 06/17] Remove CPU sync before Sampler (#414) Currently before each Sampler call we have a CPU sync, which causes a host gap: image This PR is removing that sync, so the host gap is no longer visible: image NOTE: class `ApplyToppTopkScalar` still has some CPU syncs inside. It means that the biggest gain will be observed in the scenario without `top_p` or `top_k` parameters. I think it is worth to investigate if we can remove the syncs from this function too. --- vllm/model_executor/layers/sampler.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 58c4940c12fb2..74c0416e4b379 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -200,13 +200,13 @@ def _init_sampling_tensors( self._do_penalties = do_penalties self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p - self._top_p_scalar = sampling_tensors.top_ps[0].item() - self._top_k_scalar = sampling_tensors.top_ks[0].item() + self._top_p_scalar = sampling_tensors.top_ps[0] + self._top_k_scalar = sampling_tensors.top_ks[0] scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar) scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar) - self._scalar_p_and_k = (scalar_p and scalar_k).item() - if self._scalar_p_and_k and self._do_top_p_top_k: - self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) + self._scalar_p_and_k = torch.logical_and(scalar_p, scalar_k) + + self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) def forward( self, @@ -266,13 +266,13 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - if self._scalar_p_and_k: - logits = self._apply_top_k_top_p_opt(logits, - self._top_p_scalar, - self._top_k_scalar) - else: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + # If we have a scalar p and k, we can use the optimized version. + logits = torch.where( + self._scalar_p_and_k, + self._apply_top_k_top_p_opt(logits, self._top_p_scalar, + self._top_k_scalar), + _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks)) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) From 3af4b6ce54ccdfc91516b335c5331045d78c99a2 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Tue, 22 Oct 2024 19:03:58 +0530 Subject: [PATCH 07/17] Remove redundant set_active_loras call during warmup (#413) CUDA uses `capture` for warmup runs and `execute_model` for actual runs. During each phase they call `set_active_loras` only once. HPU uses `execute_model` for both warmup and actual runs. Since `execute_model` already takes care of `set_active_loras` internally, the redundant call can be removed. This special handling is redundant and incorrect, as it causes out-of-bound slicing in decode phase reported in https://github.com/HabanaAI/vllm-fork/issues/405. This PR removes special handling of `set_active_loras` function call from warmup runs and resolves the issue in https://github.com/HabanaAI/vllm-fork/issues/405. --- vllm/worker/hpu_model_runner.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 888a9a9da942c..f2875194e93a0 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1354,12 +1354,6 @@ def warmup_scenario(self, ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs or is_pt_profiler_run else 1 - if self.lora_config and not is_lora_profile_run: - lora_mapping = LoRAMapping( - **dict(index_mapping=[0] * batch_size * seq_len, - prompt_mapping=[0] * batch_size * seq_len, - is_prefill=is_prompt)) - self.set_active_loras(set(), lora_mapping) if is_prompt: seqs = [ self.create_dummy_seq_group_metadata( From 892c09026833d85dd0fc408ee7d5e3c1b461394c Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:03:53 +0530 Subject: [PATCH 08/17] Change profile Run batch based on max_seq_len (#415) Changes the profile_run batches based on the max sequence length. This avoids padding during prepare_prompt; thus avoiding breaking constraints based on batch_size * seq_len <= max_num_batch_tokens. Current logic for profile_run max_batch_size takes precedence. e.g. - > max_batch_size = 256, max_num_batch_tokens = 2048, block_size = 128, max_seq_len = 1024 with current logic max_seq_len is updated as 8; however in **prepare_prompt** seq_len is padded to 128, thus getting batch_size * seq_len as 256 * 128 > max_num_batch_tokens; thus violating the above mentioned constraint with the updated logic, we calculate max_batch_size as 2, this avoids the padding at **prepare_prompt**, thus keeping the constraints in place. Fixes: https://github.com/HabanaAI/vllm-fork/issues/405 --- vllm/worker/hpu_model_runner.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f2875194e93a0..e8e76f6ab67ef 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1306,10 +1306,8 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1] - max_seq_len = min( - self.bucketing_global_state.prompt_seq_bucket_cfg[-1], - self.max_num_batched_tokens // max_batch_size) + max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] + max_batch_size = self.max_num_batched_tokens // max_seq_len self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) From 7f58ad1583a2d11a07705ba9d88bda54c8f19843 Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Wed, 23 Oct 2024 15:04:08 +0200 Subject: [PATCH 09/17] Add support for various softmax normalization options (#420) Supporting PR for https://github.com/HabanaAI/vllm-hpu-extension/pull/14 --- requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 1 + vllm/attention/ops/hpu_paged_attn.py | 1 + vllm/worker/hpu_model_runner.py | 9 ++++++++- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 1a583974be151..7cefa4e631fa8 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@c2801bb diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index a8f4b09b67274..f4674cedf01ce 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -223,6 +223,7 @@ def forward( block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_scales=attn_metadata.block_scales, + block_groups=attn_metadata.block_groups, scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 4c0fb2a628361..603d3959377c4 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -21,6 +21,7 @@ class HPUPagedAttentionMetadata: block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index e8e76f6ab67ef..382a0abb21240 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -907,6 +907,7 @@ def _prepare_prompt( block_indices=block_indices, block_offsets=block_offsets, block_scales=None, + block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, num_prefills=real_num_seqs, @@ -1028,6 +1029,8 @@ def _prepare_decode( len(block_list), self.bucketing_global_state.decode_block_bucket_cfg) block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) + block_groups = pad_list(block_mapping, block_bucket_size, + len(block_tables)) block_mapping = pad_list(block_mapping, block_bucket_size, -1) block_usage = pad_list(block_usage, block_bucket_size, 1) block_scales = pad_list(block_scales, block_bucket_size, 0.0) @@ -1038,6 +1041,9 @@ def _prepare_decode( block_mapping = torch.tensor(block_mapping, dtype=torch.long, device=self.device) + block_groups = torch.tensor(block_groups, + dtype=torch.long, + device=self.device) block_usage = torch.tensor(block_usage, dtype=self.model_config.dtype, device=self.device) @@ -1060,6 +1066,7 @@ def _prepare_decode( block_indices=block_indices, block_offsets=block_offsets, block_scales=block_scales, + block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, num_prefills=0, @@ -1271,7 +1278,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_scales' + 'block_offsets', 'block_scales', 'block_groups' ]) return attention_metadata From f603353e2057808f46c86395334ef507fd2bb351 Mon Sep 17 00:00:00 2001 From: Artur Fierka Date: Fri, 25 Oct 2024 08:46:30 +0200 Subject: [PATCH 10/17] Update README_GAUDI about fp8 calibration procedure (#423) --- README_GAUDI.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README_GAUDI.md b/README_GAUDI.md index b9c744bd9e23f..6dd7837116d52 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -282,6 +282,10 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM - `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used, `1` is default - `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs +# Quantization and FP8 model calibration process + +The FP8 model calibration procedure has been described as a part of [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + # Troubleshooting: Tweaking HPU Graphs If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following the below: From a5136ec1fd78c2fb640cd89a48b479472bd5666a Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Fri, 25 Oct 2024 09:58:38 +0200 Subject: [PATCH 11/17] Set vllm-hpu-extension to 341a77f (#428) --- requirements-hpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 7cefa4e631fa8..20f4dc74a3955 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@c2801bb +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@341a77f From 5b7f685c3416d735c8866ad4f3282fe86b978739 Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Fri, 25 Oct 2024 14:35:13 +0200 Subject: [PATCH 12/17] Contiguous PA (#424) Contiguous cache fetching to avoid using costly gather operation. Requires changes in vllm-hpu-extension (https://github.com/HabanaAI/vllm-hpu-extension/pull/17) to work. Introduces redundant calculations in decoding phase. In all tested cases improves performance over the entire run (5-12%). For even better performance cache defragmentation is required. Only compatible with v2-block-manager. --- requirements-hpu.txt | 2 +- vllm/worker/hpu_model_runner.py | 54 +++++++++++++++++---------------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 20f4dc74a3955..4719639da6188 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@341a77f +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@6cb6e19 diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 382a0abb21240..4be0dc1a1abd8 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -199,10 +199,11 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, bs_buckets = warmup_range(bs_bucket_config) block_buckets = warmup_range(blocks_bucket_config) bmin, bstep, bmax = blocks_bucket_config - last_bucket = round_up(max_blocks, bstep) + last_bucket = max_blocks for bs in bs_buckets: for blocks in block_buckets: if blocks > last_bucket: + buckets.append((bs, last_bucket)) break buckets.append((bs, blocks)) return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) @@ -1002,39 +1003,40 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) - blocks_used = [len(bt) for bt in block_tables if bt] - block_list = [] - block_scales = [] + block_list = list(itertools.chain(*block_tables)) + + max_idx = max(block_list) + max_blocks = max(max_idx + 1, len(block_list)) + block_bucket_size = find_bucket( + max_blocks, self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = min(block_bucket_size, + self.cache_config.num_gpu_blocks) + + block_mapping: List[Union[None, int]] = [None] * block_bucket_size + block_usage: List[Union[None, int]] = [None] * block_bucket_size + block_scales: List[Union[None, float]] = [None] * block_bucket_size + for i, bt in enumerate(block_tables): - block_list.extend(bt) - blocks_in_group = len(bt) - if blocks_in_group > 0: + if bt: + blocks_in_group = len(bt) scale = 1.0 / blocks_in_group - block_scales.extend([scale] * blocks_in_group) + for b in bt: + if block_mapping[b] is None: + block_mapping[b] = i + block_usage[b] = self.block_size + block_scales[b] = scale - block_mapping_nested: List[List[int]] = [ - [i] * b_u for i, b_u in enumerate(blocks_used) - ] - block_mapping: List[int] = list( - itertools.chain.from_iterable(block_mapping_nested)) + block_mapping = [b if b is not None else -1 for b in block_mapping] + block_scales = [b if b is not None else 0.0 for b in block_scales] - last_block = [ - sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) - ] - block_usage = [[self.block_size] * (b_u - 1) + [lb] - for b_u, lb in zip(blocks_used, last_block)] - block_usage = list(itertools.chain(*block_usage)) + for bt, sl in zip(block_tables, slot_mapping): + if bt: + block_usage[bt[-1]] = sl[-1] % self.block_size + 1 + block_usage = [u if u is not None else 1 for u in block_usage] - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) block_groups = pad_list(block_mapping, block_bucket_size, len(block_tables)) - block_mapping = pad_list(block_mapping, block_bucket_size, -1) - block_usage = pad_list(block_usage, block_bucket_size, 1) - block_scales = pad_list(block_scales, block_bucket_size, 0.0) - block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) From e3ae2ebffcb233a67c63ab1fe9acab3dad1d53dc Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Fri, 25 Oct 2024 14:49:54 +0200 Subject: [PATCH 13/17] Revert "Contiguous PA" (#432) Reverts HabanaAI/vllm-fork#424 --- requirements-hpu.txt | 2 +- vllm/worker/hpu_model_runner.py | 54 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 4719639da6188..20f4dc74a3955 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@6cb6e19 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@341a77f diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 4be0dc1a1abd8..382a0abb21240 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -199,11 +199,10 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, bs_buckets = warmup_range(bs_bucket_config) block_buckets = warmup_range(blocks_bucket_config) bmin, bstep, bmax = blocks_bucket_config - last_bucket = max_blocks + last_bucket = round_up(max_blocks, bstep) for bs in bs_buckets: for blocks in block_buckets: if blocks > last_bucket: - buckets.append((bs, last_bucket)) break buckets.append((bs, blocks)) return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) @@ -1003,40 +1002,39 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) - block_list = list(itertools.chain(*block_tables)) - - max_idx = max(block_list) - max_blocks = max(max_idx + 1, len(block_list)) - block_bucket_size = find_bucket( - max_blocks, self.bucketing_global_state.decode_block_bucket_cfg) - block_bucket_size = min(block_bucket_size, - self.cache_config.num_gpu_blocks) - - block_mapping: List[Union[None, int]] = [None] * block_bucket_size - block_usage: List[Union[None, int]] = [None] * block_bucket_size - block_scales: List[Union[None, float]] = [None] * block_bucket_size - + blocks_used = [len(bt) for bt in block_tables if bt] + block_list = [] + block_scales = [] for i, bt in enumerate(block_tables): - if bt: - blocks_in_group = len(bt) + block_list.extend(bt) + blocks_in_group = len(bt) + if blocks_in_group > 0: scale = 1.0 / blocks_in_group - for b in bt: - if block_mapping[b] is None: - block_mapping[b] = i - block_usage[b] = self.block_size - block_scales[b] = scale + block_scales.extend([scale] * blocks_in_group) - block_mapping = [b if b is not None else -1 for b in block_mapping] - block_scales = [b if b is not None else 0.0 for b in block_scales] + block_mapping_nested: List[List[int]] = [ + [i] * b_u for i, b_u in enumerate(blocks_used) + ] + block_mapping: List[int] = list( + itertools.chain.from_iterable(block_mapping_nested)) - for bt, sl in zip(block_tables, slot_mapping): - if bt: - block_usage[bt[-1]] = sl[-1] % self.block_size + 1 - block_usage = [u if u is not None else 1 for u in block_usage] + last_block = [ + sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) + ] + block_usage = [[self.block_size] * (b_u - 1) + [lb] + for b_u, lb in zip(blocks_used, last_block)] + block_usage = list(itertools.chain(*block_usage)) + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) block_groups = pad_list(block_mapping, block_bucket_size, len(block_tables)) + block_mapping = pad_list(block_mapping, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + block_scales = pad_list(block_scales, block_bucket_size, 0.0) + block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) From 93609a2a54aba5088f6fc94c49edad90ff4c3aa0 Mon Sep 17 00:00:00 2001 From: Tomasz Pawlowski Date: Fri, 25 Oct 2024 15:30:11 +0200 Subject: [PATCH 14/17] Enable Dynamic MoE for Mixtral on 1.19.0 (#425) Move Dynamic MoE implementation to habana_main. It was previously implemented for 1.18, but had to be modified as ops have been moved to [github.com/HabanaAI/vllm-hpu-extension](https://github.com/HabanaAI/vllm-hpu-extension). Works with bf16, uses static (legacy) mode when running with quantization. Related PRs: - https://github.com/HabanaAI/vllm-fork/pull/303 - https://github.com/HabanaAI/vllm-hpu-extension/pull/13 ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--- requirements-hpu.txt | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 22 +++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 20f4dc74a3955..4019950062efe 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@341a77f +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@341a77f \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 457450cda2ce6..8f6bdaa7ab44a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -226,9 +226,13 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function - if current_platform.is_hpu(): - from vllm_hpu_extension.ops import StaticFusedMOE - self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts) + if is_hpu: + from vllm_hpu_extension.ops import DynamicFusedMOE, StaticFusedMOE + + from vllm.model_executor.layers.quantization.inc import INCConfig + selected_fused_moe = (StaticFusedMOE if isinstance( + quant_config, INCConfig) else DynamicFusedMOE) + self.hpu_static_fused_moe = selected_fused_moe(self.num_experts) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -321,8 +325,10 @@ def _load_w13(self, expert_data.copy_(loaded_weight) if is_hpu: - self.hpu_static_fused_moe.w13_list[expert_id].set_weight( - orig_exp_data) + from vllm_hpu_extension.ops import StaticFusedMOE + if isinstance(self.hpu_static_fused_moe, StaticFusedMOE): + self.hpu_static_fused_moe.w13_list[expert_id].set_weight( + orig_exp_data) def _load_w2(self, expert_data: torch.Tensor, @@ -341,8 +347,10 @@ def _load_w2(self, # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) if is_hpu: - self.hpu_static_fused_moe.w2_list[expert_id].set_weight( - expert_data) + from vllm_hpu_extension.ops import StaticFusedMOE + if isinstance(self.hpu_static_fused_moe, StaticFusedMOE): + self.hpu_static_fused_moe.w2_list[expert_id].set_weight( + expert_data) def _load_single_value(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int): From 3a55e77bc7d99b2ccbe3eb738fa9e8648dbf7f4e Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Mon, 28 Oct 2024 14:16:58 +0530 Subject: [PATCH 15/17] Support long contexts with LoRA (#418) This PR enables long-contexts support with LoRA --- tests/lora/test_long_context_hpu.py | 304 ++++++++++++++++++++++++++++ vllm/lora/punica.py | 23 ++- vllm/worker/hpu_model_runner.py | 29 ++- 3 files changed, 346 insertions(+), 10 deletions(-) create mode 100644 tests/lora/test_long_context_hpu.py diff --git a/tests/lora/test_long_context_hpu.py b/tests/lora/test_long_context_hpu.py new file mode 100644 index 0000000000000..33250edde00d3 --- /dev/null +++ b/tests/lora/test_long_context_hpu.py @@ -0,0 +1,304 @@ +import ast +from typing import List, Optional, Tuple + +import numpy as np +import pytest + +import vllm +from vllm import SamplingParams +from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding) + +from .data.long_context_test_data import prompts_and_responses + +context_len_to_scaling_factor = { + "16k": 4, + "32k": 8, +} + +# We use the same sampling params for all requests +sampling_params = SamplingParams( + temperature=0, + max_tokens=100, +) + + +def _create_lora_request(lora_id, long_context_infos): + context_len = long_context_infos[lora_id]["context_length"] + scaling_factor = context_len_to_scaling_factor[context_len] + return LoRARequest(f'{context_len}_{lora_id}', lora_id, + long_context_infos[lora_id]["lora"], None, + 4096 * scaling_factor) + + +def evaluate_json_response(model_response, golden_response): + """Evaluates the model response against the golden response. + + Returns a score between 0 and 1, where 1 is a perfect match and 0 is no + match. The score quantifies how well the model is able to extract the + golden JSON from the long context. + """ + try: + model_response = ast.literal_eval(model_response) + except Exception as e: + raise ValueError( + f"Model response is not a valid JSON. Expected {golden_response}, " + f"got {model_response}") from e + + # Normally, we would flatten the dictionary and compare the values, but in + # this case, we know that the dictionary is only 2 levels deep + positive_values = 0 + total_values = 0 + # We look at all the attributes of the person that we are extracting a + # biography of and copmare them to the golden response + for person_attribute, person_attribute_value in golden_response.items(): + if person_attribute in model_response: + if isinstance(person_attribute_value, dict): + for (sub_attribute, + sub_attribute_value) in person_attribute_value.items(): + total_values += 1 + if sub_attribute in model_response[ + person_attribute] and model_response[ + person_attribute][ + sub_attribute] == sub_attribute_value: + positive_values += 1 + else: + total_values += 1 + if model_response[person_attribute] == person_attribute_value: + positive_values += 1 + else: + # We count a missing sub-dict as a single missed value. + total_values += 1 + + # Return a score between 0 and 1 + return positive_values / total_values + + +def generate( + llm: vllm.LLM, + inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], +): + prompts, sampling_param, lora_request = inputs + outputs = llm.generate(prompts, sampling_param, lora_request=lora_request) + return outputs[0].outputs[0].text.strip() + + +def batched_generate( + llm: vllm.LLM, + inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], +): + for input in inputs: + prompt, sampling_param, lora_req = input + # Add requests to the engine and run the engine + llm._validate_and_add_requests(prompt, + sampling_param, + lora_request=lora_req, + prompt_adapter_request=None) + + outputs = llm._run_engine(use_tqdm=True) + return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] + + +@pytest.fixture(scope="module") +def lora_llm(long_context_infos): + scaling_factors = [ + context_len_to_scaling_factor[info["context_length"]] + for info in long_context_infos.values() + ] + + llm = vllm.LLM( + "meta-llama/Llama-2-13b-chat-hf", + enable_lora=True, + max_num_seqs=16, + max_loras=2, + long_lora_scaling_factors=tuple(scaling_factors), + max_num_batched_tokens=4096 * 8, + tensor_parallel_size=1, + enforce_eager=True, # TODO Remove after SW-205153 is fixed + dtype="bfloat16", + disable_async_output_proc=True, # TODO Remove after SW-204469 is fixed. + distributed_executor_backend="mp") + yield llm + del llm + + +def test_rotary_emb_replaced(dist_init): + """Verify rotary emb in all the layers are replaced""" + from vllm.engine.arg_utils import EngineArgs + from vllm.platforms import current_platform + if current_platform.is_hpu(): + from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunner + else: + from vllm.worker.model_runner import ModelRunner + engine_args = EngineArgs("meta-llama/Llama-2-7b-hf", + long_lora_scaling_factors=(4.0, ), + enable_lora=True) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + model_runner.load_model() + rotary_emb_count = 0 + model = model_runner.model.model if current_platform.is_hpu( + ) else model_runner.model + for module_name, module in model.named_modules(remove_duplicate=False): + if "rotary_emb" in module_name: + if "base_layer" not in module_name: + rotary_emb_count += 1 + assert isinstance(module, LinearScalingRotaryEmbeddingWithLora) + else: + assert isinstance(module, LinearScalingRotaryEmbedding) + # Llama 2 has 32 layers. + assert rotary_emb_count == 32 + + +@pytest.mark.skip_global_cleanup +def test_batched_rope_kernel(lora_llm, long_context_infos): + """We test the batched kernel by comparing the results of batched an + non-batched generation. + """ + # Create non batched results first to compare against batched results + non_batched_results: List[str] = [] + + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_prompt = (prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + lora_output = generate(lora_llm, lora_prompt) + non_batched_results.append(lora_output) + + # Create batched results + # Each element of the batch must be + # (prompt, prompt_sampling_params, prompt_lora_request) + batched_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]] = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for non_batched, batched in zip(non_batched_results, batched_results): + assert non_batched == batched, ( + "Non batched and batched results should be the " + f"same:\n{batched}\n{non_batched}") + + +@pytest.mark.skip_global_cleanup +def test_self_consistency(lora_llm, long_context_infos): + """We test consistency of the batched kernel by permuting batched + inputs and comparing the results to the non-permuted batched results. + """ + num_loras = len(long_context_infos) + + # Create results in order of long_context_infos + batched_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]] = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + batched_results = batched_generate(lora_llm, batched_prompts) + + permutation = np.random.default_rng(seed=42).permutation(num_loras) + + # Create results in random order of permutation + batched_prompts = [] + for i in permutation: + lora_id, info = list(long_context_infos.items())[i] + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + permutated_batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for i in range(num_loras): + assert batched_results[i] == permutated_batched_results[ + permutation[i]], ( + f"Results should be the same:\n{batched_results[i]}" + f"\n{permutated_batched_results[permutation[i]]}") + + +@pytest.mark.skip_global_cleanup +def test_quality(lora_llm, long_context_infos): + """We test the quality of the answers given by the LoRA model by + comparing the generated text to the merged model's outputs. + + This is effectively a mini-benchmark over four prompts. + If this test fails, this indicates that the quality of the LoRA model + is suboptimal compared to the merged model. For example, if the model + does not output valid dictionaries, this test will fail. + + If needed for testing, the merged versions of the models are available + as part of the `conftest`. + + The test is expected to run for about 1 minute on a p4de.24xlarge + instance. + """ + scores: List[float] = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + for prompt_and_response in prompts_and_responses[context_len]: + lora_prompt = (prompt_and_response["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + response = generate(lora_llm, lora_prompt) + golden_answer = prompt_and_response["golden_answer"] + score = evaluate_json_response(response, golden_answer) + scores.append(score) + assert score > 0.3, ("Quality of the answer is not good enough. " + f"Expected {golden_answer}, got {response}") + assert np.mean(scores) > 0.5 + + +@pytest.mark.skip_global_cleanup +def test_max_len(lora_llm, long_context_infos): + """Test that we raise an ValueError when the input of a given LoRA + model exceeds the maximum length.""" + # Since each LoRA model has a different maximum length, we need to + # test each one separately + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_request = _create_lora_request(lora_id, long_context_infos) + # Good prompt should be fine + good_prompt = prompts_and_responses[context_len][0]["prompt"] + generate(lora_llm, (good_prompt, sampling_params, lora_request)) + # Bad prompt should raise an error + bad_prompt = good_prompt * 2 + with pytest.raises(ValueError): + generate(lora_llm, (bad_prompt, sampling_params, lora_request)) + + # Also test batched + batched_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]] = [] + for lora_id_with_bad_inputs in long_context_infos: + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"] * + (2 if lora_id == lora_id_with_bad_inputs else 1), + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + # Turn good prompt into bad prompt inside of batched prompts + + with pytest.raises(ValueError): + batched_generate(lora_llm, batched_prompts) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index f22f92b6fe64b..1fdd15df99c19 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -103,10 +103,15 @@ def convert_mapping( embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None + + from vllm.platforms import current_platform if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=get_device(), - dtype=torch.long) + if current_platform.is_hpu(): + long_lora_offsets_list: List[int] = [] + else: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=get_device(), + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -119,10 +124,18 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: - assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset + if current_platform.is_hpu(): + long_lora_offsets_list.append(lora_offset) + else: + assert long_lora_offsets is not None + long_lora_offsets[i] = lora_offset + + if long_lora_context and current_platform.is_hpu(): + long_lora_offsets = torch.tensor(long_lora_offsets_list, + device=get_device(), + dtype=torch.long) indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 382a0abb21240..b5100491c4135 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -37,6 +37,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sampling_params import SamplingParams @@ -649,12 +650,30 @@ def load_model(self) -> None: assert hasattr( self.model, "embedding_padding_modules" ), "Model does not have embedding_padding_modules" + + if supports_multimodal(self.model): + logger.warning( + "Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, self.lora_config, self.device, + self.vocab_size, + self.lora_config, + self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.model_config.quantization == 'inc': @@ -1314,7 +1333,8 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] - max_batch_size = self.max_num_batched_tokens // max_seq_len + max_batch_size = min(self.max_num_batched_tokens // max_seq_len, + self.scheduler_config.max_num_seqs) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) @@ -1333,7 +1353,6 @@ def warmup_scenario(self, f"bs{batch_size}_" f"seq{seq_len}_" f"graphs{'T' if use_graphs else 'F'}") - max_num_seqs = self.scheduler_config.max_num_seqs # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -1355,7 +1374,7 @@ def warmup_scenario(self, dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) + for idx in range(batch_size) ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs or is_pt_profiler_run else 1 From 4fd5c4c9601c82bd9240ba974310b20c9535d11c Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Mon, 28 Oct 2024 10:47:52 +0100 Subject: [PATCH 16/17] Add HPU specific changes to benchmark_latency.py (#436) Add support for HPU FP8 in `benchmark_latency.py` script. Limit `max_num_seqs` based on the `batch_size` as there will be no more requests. --- benchmarks/benchmark_latency.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 79a48b2a1a845..30373b119a2ca 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -47,6 +47,7 @@ def main(args: argparse.Namespace): distributed_executor_backend=args.distributed_executor_backend, otlp_traces_endpoint=args.otlp_traces_endpoint, enable_prefix_caching=args.enable_prefix_caching, + max_num_seqs=args.batch_size, ) sampling_params = SamplingParams( @@ -179,7 +180,7 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default="auto", help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' From 2a38e6f575f86e0853d2b057cb004c48109f8b77 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Mon, 28 Oct 2024 10:21:19 -0700 Subject: [PATCH 17/17] sarkar/Add htrandom generator for hpu (#246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To repro: start server: `VLLM_SKIP_WARMUP=true python -m vllm.entrypoints.openai.api_server` send a request (this works fine): ``` curl -v http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{"model": "facebook/opt-125m","prompt": "The future of AI is ","max_tokens": 100,"temperature": 0}' ``` if request has a seed it fails: ``` curl -v http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{"model": "facebook/opt-125m","prompt": "The future of AI is ","max_tokens": 100,"temperature": 0, "seed" : 37}' ``` Failure happens here: [vllm-fork/vllm/model_executor/sampling_metadata.py at habana_main · HabanaAI/vllm-fork](https://github.com/HabanaAI/vllm-fork/blob/habana_main/vllm/model_executor/sampling_metadata.py#L220) ``` if sampling_params.seed is not None: seq_group_metadata.state.generator = torch.Generator( device=device).manual_seed(sampling_params.seed) ``` `RuntimeError: Device type HPU is not supported for torch.Generator() api.` This PR fixes above issue by using htrandom [Intel Gaudi PyTorch Python API (habana_frameworks.torch) — Gaudi Documentation 1.17.1 documentation](https://docs.habana.ai/en/latest/PyTorch/Reference/Python_Packages.html?highlight=htrandom#random-number-generator-apis) --- vllm/model_executor/sampling_metadata.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 84f35f75a0c32..d4a8024095286 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -4,6 +4,7 @@ import torch +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) @@ -266,8 +267,14 @@ def _prepare_seq_groups( if seq_group_metadata.is_prompt: if sampling_params.seed is not None: - generator = torch.Generator(device=device).manual_seed( - sampling_params.seed) + if current_platform.is_hpu(): + import habana_frameworks.torch.hpu.random as htrandom + generator = \ + htrandom.default_generators[ + 0].manual_seed(sampling_params.seed) + else: + generator = torch.Generator(device=device).manual_seed( + sampling_params.seed) if generators is not None: generators[seq_group_metadata.request_id] = generator