Skip to content

Commit

Permalink
[V1] Add kv cache utils tests. (vllm-project#11513)
Browse files Browse the repository at this point in the history
Signed-off-by: xcnick <[email protected]>
  • Loading branch information
xcnick authored and abmfy committed Jan 24, 2025
1 parent a659bcf commit b9882fb
Showing 1 changed file with 241 additions and 0 deletions.
241 changes: 241 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import pytest

from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request


def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
return Request(
request_id=request_id,
inputs=token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes,
),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
lora_request=None,
)


def test_kv_cache_block():
# Test KVCacheBlock initialization
block = KVCacheBlock(block_id=0)
assert block.block_id == 0
assert block.ref_cnt == 0
assert block.block_hash is None

# Test reference count manipulation
block.incr_ref()
assert block.ref_cnt == 1
block.decr_ref()
assert block.ref_cnt == 0

# Test block hash setting and resetting
block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3))
block.block_hash = block_hash
assert block.block_hash == block_hash

block.reset_hash()
assert block.block_hash is None


def test_free_kv_cache_block_queue_initialization():
# Test with a single block
block = KVCacheBlock(block_id=0)
queue = FreeKVCacheBlockQueue([block])
assert queue.num_free_blocks == 1
assert queue.free_list_head == block
assert queue.free_list_tail == block


def test_free_kv_cache_block_queue_operations():
# Create a list of KVCacheBlock objects
blocks = [KVCacheBlock(block_id=i) for i in range(5)]

# Create a FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue(blocks)

# Check initial state
assert queue.num_free_blocks == 5
assert queue.free_list_head == blocks[0]
assert queue.free_list_tail == blocks[4]

# Pop the first block
block1 = queue.popleft()
assert block1 == blocks[0]
assert queue.num_free_blocks == 4
assert queue.free_list_head == blocks[1]
assert queue.free_list_tail == blocks[4]

# Remove a block from the middle
block_to_remove = blocks[2]
queue.remove(block_to_remove)
assert queue.num_free_blocks == 3
assert blocks[1].next_free_block == blocks[3]
assert blocks[3].prev_free_block == blocks[1]

# Append a block back
queue.append(block_to_remove)
assert queue.num_free_blocks == 4
assert queue.free_list_tail == block_to_remove
assert block_to_remove.prev_free_block == blocks[4]
assert block_to_remove.next_free_block is None

# Pop blocks until empty
for _ in range(4):
queue.popleft()
assert queue.num_free_blocks == 0
assert queue.free_list_head is None
assert queue.free_list_tail is None

# Attempt to pop from an empty queue
with pytest.raises(ValueError) as e:
queue.popleft()
assert str(e.value) == "No free blocks available"


def test_free_kv_cache_block_queue_get_all_free_blocks():
# Create a list of KVCacheBlock objects
blocks = [KVCacheBlock(block_id=i) for i in range(5)]

# Create a FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue(blocks)

# Check all blocks are correctly retrieved
assert queue.get_all_free_blocks() == blocks

# Pop a block and check again
queue.popleft()
assert queue.get_all_free_blocks() == blocks[1:]

# Remove a block and check again
block_to_remove = blocks[2]
queue.remove(block_to_remove)
assert queue.get_all_free_blocks() == blocks[1:2] + blocks[3:]

# Append a block back and check again
queue.append(block_to_remove)
assert queue.get_all_free_blocks() == \
blocks[1:2] + blocks[3:] + [block_to_remove]


def test_generate_block_hash_extra_keys():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[{
"offset": 0,
"length": 5
}, {
"offset": 10,
"length": 5
}],
mm_hashes=["hash1", "hash2"],
)

# Test with no extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
assert extra_keys == (("hash1", 0), )
assert next_mm_idx == 1

# Test with partial overlap
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0)
assert extra_keys == (("hash1", 3), )
assert next_mm_idx == 1

# Test with no overlap
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
assert extra_keys == ()
assert next_mm_idx == 1

# Test with multiple extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0)
assert extra_keys == (("hash1", 0), ("hash2", 0))
assert next_mm_idx == 2


def test_generate_block_hash_extra_keys_no_mm_inputs():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
)

extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
assert extra_keys is None
assert next_mm_idx == 0


def test_hash_block_tokens():
parent_block_hash = 123
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")

block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
extra_keys)
assert isinstance(block_hash, BlockHashType)
assert block_hash.hash_value == hash(
(parent_block_hash, *curr_block_token_ids))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys


def test_hash_request_tokens():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[{
"offset": 0,
"length": 3
}, {
"offset": 3,
"length": 3
}],
mm_hashes=["hash1", "hash2"],
)

block_size = 3
block_hashes = hash_request_tokens(block_size, request)

assert len(block_hashes) == 2
assert isinstance(block_hashes[0], BlockHashType)
assert isinstance(block_hashes[1], BlockHashType)

# Check the first block
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys == (("hash1", 0), )

# Check the second block
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys == (("hash2", 0), )


def test_hash_request_tokens_no_mm_inputs():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
)

block_size = 3
block_hashes = hash_request_tokens(block_size, request)

assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys is None
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys is None

0 comments on commit b9882fb

Please sign in to comment.