From f77435d44bf6f8fc129ffc58fc206d4d7bc8a81f Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Wed, 16 Oct 2024 09:40:20 +0200 Subject: [PATCH 1/3] Softmax: add weighted-sum normalization (#378) Supporting PR for https://github.com/HabanaAI/vllm-hpu-extension/pull/10 --- requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 1 + vllm/attention/ops/hpu_paged_attn.py | 1 + vllm/worker/hpu_model_runner.py | 45 +++++++++++++++++++++------- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 8495d63ce72fa..1a583974be151 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@7531cc6 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6 diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 17201fe6e1cd6..a8f4b09b67274 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -222,6 +222,7 @@ def forward( block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, + block_scales=attn_metadata.block_scales, scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 7fbe26d83f320..4c0fb2a628361 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -20,6 +20,7 @@ class HPUPagedAttentionMetadata: block_usage: Optional[torch.Tensor] block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] + block_scales: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f81e4aa59b289..d8150a56844a2 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -298,9 +298,19 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - block_mapping = torch.nn.functional.one_hot( - metadata.block_mapping.to(torch.long), - num_classes=batch_size).to(dtype) + if is_fake_hpu(): + # Unfortunately one_hot on CPU doesn't handle + # out of bounds classes. We need to mask those + # values manually + oob_values = metadata.block_mapping.lt(0) + block_mapping = metadata.block_mapping.masked_fill(oob_values, 0) + block_mapping = torch.nn.functional.one_hot(block_mapping, + num_classes=batch_size) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) + else: + block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + num_classes=batch_size) + block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata @@ -873,6 +883,7 @@ def _prepare_prompt( block_usage=None, block_indices=block_indices, block_offsets=block_offsets, + block_scales=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, num_prefills=real_num_seqs, @@ -968,7 +979,15 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) blocks_used = [len(bt) for bt in block_tables if bt] - block_list = list(itertools.chain(*block_tables)) + block_list = [] + block_scales = [] + for i, bt in enumerate(block_tables): + block_list.extend(bt) + blocks_in_group = len(bt) + if blocks_in_group > 0: + scale = 1.0 / blocks_in_group + block_scales.extend([scale] * blocks_in_group) + block_mapping_nested: List[List[int]] = [ [i] * b_u for i, b_u in enumerate(blocks_used) ] @@ -984,18 +1003,19 @@ def _prepare_decode( block_bucket_size = find_bucket(len(block_list), self.decode_block_bucket_cfg) - block_list = pad_list(block_list, block_bucket_size, _PAD_SLOT_ID) - block_mapping = pad_list(block_mapping, block_bucket_size, 0) - block_usage = pad_list(block_usage, block_bucket_size, 0) + block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) + block_mapping = pad_list(block_mapping, block_bucket_size, -1) + block_usage = pad_list(block_usage, block_bucket_size, 1) + block_scales = pad_list(block_scales, block_bucket_size, 0.0) block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) block_mapping = torch.tensor(block_mapping, - dtype=torch.int, + dtype=torch.long, device=self.device) block_usage = torch.tensor(block_usage, - dtype=torch.bfloat16, + dtype=self.model_config.dtype, device=self.device) slot_mapping = torch.tensor(slot_mapping, @@ -1004,6 +1024,10 @@ def _prepare_decode( block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, False) + block_scales = torch.tensor(block_scales, + dtype=self.model_config.dtype, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, @@ -1011,6 +1035,7 @@ def _prepare_decode( block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, + block_scales=block_scales, attn_bias=None, seq_lens_tensor=None, num_prefills=0, @@ -1222,7 +1247,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets' + 'block_offsets', 'block_scales' ]) return attention_metadata From a59fc7b481b1807f27de1165383b6e10476850d2 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 16 Oct 2024 19:30:58 +0200 Subject: [PATCH 2/3] Remove HPU changes from cache_engine.py (#400) We were asked on upstream PR to remove our changes from cache_engine.py. This PR does just that, and creates HPUCacheEngine inheriting from CacheEngine, just overriding _allocate_kv_cache method. --- vllm/worker/cache_engine.py | 30 +++++++++--------------------- vllm/worker/hpu_worker.py | 35 +++++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 9618585c8acb0..090f95e6e892c 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_fake_hpu, +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) @@ -75,26 +75,14 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers): - if device == 'hpu' or is_fake_hpu(): - key_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - device=device) - value_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - device=device) - kv_layer = (key_cache, value_cache) - kv_cache.append(kv_layer) - else: - # null block in CpuGpuBlockAllocator requires at least that - # block to be zeroed-out. - # We zero-out everything for simplicity. - dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else \ - self.dtype - kv_cache.append( - torch.zeros(kv_cache_shape, - dtype=dtype, - pin_memory=pin_memory, - device=device)) + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + kv_cache.append( + torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device)) return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 59a5adf65ebc1..752388e0d632f 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -93,7 +93,7 @@ def __init__( observability_config=observability_config) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: List[CacheEngine] + self.cache_engine: List[HPUCacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches self.hpu_cache: Optional[List[List[torch.tensor]]] = None # Torch profiler. Enabled and configured through env vars: @@ -242,8 +242,8 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ - CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) + HPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.hpu_cache = [ @@ -358,9 +358,9 @@ def vocab_size(self) -> int: def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. """ - return CacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) + return HPUCacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) def init_worker_distributed_environment( @@ -423,3 +423,26 @@ def raise_if_cache_size_invalid(num_gpu_blocks, block_size, f"stored in KV cache ({max_seq_len}). Try increasing " "`gpu_memory_utilization` or decreasing `max_model_len` when " "initializing the engine.") + + +class HPUCacheEngine(CacheEngine): + + def _allocate_kv_cache( + self, + num_blocks: int, + device: str, + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + for _ in range(self.num_attention_layers): + key_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + value_cache = torch.zeros(kv_cache_shape, + dtype=self.dtype, + device=device) + kv_layer = (key_cache, value_cache) + kv_cache.append(kv_layer) + return kv_cache From 05bcdf5e169be9d746ff4c9d6163fff9f4b310b9 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Thu, 17 Oct 2024 12:18:10 +0200 Subject: [PATCH 3/3] [bucketing overhaul 1/n] Add padding-aware scheduling and option to limit prefill batch size (#394) This PR adds following functionality that can be enabled via engine flags: - use_padding_aware_scheduling - vLLM scheduler will now calculate token cost considering padded prefill shape (similar to https://github.com/HabanaAI/vllm-fork/pull/109). - max_num_prefill_seqs - padding-aware scheduler will perform an additional check for prefill batch size and will effectively limit prefill batch size at maximum of `max_num_prefill_seqs`. If unset, max prefill batch size will be `max_num_seqs`. Both features are generic and do not require HPU, although they may be specialized for particular vendor's usage. Padding aware scheduling includes padding function selector which selects HPU padding function (considering currently used HPU buckets) if current device is HPU. Otherwise, it will take a product of batch_size x max_seq_len. --- vllm/config.py | 18 ++++- vllm/core/scheduler.py | 122 ++++++++++++++++++++++++++-- vllm/engine/arg_utils.py | 19 ++++- vllm/worker/hpu_model_runner.py | 137 ++++++++++++++++++++------------ 4 files changed, 238 insertions(+), 58 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5499b349bcfc8..67a4ec0761cc3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -940,6 +940,9 @@ class SchedulerConfig: a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. + max_num_prefill_seqs: Maximum number of prefill sequences to be + processed in a single iteration. Used only with padding-aware + scheduling. max_model_len: Maximum length of a sequence (including prompt and generated text). use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. @@ -963,11 +966,14 @@ class SchedulerConfig: when SPMD worker architecture is enabled. I.e., VLLM_USE_RAY_SPMD_WORKER=1 policy: The scheduling policy to use. "fcfs" (default) or "priority". + use_padding_aware_scheduling: If True, scheduler will consider padded + tokens in prefill. """ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, + max_num_prefill_seqs: Optional[int], max_model_len: int, use_v2_block_manager: bool = True, num_lookahead_slots: int = 0, @@ -979,7 +985,8 @@ def __init__(self, num_scheduler_steps: int = 1, multi_step_stream_outputs: bool = False, send_delta_data: bool = False, - policy: str = "fcfs") -> None: + policy: str = "fcfs", + use_padding_aware_scheduling=False) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: if num_scheduler_steps > 1: @@ -1018,6 +1025,7 @@ def __init__(self, self.max_num_batched_tokens) self.max_num_seqs = max_num_seqs + self.max_num_prefill_seqs = max_num_prefill_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots @@ -1029,6 +1037,7 @@ def __init__(self, self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data self.policy = policy + self.use_padding_aware_scheduling = use_padding_aware_scheduling self._verify_args() def _verify_args(self) -> None: @@ -1059,6 +1068,13 @@ def _verify_args(self) -> None: "num_scheduler_steps " f"({self.num_scheduler_steps}) must be greater than or " "equal to 1.") + if self.max_num_prefill_seqs is not None \ + and not self.use_padding_aware_scheduling: + raise ValueError("max_num_prefill_seqs can be only " + "used with padding-aware-scheduling. ") + if self.use_padding_aware_scheduling and self.chunked_prefill_enabled: + raise ValueError("Padding-aware scheduling currently " + "does not work with chunked prefill ") if (not self.use_v2_block_manager \ and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1f0a121711db5..1c69c72933b79 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, @@ -101,6 +102,94 @@ def num_curr_seqs(self): return self._num_curr_seqs +@dataclass +class PaddingAwareSchedulingBudget(SchedulingBudget): + max_num_prefill_seqs: Optional[int] = None + _prefill_request_ids_max_seq_lens: Dict[str, + int] = field(default_factory=dict) + _max_seq_len: int = 0 + _num_curr_prefill_seqs: int = 0 + + def _generic_padding_fn(self, batch_size, max_seq_len) -> int: + return batch_size * max_seq_len + + def _hpu_padding_fn(self, batch_size, max_seq_len): + from vllm.worker.hpu_model_runner import (HPUBucketingGlobalState, + find_bucket) + padded_bs = batch_size + padded_seq = max_seq_len + + hpu_bucketing_global_state = HPUBucketingGlobalState() + + bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg + if bs_cfg is not None: + padded_bs = find_bucket(batch_size, bs_cfg) + else: + logger.warning( + "prompt_bs_bucket_cfg was not set! Using unpadded batch size.") + seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg + if seq_cfg is not None: + padded_seq = find_bucket(max_seq_len, seq_cfg) + else: + logger.warning("prompt_seq_bucket_cfg was not set! " + "Using unpadded sequence length.") + return padded_bs * padded_seq + + def _padding_fn_selector(self): + if current_platform.is_hpu(): + return self._hpu_padding_fn + return self._generic_padding_fn + + def _maybe_update_max_seq_len(self, + new_seq_max_seq_len: Optional[int] = None): + if new_seq_max_seq_len is not None \ + and new_seq_max_seq_len > self._max_seq_len: + self._max_seq_len = new_seq_max_seq_len + return + self._max_seq_len = max( + self._prefill_request_ids_max_seq_lens.values()) + + def add_prefill_seqs(self, req_id, num_curr_prefill_seqs, max_seq_len): + self._prefill_request_ids_max_seq_lens[req_id] = max_seq_len + self._num_curr_prefill_seqs += num_curr_prefill_seqs + self._maybe_update_max_seq_len(max_seq_len) + + def subtract_prefill_seqs(self, req_id, num_curr_prefill_seqs): + if req_id in self._prefill_request_ids_max_seq_lens: + popped_seq_len = self._prefill_request_ids_max_seq_lens.pop(req_id) + self._num_curr_prefill_seqs -= num_curr_prefill_seqs + if popped_seq_len == self._max_seq_len: + self._maybe_update_max_seq_len() + + def can_schedule(self, + *args, + num_new_tokens: int, + num_new_seqs: int, + is_prefill: bool = False, + max_seq_len: int = 0): + can_parent_schedule = super().can_schedule( + *args, num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs) + if not can_parent_schedule or not is_prefill: + return can_parent_schedule + new_batch_size = self._num_curr_prefill_seqs + num_new_seqs + new_max_seq_len = max(max(self._max_seq_len, max_seq_len), 1) + padding_fn = self._padding_fn_selector() + num_new_padded_tokens = padding_fn(new_batch_size, new_max_seq_len) + result = num_new_padded_tokens <= self.token_budget + if self.max_num_prefill_seqs is not None and result: + result = self._num_curr_prefill_seqs + num_new_seqs \ + <= self.max_num_prefill_seqs + return result + + @property + def max_seq_len(self): + return self._max_seq_len + + @property + def num_curr_prefill_seqs(self): + return self._num_curr_prefill_seqs + + @dataclass class ScheduledSequenceGroup: # A sequence group that's scheduled. @@ -938,9 +1027,18 @@ def _schedule_prefills( continue num_new_seqs = seq_group.get_max_num_running_seqs() + max_prefill_seq_len = None + can_schedule_kwargs = { + 'num_new_tokens': num_new_tokens, + 'num_new_seqs': num_new_seqs + } + if self.scheduler_config.use_padding_aware_scheduling: + max_prefill_seq_len = max( + [seq.get_num_new_tokens() for seq in seq_group.get_seqs()]) + can_schedule_kwargs['is_prefill'] = True + can_schedule_kwargs['max_seq_len'] = max_prefill_seq_len if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + or not budget.can_schedule(**can_schedule_kwargs)): break # Can schedule this request. @@ -971,6 +1069,10 @@ def _schedule_prefills( token_chunk_size=num_new_tokens)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_seqs(seq_group.request_id, num_new_seqs) + if self.scheduler_config.use_padding_aware_scheduling: + assert isinstance(budget, PaddingAwareSchedulingBudget) + budget.add_prefill_seqs(seq_group.request_id, num_new_seqs, + max_prefill_seq_len) # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) @@ -992,10 +1094,18 @@ def _schedule_default(self) -> SchedulerOutputs: be swapped or preempted. """ # Include running requests to the budget. - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) + budget: SchedulingBudget + if self.scheduler_config.use_padding_aware_scheduling: + budget = PaddingAwareSchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + max_num_prefill_seqs=self.scheduler_config.max_num_prefill_seqs + ) + else: + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) # Make sure we include num running seqs before scheduling prefill, # so that we don't schedule beyond max_num_seqs for prefill. for seq_group in self.running: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3c9f3d4fe4ab3..cdf1401816800 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -113,11 +113,13 @@ class EngineArgs: enable_prefix_caching: bool = False disable_sliding_window: bool = False use_v2_block_manager: bool = True + use_padding_aware_scheduling: bool = False swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 + max_num_prefill_seqs: Optional[int] = None max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False revision: Optional[str] = None @@ -391,6 +393,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='Use BlockSpaceMangerV2. By default this is set to True. ' 'Set to False to use BlockSpaceManagerV1') + parser.add_argument( + '--use-padding-aware-scheduling', + default=EngineArgs.use_padding_aware_scheduling, + action='store_true', + help=('Use padding-aware scheduling. If True, the scheduler ' + 'will consider padded tokens in prefill. ' + 'By default this is set to False. ')) parser.add_argument( '--num-lookahead-slots', type=int, @@ -445,6 +454,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.max_num_seqs, help='Maximum number of sequences per iteration.') + parser.add_argument( + '--max-num-prefill-seqs', + type=int, + default=EngineArgs.max_num_prefill_seqs, + help=('Maximum number of prefill sequences per ' + 'iteration. Can be used only with padding-aware ' + 'scheduling. Must be <= max_num_seqs.')) parser.add_argument( '--max-logprobs', type=int, @@ -1036,6 +1052,7 @@ def create_engine_config(self) -> EngineConfig: scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, + max_num_prefill_seqs=self.max_num_prefill_seqs, max_model_len=model_config.max_model_len, use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=num_lookahead_slots, @@ -1049,7 +1066,7 @@ def create_engine_config(self) -> EngineConfig: send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, - ) + use_padding_aware_scheduling=self.use_padding_aware_scheduling) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d8150a56844a2..785337478468f 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -13,6 +13,7 @@ import os import time from array import array +from dataclasses import dataclass, field from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -64,6 +65,26 @@ LORA_WARMUP_RANK = 8 +class Singleton(type): + _instances: Dict[type, object] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, + cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +@dataclass +class HPUBucketingGlobalState(metaclass=Singleton): + prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) + decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) + prompt_buckets: List[Tuple[int, int]] = field(init=False) + decode_buckets: List[Tuple[int, int]] = field(init=False) + + def subtuple(obj: object, typename: str, to_copy: List[str], @@ -542,6 +563,9 @@ def __init__( self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs + self.max_num_prefill_seqs = self.scheduler_config.max_num_prefill_seqs \ + if self.scheduler_config.max_num_prefill_seqs is not None \ + else self.max_num_seqs self.max_model_len = self.scheduler_config.max_model_len self.max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens @@ -569,6 +593,7 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None + self.bucketing_global_state = HPUBucketingGlobalState() self._setup_buckets() self._set_gc_threshold() @@ -680,27 +705,26 @@ def _is_valid_bucket(self, bucket): def _setup_buckets(self) -> None: align_bs = lambda x: min(self.max_num_seqs, x) - max_bucket_cfg = 64 #FIXME: The default values should be max_model_len max_prompt_seq = 1024 max_decode_seq = 2048 - self.prompt_bs_bucket_cfg = read_bucket_settings( + self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( 'prompt', 'bs', min=1, step=align_bs(32), - max=align_bs(max_bucket_cfg)) - self.decode_bs_bucket_cfg = read_bucket_settings('decode', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_seqs) - self.prompt_seq_bucket_cfg = read_bucket_settings('prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.decode_block_bucket_cfg = read_bucket_settings( + max=self.max_num_prefill_seqs) + self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( + 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) + self.bucketing_global_state.prompt_seq_bucket_cfg = \ + read_bucket_settings( + 'prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=max_prompt_seq) + self.bucketing_global_state.decode_block_bucket_cfg = \ + read_bucket_settings( 'decode', 'block', min=self.block_size, @@ -710,13 +734,13 @@ def _setup_buckets(self) -> None: self.graphed_buckets: Set[Any] = set() msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.prompt_bs_bucket_cfg}, " - f"seq:{self.prompt_seq_bucket_cfg}") + f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " + f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") logger.info(msg) msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.decode_bs_bucket_cfg}, " - f"block:{self.decode_block_bucket_cfg}") + f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " + f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") logger.info(msg) def _prepare_prompt( @@ -834,7 +858,8 @@ def _prepare_prompt( assert max_query_len > 0 max_prompt_len = max( - find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), + find_bucket(max(seq_lens), + self.bucketing_global_state.prompt_seq_bucket_cfg), self.block_size) lora_ids: List[int] = [] @@ -1001,8 +1026,9 @@ def _prepare_decode( for b_u, lb in zip(blocks_used, last_block)] block_usage = list(itertools.chain(*block_usage)) - block_bucket_size = find_bucket(len(block_list), - self.decode_block_bucket_cfg) + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) block_mapping = pad_list(block_mapping, block_bucket_size, -1) block_usage = pad_list(block_usage, block_bucket_size, 1) @@ -1076,8 +1102,8 @@ def prepare_input_tensors( self.profiler.start('internal', base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.prompt_bs_bucket_cfg if is_prompt else \ - self.decode_bs_bucket_cfg + bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ + if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg batch_size_padded = find_bucket(real_batch_size, bucket_cfg) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() @@ -1282,9 +1308,10 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_batch_size = self.prompt_bs_bucket_cfg[-1] - max_seq_len = min(self.prompt_seq_bucket_cfg[-1], - self.max_num_batched_tokens // max_batch_size) + max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1] + max_seq_len = min( + self.bucketing_global_state.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) @@ -1498,13 +1525,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) - self.prompt_buckets, prompt_omitted_buckets = generate_prompt_buckets( - self.prompt_bs_bucket_cfg, self.prompt_seq_bucket_cfg, + self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + self.bucketing_global_state.prompt_bs_bucket_cfg, + self.bucketing_global_state.prompt_seq_bucket_cfg, self.max_num_batched_tokens) - msg = ( - f"Generated {len(self.prompt_buckets)} " - f"prompt buckets [bs, seq]: {list(sorted(self.prompt_buckets))}") + msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: \ + {list(sorted(self.bucketing_global_state.prompt_buckets))}") logger.info(msg) msg = (f"Omitted {len(prompt_omitted_buckets)} " @@ -1515,16 +1544,17 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" logger.debug(msg) - self.decode_buckets = generate_decode_buckets( - self.decode_bs_bucket_cfg, self.decode_block_bucket_cfg, - max_blocks) + self.bucketing_global_state.decode_buckets = generate_decode_buckets( + self.bucketing_global_state.decode_bs_bucket_cfg, + self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.decode_buckets), - list(sorted(self.decode_buckets))) + len(self.bucketing_global_state.decode_buckets), + list(sorted(self.bucketing_global_state.decode_buckets))) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: - cache_size_limit = len(self.prompt_buckets) + len( - self.decode_buckets) + 1 + cache_size_limit = len( + self.bucketing_global_state.prompt_buckets) + len( + self.bucketing_global_state.decode_buckets) + 1 torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between @@ -1551,8 +1581,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.prompt_buckets, True, kv_caches) - self.warmup_all_buckets(self.decode_buckets, False, kv_caches) + self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, + True, kv_caches) + self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, + False, kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1582,12 +1614,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, kv_caches, - prompt_available_memory) + prompt_strategy, self.bucketing_global_state.prompt_buckets, + True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, - decode_available_memory) + decode_strategy, self.bucketing_global_state.decode_buckets, + False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1596,7 +1628,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, self.prompt_buckets, True, + prompt_strategy, + self.bucketing_global_state.prompt_buckets, True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1608,14 +1641,18 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, self.decode_buckets, False, kv_caches, + decode_strategy, + self.bucketing_global_state.decode_buckets, False, + kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) - self.log_graph_warmup_summary(self.prompt_buckets, True, - mem_post_prompt) - self.log_graph_warmup_summary(self.decode_buckets, False, - mem_post_decode) + self.log_graph_warmup_summary( + self.bucketing_global_state.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary( + self.bucketing_global_state.decode_buckets, False, + mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage()