Skip to content

Commit

Permalink
Add real page pool tests for trie_attention_cache (#902)
Browse files Browse the repository at this point in the history
Previously, we were testing with mocked page pools so the tests run
faster. In this PR, I split trie_attention_cache_tests.py into 2 files:

trie_attention_cache/mock_pool_tests.py contains the old tests, and we
continue to test with a mocked-up page pool to verify that the trie
correctly does accounting for the pages and the evictions.

trie_attention_cache/real_pool_tests.py will contain new tests for
page-copying prefix sharing, so that we won't have to recompute the
entire last page's worth of KV if branching on a token. Since we're
copying the page, the tests will need to not mock the page pool and
actually allocate the buffer, which will make them slower. I opted to do
this separately from the old tests so that we won't have to take 5-ish
seconds to set up the buffer for each of the 30 ish tests.

This PR also replaces some of the nuisance print statements with
logging.debug.

~~This is a step on the way to implement beam search (required by
MLPerf).~~

Edit: [MLPerf only requires beam search for
GPT-J](https://github.com/mlcommons/inference_policies/blob/master/inference_rules.adoc#:~:text=Q%3A%20What%20algorithm,uses%20greedy%20search.).
Thanks @stbaione
  • Loading branch information
renxida authored Feb 5, 2025
1 parent 6aeaaee commit 17c8369
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Trie attention cache tests with a mocked page-pool.
This file contains trie attention cache tests that don't require writing to the actual page.
Since we mock all dependencies of the page pool, we don't need to initialize systems, devices, and device-arrays for every test.
Everything runs A LOT faster this way.
"""

import pytest
from typing import List, Tuple
import shortfin as sf
Expand All @@ -7,6 +17,7 @@
import threading
import time
from dataclasses import dataclass
import logging

from shortfin_apps.llm.components.kvcache.trie_attention_cache import (
TriePagedAttentionCache,
Expand All @@ -20,6 +31,7 @@
PagePoolConfig,
)

logger = logging.getLogger(__name__)

# Test constants
TEST_PAGE_SIZE = 16 # Tokens per page
Expand Down Expand Up @@ -152,7 +164,7 @@ def _publish_sequence(tokens: List[int]) -> None:
def print_tree_state(cache, prefix=""):
"""Helper function to print current tree state in a readable format"""
if not hasattr(cache, "root"):
print(f"{prefix}Unable to access trie structure")
logger.debug(f"{prefix}Unable to access trie structure")
return

def node_info(node):
Expand All @@ -161,12 +173,12 @@ def node_info(node):

def print_node(node, depth=0):
indent = " " * depth
print(f"{prefix}{indent}- {node_info(node)}")
logger.debug(f"{prefix}{indent}- {node_info(node)}")
if node.children:
for child in node.children.values():
print_node(child, depth + 1)

print(f"{prefix}Tree state:")
logger.debug(f"{prefix}Tree state:")
print_node(cache.root)


Expand Down Expand Up @@ -284,66 +296,68 @@ def filled_cache(trie_cache, published_sequence):
)
def test_lru_eviction(trie_cache, access_count):
"""Test LRU eviction with different access patterns"""
print(f"\nStarting test_lru_eviction with access_count={access_count}")
logger.debug(f"\nStarting test_lru_eviction with access_count={access_count}")

# Create mix of published and unpublished sequences
keep_published = 3 # Number of sequences to keep published
sequences = []

# First add some sequences we'll keep published
print("\nPublishing sequences to keep active:")
logger.debug("\nPublishing sequences to keep active:")
for i in range(keep_published):
tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE))
alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0)
alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE])
sequences.append(tokens)
print(f"Published sequence {i} (keeping active)")
logger.debug(f"Published sequence {i} (keeping active)")
print_tree_state(trie_cache, " ")

# Then add sequences we'll publish but release (evictable)
print("\nAdding releasable sequences:")
logger.debug("\nAdding releasable sequences:")
for i in range(keep_published, TEST_POOL_CAPACITY):
tokens = list(range(i * 100, i * 100 + TEST_PAGE_SIZE))
alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0)
alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE])
alloc.release_pages() # These can be evicted
sequences.append(tokens)
print(f"Added releasable sequence {i}")
logger.debug(f"Added releasable sequence {i}")
print_tree_state(trie_cache, " ")

print("\nCache state before accessing sequences:")
logger.debug("\nCache state before accessing sequences:")
print_tree_state(trie_cache, " ")

# Access some sequences to update their LRU status
print(f"\nAccessing {access_count} sequences to update LRU order:")
logger.debug(f"\nAccessing {access_count} sequences to update LRU order:")
for i in range(access_count):
print(f"\nAccessing sequence {i}:")
logger.debug(f"\nAccessing sequence {i}:")
alloc = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0)
print_tree_state(trie_cache, " ")
alloc.release_pages()
print(f"After releasing allocation {i}:")
logger.debug(f"After releasing allocation {i}:")
print_tree_state(trie_cache, " ")

print("\nCache state before attempting new allocation:")
logger.debug("\nCache state before attempting new allocation:")
print_tree_state(trie_cache, " ")
print("\nAvailable pages in pool:", len(trie_cache.page_pool.available_pages))
logger.debug(
"\nAvailable pages in pool:", len(trie_cache.page_pool.available_pages)
)

# Try to allocate new sequence - should evict least recently used unpublished sequence
new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE))
print(f"\nAttempting to allocate new sequence: {new_tokens}")
logger.debug(f"\nAttempting to allocate new sequence: {new_tokens}")
new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0)
print("\nNew allocation succeeded:")
print("\nCache state after new allocation:")
logger.debug("\nNew allocation succeeded:")
logger.debug("\nCache state after new allocation:")
print_tree_state(trie_cache, " ")
new_alloc.release_pages()

# Verify recently accessed sequences AND published sequences weren't evicted
print("\nVerifying preserved sequences:")
logger.debug("\nVerifying preserved sequences:")
for i in range(max(access_count, keep_published)):
print(f"\nChecking sequence {i}:")
logger.debug(f"\nChecking sequence {i}:")
recheck = trie_cache.acquire_pages_for_tokens(sequences[i], extra_token_slots=0)
cached_pages = recheck.number_of_published_pages
print(f"- Cached pages found: {cached_pages}")
logger.debug(f"- Cached pages found: {cached_pages}")
assert (
cached_pages == 1
), f"Sequence {i} was evicted but should have been preserved"
Expand All @@ -353,61 +367,65 @@ def test_lru_eviction(trie_cache, access_count):
@pytest.mark.parametrize("publish_steps", [1, 2, 3])
def test_progressive_publish(trie_cache, publish_steps):
"""Test publishing pages progressively"""
print(f"\nStarting test_progressive_publish with publish_steps={publish_steps}")
logger.debug(
f"\nStarting test_progressive_publish with publish_steps={publish_steps}"
)

tokens = tuple(range(TEST_PAGE_SIZE * 3)) # Three pages
print(f"\nInitial tokens: {tokens}")
print(f"Tokens per page: {TEST_PAGE_SIZE}")
print(
logger.debug(f"\nInitial tokens: {tokens}")
logger.debug(f"Tokens per page: {TEST_PAGE_SIZE}")
logger.debug(
f"Expected total pages: {len(tokens) // TEST_PAGE_SIZE + (1 if len(tokens) % TEST_PAGE_SIZE else 0)}"
)

print("\nInitial cache state:")
logger.debug("\nInitial cache state:")
print_tree_state(trie_cache)

print("\nAcquiring initial allocation...")
logger.debug("\nAcquiring initial allocation...")
alloc = trie_cache.acquire_pages_for_tokens(tokens)
print(f"Initial allocation pages: {[p.index for p in alloc.pages]}")
print("\nCache state after initial allocation:")
logger.debug(f"Initial allocation pages: {[p.index for p in alloc.pages]}")
logger.debug("\nCache state after initial allocation:")
print_tree_state(trie_cache)

for step in range(1, publish_steps + 1):
print(f"\n--- Step {step} of {publish_steps} ---")
logger.debug(f"\n--- Step {step} of {publish_steps} ---")

# Publish next page
print(f"Publishing up to page {step}")
logger.debug(f"Publishing up to page {step}")
# Replace publishing with tokens
alloc.publish_pages_for_tokens(alloc.tokens[: (step) * TEST_PAGE_SIZE])
print("\nCache state after publish:")
logger.debug("\nCache state after publish:")
print_tree_state(trie_cache)

# Verify reuse up to published point
reuse_tokens = tokens[: (step) * TEST_PAGE_SIZE]
print(f"\nAttempting to reuse tokens: {reuse_tokens}")
print(f"Expected cached pages: {step}")
logger.debug(f"\nAttempting to reuse tokens: {reuse_tokens}")
logger.debug(f"Expected cached pages: {step}")

reuse_alloc = trie_cache.acquire_pages_for_tokens(reuse_tokens)
print(f"Reuse allocation total pages: {len(reuse_alloc.pages)}")
print(f"Reuse allocation cached pages: {reuse_alloc.number_of_published_pages}")
logger.debug(f"Reuse allocation total pages: {len(reuse_alloc.pages)}")
logger.debug(
f"Reuse allocation cached pages: {reuse_alloc.number_of_published_pages}"
)

print("\nCache state after reuse attempt:")
logger.debug("\nCache state after reuse attempt:")
print_tree_state(trie_cache)

try:
assert reuse_alloc.number_of_published_pages == step
except AssertionError:
print("\nASSERTION FAILED!")
print(
logger.debug("\nASSERTION FAILED!")
logger.debug(
f"Expected {step} cached pages but got {reuse_alloc.number_of_published_pages}"
)
raise

reuse_alloc.release_pages()
print("\nCache state after releasing reuse allocation:")
logger.debug("\nCache state after releasing reuse allocation:")
print_tree_state(trie_cache)

alloc.release_pages()
print("\nFinal cache state after releasing initial allocation:")
logger.debug("\nFinal cache state after releasing initial allocation:")
print_tree_state(trie_cache)


Expand All @@ -422,14 +440,14 @@ def test_reference_counting(trie_cache, ref_count):
# Replace publishing with tokens
first_alloc.publish_pages_for_tokens(first_alloc.tokens)
allocations.append(first_alloc)
print("\nInitial allocation created")
logger.debug("\nInitial allocation created")
print_tree_state(trie_cache, " ")

# Create additional references
for i in range(ref_count - 1):
alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0)
allocations.append(alloc)
print(f"\nCreated reference {i+1}")
logger.debug(f"\nCreated reference {i+1}")
print_tree_state(trie_cache, " ")

# Fill remaining cache
Expand All @@ -442,22 +460,22 @@ def test_reference_counting(trie_cache, ref_count):
alloc = trie_cache.acquire_pages_for_tokens(fill_tokens, extra_token_slots=0)
alloc.publish_pages_for_tokens(alloc.tokens[:TEST_PAGE_SIZE])
fill_allocations.append(alloc)
print(f"\nFilled cache slot {i+1}/{remaining}")
logger.debug(f"\nFilled cache slot {i+1}/{remaining}")
print_tree_state(trie_cache, " ")

print("\nAttempting allocation that should fail...")
logger.debug("\nAttempting allocation that should fail...")
try:
new_tokens = list(range(1000, 1000 + TEST_PAGE_SIZE))
new_alloc = trie_cache.acquire_pages_for_tokens(new_tokens, extra_token_slots=0)
print("ERROR: Allocation succeeded when it should have failed!")
print("\nPost-allocation state:")
logger.debug("ERROR: Allocation succeeded when it should have failed!")
logger.debug("\nPost-allocation state:")
print_tree_state(trie_cache, " ")
new_alloc.release_pages()
pytest.fail("Expected CacheAllocationFailure was not raised")
except CacheAllocationFailure:
print("Success: CacheAllocationFailure raised as expected")
logger.debug("Success: CacheAllocationFailure raised as expected")

# Cleanup
print("\nCleaning up allocations...")
logger.debug("\nCleaning up allocations...")
for alloc in allocations + fill_allocations:
alloc.release_pages()
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Trie attention cache tests with a real page pool.
This file contains tests that involve writing data to the page. Tests that deal purely with trie cache structure should go in `mock_pool_tests.py`.
Each test requires us to initialize a new page pool & page table device array. Tests here will be a LOT slower.
"""


import pytest
from typing import List
import shortfin as sf
import shortfin.array as sfnp
import time
import logging
from dataclasses import dataclass

from shortfin_apps.llm.components.kvcache.trie_attention_cache import (
TriePagedAttentionCache,
)
from shortfin_apps.llm.components.kvcache.page_pool import (
PagePool,
PagePoolConfig,
)


# Test constants
TEST_PAGE_SIZE = 16 # Tokens per page

# Note: Using a very small block size (8 elements) for testing purposes.
# In real applications, this would typically be much larger for performance reasons.
TEST_BLOCK_SIZE = 8
TEST_POOL_CAPACITY = 256


# set up logging
logger = logging.getLogger(__name__)


@pytest.fixture
def real_device():
"""Create a real device using the system manager"""
sc = sf.host.CPUSystemBuilder()
with sc.create_system() as ls:
worker = ls.create_worker("test-worker")
fiber = ls.create_fiber(worker)
yield list(fiber.devices_dict.values())[0] # Get the first device


@pytest.fixture
def page_pool(real_device):
"""Create a real PagePool with test parameters"""
config = PagePoolConfig(
dtype=sfnp.float32, # Using float32 as requested
alloc_page_count=TEST_POOL_CAPACITY, # Using 256 pages as requested
paged_kv_block_size_elements=TEST_BLOCK_SIZE, # Using small block size (8) for testing
)

return PagePool(devices=[real_device], config=config)


@pytest.fixture
def trie_cache(page_pool):
"""Create TriePagedAttentionCache instance"""
return TriePagedAttentionCache(page_pool=page_pool, tokens_per_page=TEST_PAGE_SIZE)


@pytest.fixture
def published_sequence(trie_cache):
"""Helper fixture that returns a function to publish token sequences"""

def _publish_sequence(tokens: List[int]) -> None:
alloc = trie_cache.acquire_pages_for_tokens(tokens, extra_token_slots=0)
alloc.publish_pages_for_tokens(alloc.tokens)
alloc.release_pages()

return _publish_sequence


@pytest.mark.xfail(reason="Partial page reuse is not yet implemented.", strict=True)
def test_partial_page_publishing(trie_cache):
"""Test that we can publish partial pages and match them correctly"""
# Create a sequence that's 1.5 pages long and publish it
tokens = list(range(TEST_PAGE_SIZE + TEST_PAGE_SIZE // 2))
alloc1 = trie_cache.acquire_pages_for_tokens(tokens)
# write to the first page

alloc1.publish_pages_for_tokens(tokens)

# Try to match exactly half of the second page
match_tokens = tokens[: TEST_PAGE_SIZE + TEST_PAGE_SIZE // 2]
alloc2 = trie_cache.acquire_pages_for_tokens(match_tokens)

# We should match both the full first page and half of the second page
assert (
alloc2.number_of_published_pages == 2
), "Should match both pages, including the partial one"
# We should not get the same second page
assert (
alloc2.pages[1].index != alloc1.pages[1].index
), "Should not match the same second page"

0 comments on commit 17c8369

Please sign in to comment.