From 6189ff23f96c498296eeb492897179a38dcb9662 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 17 Jan 2025 15:39:35 +0800 Subject: [PATCH] [V1] Move more control of kv cache initialization from model_executor to EngineCore (#11960) Signed-off-by: Chen Zhang Co-authored-by: Cody Yu --- tests/v1/test_utils.py | 62 +++++++++++++ vllm/attention/layer.py | 2 + vllm/v1/core/kv_cache_utils.py | 124 +++++++++++++++++++++++++ vllm/v1/engine/core.py | 31 ++++--- vllm/v1/executor/abstract.py | 11 ++- vllm/v1/executor/multiproc_executor.py | 25 +++-- vllm/v1/executor/ray_executor.py | 40 ++++---- vllm/v1/executor/uniproc_executor.py | 25 ++--- vllm/v1/kv_cache_interface.py | 111 ++++++++++++++++++++++ vllm/v1/utils.py | 56 ++++++++++- vllm/v1/worker/gpu_model_runner.py | 84 ++++++++++++++--- vllm/v1/worker/gpu_worker.py | 48 +++------- 12 files changed, 515 insertions(+), 104 deletions(-) create mode 100644 tests/v1/test_utils.py create mode 100644 vllm/v1/kv_cache_interface.py diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py new file mode 100644 index 0000000000000..ac773b611f406 --- /dev/null +++ b/tests/v1/test_utils.py @@ -0,0 +1,62 @@ +from typing import List + +import torch + +from vllm.v1.utils import bind_kv_cache + + +def test_bind_kv_cache(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = { + 'layers.0.self_attn': torch.zeros((1, )), + 'layers.1.self_attn': torch.zeros((1, )), + 'layers.2.self_attn': torch.zeros((1, )), + 'layers.3.self_attn': torch.zeros((1, )), + } + runner_kv_caches: List[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ + 'layers.0.self_attn'] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ + 'layers.1.self_attn'] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ + 'layers.2.self_attn'] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ + 'layers.3.self_attn'] + + assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] + assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] + assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] + assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] + + +def test_bind_kv_cache_non_attention(): + from vllm.attention import Attention + + # example from Jamba PP=2 + ctx = { + 'model.layers.20.attn': Attention(32, 128, 0.1), + 'model.layers.28.attn': Attention(32, 128, 0.1), + } + kv_cache = { + 'model.layers.20.attn': torch.zeros((1, )), + 'model.layers.28.attn': torch.zeros((1, )), + } + + runner_kv_caches: List[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + + assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ + 'model.layers.20.attn'] + assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ + 'model.layers.28.attn'] + + assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] + assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9b03fd73fe690..e2403306950a3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -101,7 +101,9 @@ def __init__( self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 22a5d2fb08a48..bab99fe37caee 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -3,7 +3,10 @@ from dataclasses import dataclass from typing import Any, List, NamedTuple, Optional, Tuple +from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, + KVCacheTensor) from vllm.v1.request import Request logger = init_logger(__name__) @@ -305,3 +308,124 @@ def hash_request_tokens(block_size: int, ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret + + +def check_enough_kv_cache_memory(vllm_config: VllmConfig, + kv_cache_spec: KVCacheSpec, + available_memory: int): + """ + Checks whether `available_memory` is enough for the KV cache to hold at + least one request with the model's max_model_len. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory: Memory available for KV cache in bytes. + + Raises: + ValueError: If there is not enough memory available for the KV cache. + """ + + if available_memory <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_model_len = vllm_config.model_config.max_model_len + needed_memory = 0 + for layer_spec in kv_cache_spec.values(): + needed_memory += layer_spec.bytes_for_tokens(max_model_len) + + if needed_memory > available_memory: + raise ValueError( + f"To serve at least one request with the models's max seq len " + f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV " + f"cache is needed, which is larger than the available KV cache " + f"memory ({available_memory/1024/1024/1024:.2f} GB). Try " + f"increasing `gpu_memory_utilization` or decreasing " + f"`max_model_len` when initializing the engine.") + + +def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same type of KV cache. + + Args: + kv_cache_spec: The KVCacheSpec of the model + + Returns: + True if all layers have the same type, False otherwise. + """ + + layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) + return len(layer_keys) == 1 + + +def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, + kv_cache_spec: KVCacheSpec, + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one type of KV cache. + Divide the available memory equally among all layers. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + assert len(page_sizes) == 1 + page_size = page_sizes.pop() + + num_blocks = int(available_memory // page_size // len(kv_cache_spec)) + num_blocks = max(num_blocks, 0) + + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + num_blocks = num_gpu_blocks_override + + logger.info("# GPU blocks: %d", num_blocks) + + per_layer_size = page_size * num_blocks + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + tensors={ + layer_name: KVCacheTensor(size=per_layer_size) + for layer_name in kv_cache_spec + }, + groups=[[layer_name for layer_name in kv_cache_spec]], + kv_cache_spec=kv_cache_spec) + return kv_cache_config + + +def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model + TODO: support hybrid models with more than one type of KV cache. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for most models. + # Allocate the same amount of memory for each layer. + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) + else: + raise NotImplementedError diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ef616229aa57b..26ebc7edcf03e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,11 +11,12 @@ import zmq.asyncio from msgspec import msgpack -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx +from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, @@ -49,7 +50,7 @@ def __init__( # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( - vllm_config.cache_config) + vllm_config) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks @@ -65,21 +66,25 @@ def __init__( vllm_config.model_config) def _initialize_kv_caches(self, - cache_config: CacheConfig) -> Tuple[int, int]: + vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() - num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( - ) - if cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override + # Get all kv cache needed by the model + kv_cache_spec = self.model_executor.get_kv_cache_spec() + + # Profiles the peak memory usage of the model to determine how much + # memory can be allocated for kv cache. + availble_gpu_memory = self.model_executor.determine_available_memory() + # Get the kv cache tensor size + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + availble_gpu_memory) + num_gpu_blocks = kv_cache_config.num_blocks num_cpu_blocks = 0 - self.model_executor.initialize(num_gpu_blocks) + + # Initialize kv cache and warmup the execution + self.model_executor.initialize(kv_cache_config) + elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 7c17f60510ae1..5240778ebf330 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from typing import Tuple, Type +from typing import Type from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -31,11 +32,15 @@ def __init__(self, vllm_config: VllmConfig) -> None: raise NotImplementedError @abstractmethod - def initialize(self, num_gpu_blocks: int) -> None: + def initialize(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_available_memory(self) -> int: # in bytes + raise NotImplementedError + + @abstractmethod + def get_kv_cache_spec(self) -> KVCacheSpec: raise NotImplementedError @abstractmethod diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e111ac7ee8183..e92acc7cb5e41 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -23,6 +23,7 @@ from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase @@ -90,29 +91,33 @@ def sigusr1_handler(signum, frame): for w in self.workers: w.worker_response_mq.wait_until_ready() - def initialize(self, num_gpu_blocks: int) -> None: + def initialize(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - logger.info("# GPU blocks: %d", num_gpu_blocks) - self.collective_rpc("initialize_cache", args=(num_gpu_blocks, )) + self.collective_rpc("initialize_cache", args=(kv_cache_config, )) self.collective_rpc("compile_or_warm_up_model") - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_available_memory(self) -> int: """ - Determine the number of available KV blocks by invoking the + Determine the available memory (in bytes) for KV cache by invoking the underlying worker. """ - num_blocks = self.collective_rpc("determine_num_available_blocks") + memory_sizes = self.collective_rpc("determine_available_memory") # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory + # memory size across all workers to make sure all the memory # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) + return min(memory_sizes) - return num_gpu_blocks, num_cpu_blocks + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Get all kv cache needed by the model by invoking the underlying worker. + """ + kv_cache_specs = self.collective_rpc("get_kv_cache_spec") + assert all(s == kv_cache_specs[0] for s in kv_cache_specs) + return kv_cache_specs[0] def collective_rpc(self, method: str, diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 79acc60001c99..fd67fa2235770 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -10,6 +10,7 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, ray) +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput if ray is not None: @@ -211,39 +212,40 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, ) - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_available_memory(self) -> int: """ - Determine the number of available KV blocks. + Determine the available GPU memory in bytes. - This invokes `determine_num_available_blocks` on each worker and takes + This invokes `determine_available_memory` on each worker and takes the min of the results, guaranteeing that the selected cache sizes are compatible with all workers. - - Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks") + + memory_sizes = self._run_workers("determine_available_memory") # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory + # memory size across all workers to make sure all the memory # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) + return min(memory_sizes) - return num_gpu_blocks, num_cpu_blocks - - def initialize(self, num_gpu_blocks: int) -> None: + def initialize(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the KV cache in all workers. """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d", num_gpu_blocks) - self._run_workers("initialize_cache", num_gpu_blocks) + self._run_workers("initialize_cache", kv_cache_config) self._run_workers("compile_or_warm_up_model") + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Get all kv cache needed by the model + + This invokes `get_kv_cache_spec` on each worker and asserts that + they are identical. The KVCacheSpec is then returned. + """ + kv_cache_specs = self._run_workers("get_kv_cache_spec") + assert all(s == kv_cache_specs[0] for s in kv_cache_specs) + return kv_cache_specs[0] + def _run_workers( self, method: str, diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index c63d7a4c47c15..b3997caac726b 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -1,10 +1,11 @@ import os -from typing import Optional, Tuple +from typing import Optional from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_worker import Worker @@ -49,20 +50,22 @@ def _create_worker( distributed_init_method=distributed_init_method, ) - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. + def determine_available_memory(self) -> int: + """Determine the available memory (in bytes) for KV cache by invoking + the underlying worker. """ - return self.worker.determine_num_available_blocks() + return self.worker.determine_available_memory() - def initialize(self, num_gpu_blocks: int) -> None: + def get_kv_cache_spec(self) -> KVCacheSpec: + """Get all kv cache needed by the model by invoking the underlying + worker. + """ + return self.worker.get_kv_cache_spec() + + def initialize(self, kv_cache_config: KVCacheConfig) -> None: """Initialize the KV cache by invoking the underlying worker. """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d", num_gpu_blocks) - self.worker.initialize_cache(num_gpu_blocks) + self.worker.initialize_cache(kv_cache_config) self.worker.compile_or_warm_up_model() def execute_model( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py new file mode 100644 index 0000000000000..6d5cc32ffc5b8 --- /dev/null +++ b/vllm/v1/kv_cache_interface.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +from typing import Dict, List + +import torch + +from vllm.logger import init_logger +from vllm.utils import cdiv, get_dtype_size + +logger = init_logger(__name__) + + +@dataclass +class KVCacheSpecBase: + """ + A base class for specifying the KV cache format of one layer. + """ + + # number of tokens in a block + block_size: int + + @property + def type_id(self) -> str: + """ + The type identifier of this KV cache. + Return different strings for layers with different KV cache type (e.g., + different number of tokens like full attention vs sliding window + attention, different KV cache size per token like layers with different + number of heads) + + Returns: + The type identifier of this KV cache. + """ + raise NotImplementedError + + @property + def page_size_bytes(self) -> int: + """ + The size of a page with `block_size` tokens in bytes. + + Returns: + The page size + """ + raise NotImplementedError + + def bytes_for_tokens(self, num_tokens: int) -> int: + """ + The KV cache size for `num_tokens` tokens in bytes. Returns the real + memory size after padding `num_tokens` to full blocks. + + Returns: + The KV cache size + """ + raise NotImplementedError + + +@dataclass +class FullAttentionSpec(KVCacheSpecBase): + num_kv_heads: int + head_size: int + dtype: torch.dtype + + @property + def type_id(self) -> str: + return f"full_attention_{self.block_size}_{self.page_size_bytes}" + + @property + def page_size_bytes(self) -> int: + return 2 * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + + def bytes_for_tokens(self, num_tokens: int) -> int: + return cdiv(num_tokens, self.block_size) * self.page_size_bytes + + +KVCacheSpec = Dict[str, KVCacheSpecBase] + + +@dataclass +class KVCacheTensor: + """ + A dataclass for specifying how the workers should initialize the KV cache + for a layer. Only contains the size of KV cache for that layer for now. Will + be extended to support multiple layers sharing the same memory pool. + """ + size: int # The size of KV cache Tensor in bytes + + +@dataclass +class KVCacheConfig: + """ + The KV cache configuration of a model. + """ + """The number of KV cache blocks""" + num_blocks: int + """layer_name -> how to initialize KV cache for that layer""" + tensors: Dict[str, KVCacheTensor] + """ + A list of kv-cache groups. Each group includes a set of layers with + the same kv-cache spec, and the total page_size of layers inside a group + is same across all groups (as the KVCacheManager only supports allocating + pages of the same size). For example: + 1. A model only uses full attention: one group with all layers in the model. + 2. (not implemented yet) A model with the same number of full attention + layers and sliding window attention layers: two groups, one for full + attention layers and one for sliding window attention layers. + 3. (not implemented yet) A model with 2 full attention layers and 4 sliding + window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). + """ + groups: List[List[str]] + """the KVCacheSpec of the model""" + kv_cache_spec: KVCacheSpec diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b0a7affbebb7e..8dfcf2dd78606 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,13 +1,20 @@ import multiprocessing import os import weakref +from collections import defaultdict from collections.abc import Sequence -from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar, - Union, overload) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List, + Optional, TypeVar, Union, overload) + +import torch from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import get_mp_context, kill_process_tree +if TYPE_CHECKING: + from vllm.attention.layer import Attention + logger = init_logger(__name__) T = TypeVar("T") @@ -134,3 +141,48 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): socket_file = ipc_socket.replace("ipc://", "") if os and os.path.exists(socket_file): os.remove(socket_file) + + +def bind_kv_cache( + kv_caches: Dict[str, torch.Tensor], + forward_context: Dict[str, "Attention"], + runner_kv_caches: List[torch.Tensor], +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index de83640b27cd6..aa63d9414c296 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,6 +7,8 @@ import torch.distributed import torch.nn as nn +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context @@ -16,14 +18,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, bind_kv_cache, cdiv, - is_pin_memory_available) + LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -856,15 +860,71 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def initialize_kv_cache(self, num_blocks: int) -> None: - assert len(self.kv_caches) == 0 - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - for _ in range(self.num_attn_layers): - self.kv_caches.append( - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device)) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + if len(kv_cache_config.groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") + + kv_caches: Dict[str, torch.Tensor] = {} + + for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % layer_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // layer_spec.page_size_bytes + if isinstance(layer_spec, FullAttentionSpec): + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, + layer_spec.head_size) + dtype = layer_spec.dtype + kv_caches[layer_name] = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + else: + raise NotImplementedError + bind_kv_cache( + kv_caches, self.vllm_config.compilation_config.static_forward_context, - [self.kv_caches]) + self.kv_caches) + + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + forward_ctx = self.vllm_config.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size + kv_cache_spec: KVCacheSpec = {} + for layer_name, attn_module in forward_ctx.items(): + # TODO: Support other attention modules, e.g., sliding window, + # cross-attention, MLA. + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + ) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 81b247e07ef4a..4fb4197f1822f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional import torch import torch.distributed @@ -16,6 +16,7 @@ from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -112,20 +113,18 @@ def load_model(self) -> None: self.model_runner.load_model() @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. + def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. + Then, it calculate the free memory that can be used for KV cache in + bytes. .. tip:: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -161,33 +160,14 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - return num_gpu_blocks, 0 - - def initialize_cache(self, num_gpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks.""" - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_gpu_blocks - max_model_len = self.model_config.max_model_len - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") + return int(available_kv_cache_memory) + + def get_kv_cache_spec(self) -> KVCacheSpec: + return self.model_runner.get_kv_cache_spec() - self.model_runner.initialize_kv_cache(num_gpu_blocks) + def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: