Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support reset_prefix_cache #12284

Merged
merged 7 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,44 @@ def test_find_cached_blocks_prefix():
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks

# Test reset prefix cache
@staticmethod
@pytest.mark.parametrize("num_blocks", [10])
@pytest.mark.parametrize("block_size", [16])
def test_reset_prefix_cache(num_blocks: int, block_size: int):
"""This test case simulates the case of resetting the prefix cache."""

allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
token_ids = list(range(3 * block_size))

first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)

# Free each block in the first chain.
for block in first_chain:
allocator.free(block)

# Failed to reset prefix cache because some blocks are not freed yet.
assert not allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() > 0.0

# Free each block in the second chain.
for block in second_chain:
allocator.free(block)

# Reset prefix cache.
assert allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() == 0.0

@staticmethod
def create_immutable_chain(
block_size: int,
Expand Down
39 changes: 39 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
assert {block.ref_cnt for block in block_part1[3:]} == {0}


def test_reset_prefix_cache():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)

full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, [])
assert [b.block_id for b in blocks] == [0, 1, 2, 3]

unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]

# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
assert manager.cached_block_hash_to_block

# Free the blocks.
manager.free(req0)
manager.free(req1)

assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])
7 changes: 7 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache()
return success

def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
Expand Down
10 changes: 10 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass

class NoFreeBlocksError(ValueError):
pass

Expand Down Expand Up @@ -297,6 +302,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass

@abstractmethod
def find_cached_blocks_prefix(
self,
Expand Down
19 changes: 14 additions & 5 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union

from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
Expand Down Expand Up @@ -136,16 +136,18 @@ def _allocate_block_id(self) -> BlockId:
self._refcounter.incr(block_id)
return block_id

def _free_block_id(self, block: Block) -> None:
block_id = block.block_id
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
if isinstance(block, Block):
block_id = block.block_id
block.block_id = None
else:
block_id = block
assert block_id is not None

refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)

block.block_id = None

def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
self._free_block_id(block)
Expand All @@ -154,6 +156,9 @@ def free(self, block: Block, keep_block_object: bool = False) -> None:
if not keep_block_object:
self._block_pool.free_block(block)

def free_block_id(self, block_id: BlockId) -> None:
self._free_block_id(block_id)

def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Expand Down Expand Up @@ -325,6 +330,10 @@ def swap_in(self, blocks: List[Block]) -> None:
def get_prefix_cache_hit_rate(self) -> float:
return -1

def reset_prefix_cache(self) -> bool:
"""No prefix cache for naive block allocator."""
return True

def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator.
return []
Expand Down
44 changes: 43 additions & 1 deletion vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.logger import init_logger
from vllm.sequence import Sequence

PrefixHash = int
Expand All @@ -21,6 +22,8 @@
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1

logger = init_logger(__name__)


class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator
Expand Down Expand Up @@ -105,7 +108,8 @@ def __init__(

# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
self.eviction_policy = eviction_policy
self.evictor: Evictor = make_evictor(self.eviction_policy)

# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
Expand Down Expand Up @@ -428,6 +432,44 @@ def all_block_ids(self) -> FrozenSet[int]:
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.

Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.get_num_total_blocks() -
self.get_num_free_blocks())
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False

# Free all blocks in the evictor.
while (block_id :=
self._maybe_allocate_evicted_block_id()) is not None:
self._hashless_allocator.free_block_id(block_id)

# Should not have any cached blocks because all blocks are evicted.
assert not self._cached_blocks

# Reset the evictor.
self.evictor = make_evictor(self.eviction_policy)

# Reset the block tracker.
for block_id in self._block_tracker:
self._block_tracker[block_id] = BlockTracker()

# Reset the metrics.
self.metric_data = CacheMetricData()

logger.info("Successfully reset prefix cache")
return True

def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
return block.content_hash in self._cached_blocks
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def get_num_free_cpu_blocks(self) -> int:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)

def reset_prefix_cache(self) -> bool:
return self.block_allocator.reset_prefix_cache()

def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
Expand Down
5 changes: 5 additions & 0 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
pass

@abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int:
pass
3 changes: 3 additions & 0 deletions vllm/core/placeholder_block_space_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup,
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1

def reset_prefix_cache(self) -> bool:
return True

def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0
3 changes: 3 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ def has_unfinished_seqs(self) -> bool:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)

def reset_prefix_cache(self) -> bool:
return self.block_manager.reset_prefix_cache()

def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,9 @@ async def start_profile(self) -> None:
async def stop_profile(self) -> None:
self.engine.stop_profile()

async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()

async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request)

Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,14 @@ def has_unfinished_requests_for_virtual_engine(
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""

success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
return success

@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2


class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1


@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
Expand All @@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse:


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest]
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]
Expand Down
12 changes: 10 additions & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
Expand Down Expand Up @@ -675,6 +676,13 @@ async def stop_profile(self) -> None:
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)

async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""

await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket)

async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
Expand Down
10 changes: 8 additions & 2 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.logger import init_logger
Expand Down Expand Up @@ -237,6 +238,8 @@ def handle_new_input(self):
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
Expand Down Expand Up @@ -361,6 +364,9 @@ def start_profile(self) -> None:
def stop_profile(self) -> None:
self.engine.stop_profile()

def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()


def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
Expand Down
Loading
Loading