Skip to content

Commit

Permalink
[V1] Remove _get_cache_block_size (vllm-project#12214)
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Zhang <[email protected]>
  • Loading branch information
heheda12345 authored Jan 20, 2025
1 parent c222f47 commit 5f0ec39
Showing 1 changed file with 1 addition and 23 deletions.
24 changes: 1 addition & 23 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import torch.nn as nn

import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -235,24 +234,3 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")


def _get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)

key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = get_dtype_size(dtype)
return dtype_size * total

0 comments on commit 5f0ec39

Please sign in to comment.