diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index a4204094106c0..6888caaf02bc4 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -241,6 +241,7 @@ def forward( block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, + block_scales=attn_metadata.block_scales, scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 7bce1a266cf22..15d7e72aafe26 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -20,6 +20,7 @@ class HPUPagedAttentionMetadata: block_usage: Optional[torch.Tensor] block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/config.py b/vllm/config.py index 5499b349bcfc8..67a4ec0761cc3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -940,6 +940,9 @@ class SchedulerConfig: a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. + max_num_prefill_seqs: Maximum number of prefill sequences to be + processed in a single iteration. Used only with padding-aware + scheduling. max_model_len: Maximum length of a sequence (including prompt and generated text). use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. @@ -963,11 +966,14 @@ class SchedulerConfig: when SPMD worker architecture is enabled. I.e., VLLM_USE_RAY_SPMD_WORKER=1 policy: The scheduling policy to use. "fcfs" (default) or "priority". + use_padding_aware_scheduling: If True, scheduler will consider padded + tokens in prefill. """ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, + max_num_prefill_seqs: Optional[int], max_model_len: int, use_v2_block_manager: bool = True, num_lookahead_slots: int = 0, @@ -979,7 +985,8 @@ def __init__(self, num_scheduler_steps: int = 1, multi_step_stream_outputs: bool = False, send_delta_data: bool = False, - policy: str = "fcfs") -> None: + policy: str = "fcfs", + use_padding_aware_scheduling=False) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: if num_scheduler_steps > 1: @@ -1018,6 +1025,7 @@ def __init__(self, self.max_num_batched_tokens) self.max_num_seqs = max_num_seqs + self.max_num_prefill_seqs = max_num_prefill_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots @@ -1029,6 +1037,7 @@ def __init__(self, self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data self.policy = policy + self.use_padding_aware_scheduling = use_padding_aware_scheduling self._verify_args() def _verify_args(self) -> None: @@ -1059,6 +1068,13 @@ def _verify_args(self) -> None: "num_scheduler_steps " f"({self.num_scheduler_steps}) must be greater than or " "equal to 1.") + if self.max_num_prefill_seqs is not None \ + and not self.use_padding_aware_scheduling: + raise ValueError("max_num_prefill_seqs can be only " + "used with padding-aware-scheduling. ") + if self.use_padding_aware_scheduling and self.chunked_prefill_enabled: + raise ValueError("Padding-aware scheduling currently " + "does not work with chunked prefill ") if (not self.use_v2_block_manager \ and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1f0a121711db5..1c69c72933b79 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, @@ -101,6 +102,94 @@ def num_curr_seqs(self): return self._num_curr_seqs +@dataclass +class PaddingAwareSchedulingBudget(SchedulingBudget): + max_num_prefill_seqs: Optional[int] = None + _prefill_request_ids_max_seq_lens: Dict[str, + int] = field(default_factory=dict) + _max_seq_len: int = 0 + _num_curr_prefill_seqs: int = 0 + + def _generic_padding_fn(self, batch_size, max_seq_len) -> int: + return batch_size * max_seq_len + + def _hpu_padding_fn(self, batch_size, max_seq_len): + from vllm.worker.hpu_model_runner import (HPUBucketingGlobalState, + find_bucket) + padded_bs = batch_size + padded_seq = max_seq_len + + hpu_bucketing_global_state = HPUBucketingGlobalState() + + bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg + if bs_cfg is not None: + padded_bs = find_bucket(batch_size, bs_cfg) + else: + logger.warning( + "prompt_bs_bucket_cfg was not set! Using unpadded batch size.") + seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg + if seq_cfg is not None: + padded_seq = find_bucket(max_seq_len, seq_cfg) + else: + logger.warning("prompt_seq_bucket_cfg was not set! " + "Using unpadded sequence length.") + return padded_bs * padded_seq + + def _padding_fn_selector(self): + if current_platform.is_hpu(): + return self._hpu_padding_fn + return self._generic_padding_fn + + def _maybe_update_max_seq_len(self, + new_seq_max_seq_len: Optional[int] = None): + if new_seq_max_seq_len is not None \ + and new_seq_max_seq_len > self._max_seq_len: + self._max_seq_len = new_seq_max_seq_len + return + self._max_seq_len = max( + self._prefill_request_ids_max_seq_lens.values()) + + def add_prefill_seqs(self, req_id, num_curr_prefill_seqs, max_seq_len): + self._prefill_request_ids_max_seq_lens[req_id] = max_seq_len + self._num_curr_prefill_seqs += num_curr_prefill_seqs + self._maybe_update_max_seq_len(max_seq_len) + + def subtract_prefill_seqs(self, req_id, num_curr_prefill_seqs): + if req_id in self._prefill_request_ids_max_seq_lens: + popped_seq_len = self._prefill_request_ids_max_seq_lens.pop(req_id) + self._num_curr_prefill_seqs -= num_curr_prefill_seqs + if popped_seq_len == self._max_seq_len: + self._maybe_update_max_seq_len() + + def can_schedule(self, + *args, + num_new_tokens: int, + num_new_seqs: int, + is_prefill: bool = False, + max_seq_len: int = 0): + can_parent_schedule = super().can_schedule( + *args, num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs) + if not can_parent_schedule or not is_prefill: + return can_parent_schedule + new_batch_size = self._num_curr_prefill_seqs + num_new_seqs + new_max_seq_len = max(max(self._max_seq_len, max_seq_len), 1) + padding_fn = self._padding_fn_selector() + num_new_padded_tokens = padding_fn(new_batch_size, new_max_seq_len) + result = num_new_padded_tokens <= self.token_budget + if self.max_num_prefill_seqs is not None and result: + result = self._num_curr_prefill_seqs + num_new_seqs \ + <= self.max_num_prefill_seqs + return result + + @property + def max_seq_len(self): + return self._max_seq_len + + @property + def num_curr_prefill_seqs(self): + return self._num_curr_prefill_seqs + + @dataclass class ScheduledSequenceGroup: # A sequence group that's scheduled. @@ -938,9 +1027,18 @@ def _schedule_prefills( continue num_new_seqs = seq_group.get_max_num_running_seqs() + max_prefill_seq_len = None + can_schedule_kwargs = { + 'num_new_tokens': num_new_tokens, + 'num_new_seqs': num_new_seqs + } + if self.scheduler_config.use_padding_aware_scheduling: + max_prefill_seq_len = max( + [seq.get_num_new_tokens() for seq in seq_group.get_seqs()]) + can_schedule_kwargs['is_prefill'] = True + can_schedule_kwargs['max_seq_len'] = max_prefill_seq_len if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + or not budget.can_schedule(**can_schedule_kwargs)): break # Can schedule this request. @@ -971,6 +1069,10 @@ def _schedule_prefills( token_chunk_size=num_new_tokens)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_seqs(seq_group.request_id, num_new_seqs) + if self.scheduler_config.use_padding_aware_scheduling: + assert isinstance(budget, PaddingAwareSchedulingBudget) + budget.add_prefill_seqs(seq_group.request_id, num_new_seqs, + max_prefill_seq_len) # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) @@ -992,10 +1094,18 @@ def _schedule_default(self) -> SchedulerOutputs: be swapped or preempted. """ # Include running requests to the budget. - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) + budget: SchedulingBudget + if self.scheduler_config.use_padding_aware_scheduling: + budget = PaddingAwareSchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + max_num_prefill_seqs=self.scheduler_config.max_num_prefill_seqs + ) + else: + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) # Make sure we include num running seqs before scheduling prefill, # so that we don't schedule beyond max_num_seqs for prefill. for seq_group in self.running: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3c9f3d4fe4ab3..cdf1401816800 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -113,11 +113,13 @@ class EngineArgs: enable_prefix_caching: bool = False disable_sliding_window: bool = False use_v2_block_manager: bool = True + use_padding_aware_scheduling: bool = False swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 + max_num_prefill_seqs: Optional[int] = None max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False revision: Optional[str] = None @@ -391,6 +393,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='Use BlockSpaceMangerV2. By default this is set to True. ' 'Set to False to use BlockSpaceManagerV1') + parser.add_argument( + '--use-padding-aware-scheduling', + default=EngineArgs.use_padding_aware_scheduling, + action='store_true', + help=('Use padding-aware scheduling. If True, the scheduler ' + 'will consider padded tokens in prefill. ' + 'By default this is set to False. ')) parser.add_argument( '--num-lookahead-slots', type=int, @@ -445,6 +454,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.max_num_seqs, help='Maximum number of sequences per iteration.') + parser.add_argument( + '--max-num-prefill-seqs', + type=int, + default=EngineArgs.max_num_prefill_seqs, + help=('Maximum number of prefill sequences per ' + 'iteration. Can be used only with padding-aware ' + 'scheduling. Must be <= max_num_seqs.')) parser.add_argument( '--max-logprobs', type=int, @@ -1036,6 +1052,7 @@ def create_engine_config(self) -> EngineConfig: scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, + max_num_prefill_seqs=self.max_num_prefill_seqs, max_model_len=model_config.max_model_len, use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=num_lookahead_slots, @@ -1049,7 +1066,7 @@ def create_engine_config(self) -> EngineConfig: send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, - ) + use_padding_aware_scheduling=self.use_padding_aware_scheduling) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 9618585c8acb0..090f95e6e892c 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) @@ -75,26 +75,14 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers): - if device == 'hpu' or is_fake_hpu(): - key_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - device=device) - value_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - device=device) - kv_layer = (key_cache, value_cache) - kv_cache.append(kv_layer) - else: - # null block in CpuGpuBlockAllocator requires at least that - # block to be zeroed-out. - # We zero-out everything for simplicity. - dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else \ - self.dtype - kv_cache.append( - torch.zeros(kv_cache_shape, - dtype=dtype, - pin_memory=pin_memory, - device=device)) + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + kv_cache.append( + torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device)) return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 89ee432a459f3..0054d27966eb4 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -13,6 +13,7 @@ import os import time from array import array +from dataclasses import dataclass, field from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -64,6 +65,26 @@ LORA_WARMUP_RANK = 8 +class Singleton(type): + _instances: Dict[type, object] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, + cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class HPUBucketingGlobalState(metaclass=Singleton): + prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_buckets: List[Tuple[int, int]] = field(init=False) + decode_buckets: List[Tuple[int, int]] = field(init=False) + + def subtuple(obj: object, typename: str, to_copy: List[str], @@ -315,9 +336,19 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - block_mapping = torch.nn.functional.one_hot( - metadata.block_mapping.to(torch.long), - num_classes=batch_size).to(dtype) + if is_fake_hpu(): + # Unfortunately one_hot on CPU doesn't handle + # out of bounds classes. We need to mask those + # values manually + oob_values = metadata.block_mapping.lt(0) + block_mapping = metadata.block_mapping.masked_fill(oob_values, 0) + block_mapping = torch.nn.functional.one_hot(block_mapping, + num_classes=batch_size) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + else: + block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + num_classes=batch_size) + block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata @@ -549,6 +580,9 @@ def __init__( self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs + self.max_num_prefill_seqs = self.scheduler_config.max_num_prefill_seqs \ + if self.scheduler_config.max_num_prefill_seqs is not None \ + else self.max_num_seqs self.max_model_len = self.scheduler_config.max_model_len self.max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens @@ -576,6 +610,7 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None + self.bucketing_global_state = HPUBucketingGlobalState() self._setup_buckets() self._set_gc_threshold() @@ -687,27 +722,26 @@ def _is_valid_bucket(self, bucket): def _setup_buckets(self) -> None: align_bs = lambda x: min(self.max_num_seqs, x) - max_bucket_cfg = 64 #FIXME: The default values should be max_model_len max_prompt_seq = 1024 max_decode_seq = 2048 - self.prompt_bs_bucket_cfg = read_bucket_settings( + self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( 'prompt', 'bs', min=1, step=align_bs(32), - max=align_bs(max_bucket_cfg)) - self.decode_bs_bucket_cfg = read_bucket_settings('decode', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_seqs) - self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.decode_block_bucket_cfg = read_bucket_settings( + max=self.max_num_prefill_seqs) + self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( + 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) + self.bucketing_global_state.prompt_seq_bucket_cfg = \ + read_bucket_settings( + 'prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=max_prompt_seq) + self.bucketing_global_state.decode_block_bucket_cfg = \ + read_bucket_settings( 'decode', 'block', min=self.block_size, @@ -717,13 +751,13 @@ def _setup_buckets(self) -> None: self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.prompt_bs_bucket_cfg}, " - f"seq:{self.prompt_seq_bucket_cfg}") + f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " + f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") logger.info(msg) msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.decode_bs_bucket_cfg}, " - f"block:{self.decode_block_bucket_cfg}") + f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " + f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") logger.info(msg) def _prepare_prompt( @@ -841,7 +875,8 @@ def _prepare_prompt( assert max_query_len > 0 max_prompt_len = max( - find_bucket(max_query_len, self.prompt_seq_bucket_cfg), + find_bucket(max_query_len, + self.bucketing_global_state.prompt_seq_bucket_cfg), self.block_size) lora_ids: List[int] = [] @@ -915,6 +950,7 @@ def _prepare_prompt( block_usage=None, block_indices=block_indices, block_offsets=block_offsets, + block_scales=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, context_lens_tensor=context_lens_tensor, @@ -1011,7 +1047,15 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) blocks_used = [len(bt) for bt in block_tables if bt] - block_list = list(itertools.chain(*block_tables)) + block_list = [] + block_scales = [] + for i, bt in enumerate(block_tables): + block_list.extend(bt) + blocks_in_group = len(bt) + if blocks_in_group > 0: + scale = 1.0 / blocks_in_group + block_scales.extend([scale] * blocks_in_group) + block_mapping_nested: List[List[int]] = [ [i] * b_u for i, b_u in enumerate(blocks_used) ] @@ -1025,20 +1069,22 @@ def _prepare_decode( for b_u, lb in zip(blocks_used, last_block)] block_usage = list(itertools.chain(*block_usage)) - block_bucket_size = find_bucket(len(block_list), - self.decode_block_bucket_cfg) - block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) - block_mapping = pad_list(block_mapping, block_bucket_size, 0) - block_usage = pad_list(block_usage, block_bucket_size, 0) + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) + block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) + block_mapping = pad_list(block_mapping, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + block_scales = pad_list(block_scales, block_bucket_size, 0.0) block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) block_mapping = torch.tensor(block_mapping, - dtype=torch.int, + dtype=torch.long, device=self.device) block_usage = torch.tensor(block_usage, - dtype=torch.bfloat16, + dtype=self.model_config.dtype, device=self.device) slot_mapping = torch.tensor(slot_mapping, @@ -1047,6 +1093,10 @@ def _prepare_decode( block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, False) + block_scales = torch.tensor(block_scales, + dtype=self.model_config.dtype, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, @@ -1054,6 +1104,7 @@ def _prepare_decode( block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, + block_scales=block_scales, attn_bias=None, seq_lens_tensor=None, context_lens_tensor=None, @@ -1095,8 +1146,8 @@ def prepare_input_tensors( self.profiler.start('internal', base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else \ - self.decode_bs_bucket_cfg + bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ + if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg batch_size_padded = find_bucket(real_batch_size, bucket_cfg) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() @@ -1266,7 +1317,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', - 'block_indices', 'block_offsets' + 'block_indices', 'block_offsets', 'block_scales' ]) return attention_metadata @@ -1301,9 +1352,10 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = min(self.prompt_seq_bucket_cfg[-1], - self.max_num_batched_tokens // max_batch_size) + max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1] + max_seq_len = min( + self.bucketing_global_state.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) @@ -1517,13 +1569,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) - self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + self.bucketing_global_state.prompt_bs_bucket_cfg, + self.bucketing_global_state.prompt_seq_bucket_cfg, self.max_num_batched_tokens) - msg = ( - f"Generated {len(self.prompt_buckets)} " - f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") + msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: \ + {list(sorted(self.bucketing_global_state.prompt_buckets))}") logger.info(msg) msg = (f"Omitted {len(prompt_omitted_buckets)} " @@ -1534,16 +1588,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" logger.debug(msg) - self.decode_buckets = generate_decode_buckets( - self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, - max_blocks) + self.bucketing_global_state.decode_buckets = generate_decode_buckets( + self.bucketing_global_state.decode_bs_bucket_cfg, + self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.decode_buckets), - list(sorted(self.decode_buckets))) + len(self.bucketing_global_state.decode_buckets), + list(sorted(self.bucketing_global_state.decode_buckets))) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: - cache_size_limit = len(self.prompt_buckets) + len( - self.decode_buckets) + 1 + cache_size_limit = len( + self.bucketing_global_state.prompt_buckets) + len( + self.bucketing_global_state.decode_buckets) + 1 torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between @@ -1570,8 +1625,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) - self.warmup_all_buckets(self.decode_buckets, False, kv_caches) + self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, + True, kv_caches) + self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, + False, kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1601,12 +1658,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, kv_caches, - prompt_available_memory) + prompt_strategy, self.bucketing_global_state.prompt_buckets, + True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, - decode_available_memory) + decode_strategy, self.bucketing_global_state.decode_buckets, + False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1615,7 +1672,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, + prompt_strategy, + self.bucketing_global_state.prompt_buckets, True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1627,14 +1685,18 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, + decode_strategy, + self.bucketing_global_state.decode_buckets, False, + kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) - self.log_graph_warmup_summary(self.prompt_buckets, True, - mem_post_prompt) - self.log_graph_warmup_summary(self.decode_buckets, False, - mem_post_decode) + self.log_graph_warmup_summary( + self.bucketing_global_state.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary( + self.bucketing_global_state.decode_buckets, False, + mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 59a5adf65ebc1..752388e0d632f 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -93,7 +93,7 @@ def __init__( observability_config=observability_config) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: List[CacheEngine] + self.cache_engine: List[HPUCacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches self.hpu_cache: Optional[List[List[torch.tensor]]] = None # Torch profiler. Enabled and configured through env vars: @@ -242,8 +242,8 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ - CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) + HPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.hpu_cache = [ @@ -358,9 +358,9 @@ def vocab_size(self) -> int: def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. """ - return CacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) + return HPUCacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) def init_worker_distributed_environment( @@ -423,3 +423,26 @@ def raise_if_cache_size_invalid(num_gpu_blocks, block_size, f"stored in KV cache ({max_seq_len}). Try increasing " "`gpu_memory_utilization` or decreasing `max_model_len` when " "initializing the engine.") + + +class HPUCacheEngine(CacheEngine): + + def _allocate_kv_cache( + self, + num_blocks: int, + device: str, + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for _ in range(self.num_attention_layers): + key_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + value_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + kv_layer = (key_cache, value_cache) + kv_cache.append(kv_layer) + return kv_cache