From 433b0b39d3a556bb4e04f44682ea9708eda1004d Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Sat, 28 Sep 2024 18:17:45 -0700 Subject: [PATCH] [Bugfix] Block manager v2 with preemption and lookahead slots (#8824) Signed-off-by: Amit Garg --- tests/basic_correctness/test_preemption.py | 9 +++- tests/core/block/test_block_manager_v2.py | 47 ++++++++++++++++++- tests/core/block/test_naive_block.py | 19 ++++---- tests/core/block/test_prefix_caching_block.py | 25 +++++----- vllm/core/block/cpu_gpu_block_allocator.py | 17 +++---- vllm/core/block/interfaces.py | 10 ++-- vllm/core/block/naive_block.py | 35 ++++---------- vllm/core/block/prefix_caching_block.py | 41 ++++++---------- vllm/core/block_manager_v2.py | 46 +++++++++--------- 9 files changed, 133 insertions(+), 116 deletions(-) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 00806c3e129b1..05e7859759002 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -23,8 +23,10 @@ @pytest.fixture(scope="module", autouse=True) def check_settings(): assert ENABLE_ARTIFICIAL_PREEMPT is True, ( - "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " - "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, " + "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " + "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest " "tests/basic_correctness/test_preemption.py`") @@ -199,6 +201,7 @@ def test_swap( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("beam_width", [4]) +@pytest.mark.parametrize("use_v2_block_manager", [True, False]) def test_swap_infeasible( vllm_runner, example_prompts, @@ -207,6 +210,7 @@ def test_swap_infeasible( max_tokens: int, beam_width: int, worker_use_ray: bool, + use_v2_block_manager: bool, ) -> None: """Verify infeasible swap request will be ignored.""" BLOCK_SIZE = 16 @@ -223,6 +227,7 @@ def test_swap_infeasible( num_gpu_blocks_override=prefill_blocks + decode_blocks, max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, worker_use_ray=worker_use_ray, + use_v2_block_manager=use_v2_block_manager, ) as vllm_model: sampling_params = SamplingParams(n=beam_width, use_beam_search=True, diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 30efe4437741d..e67883367879f 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -373,6 +373,52 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, seq_group, num_lookahead_slots) == AllocStatus.NEVER +@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) +@pytest.mark.parametrize("enable_caching", [False, True]) +def test_swap_in_infeasible(num_lookahead_slots, enable_caching): + """Verifies that swapping fails if there is not enough free blocks + to account for unseen tokens and lookahead_slots. + """ + block_size = 8 + num_cpu_blocks = 1 + num_gpu_blocks = 1 + block_manager = BlockSpaceManagerV2(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) + prompt_length = block_size - 3 + assert prompt_length > 0 + prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length) + prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + prompt.status = SequenceStatus.RUNNING + prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap seq group from GPU -> CPU. + assert block_manager.can_swap_out(seq_group) + block_manager.swap_out(seq_group) + prompt.status = SequenceStatus.SWAPPED + + # Swap seq group from CPU -> GPU. + # The number of unseen tokens is 1. If the number of existing + # tokens plus the unseen ones and number of lookahead slots exceeds + # the total number of available GPU blocks then the swap + # should fail. + num_unseen_tokens = 1 + if (num_lookahead_slots + num_unseen_tokens + + prompt_length) <= (block_size * num_gpu_blocks): + assert block_manager.can_swap_in(seq_group, + num_lookahead_slots) == AllocStatus.OK + else: + assert block_manager.can_swap_in( + seq_group, num_lookahead_slots) == AllocStatus.NEVER + + # TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level. @@ -400,7 +446,6 @@ def check_used(min_n, max_n=None): if max_n is None: max_n = min_n used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() - #print("check", min_n, used, max_n) assert min_n <= used assert used <= max_n diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py index e2e814c278603..10d5964dcfe8a 100644 --- a/tests/core/block/test_naive_block.py +++ b/tests/core/block/test_naive_block.py @@ -104,9 +104,9 @@ def test_get_num_free_blocks(allocate_type: str, num_blocks: int, @staticmethod @pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("block_size", [8]) - def test_naive_block_get_num_blocks_touched(num_blocks, block_size): + def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size): """ Verify the allocator can correctly return the number of - blocks touched, with different lookahead slots. + full blocks touched. """ allocator_src = NaiveBlockAllocator(create_block=NaiveBlock, num_blocks=num_blocks, @@ -124,7 +124,7 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size): src_blocks = [allocate_block() for _ in range(num_blocks - 1)] # All blocks are cached - assert allocator_dst.get_num_blocks_touched( + assert allocator_dst.get_num_full_blocks_touched( src_blocks) == num_blocks - 1 # Insert one non-full block in the src @@ -136,9 +136,10 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size): src_blocks.append(allocate_non_full_block()) src_blocks[-1].append_token_ids([0]) - assert allocator_dst.get_num_blocks_touched( - src_blocks, num_lookahead_slots=1) == num_blocks - assert allocator_dst.get_num_blocks_touched( - src_blocks, num_lookahead_slots=block_size - 1) == num_blocks - assert allocator_dst.get_num_blocks_touched( - src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1) + assert allocator_dst.get_num_full_blocks_touched( + src_blocks) == num_blocks - 1 + # Fill up the last source block and then invoke + # get_num_blocks_touched + src_blocks[-1].append_token_ids([0] * (block_size - 1)) + assert allocator_dst.get_num_full_blocks_touched( + src_blocks) == num_blocks diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 25be2dd13f8bd..1a6e17ef7b445 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -318,11 +318,10 @@ def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): @staticmethod @pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("block_size", [8]) - def test_prefix_caching_block_get_num_blocks_touched( + def test_prefix_caching_block_get_num_full_blocks_touched( num_blocks, block_size): """ Verify the allocator can correctly return the number of - blocks touched, when there are cached prefixes and different - lookahead slots. + blocks touched, when there are cached prefixes. """ allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks, block_size=block_size) @@ -346,28 +345,30 @@ def test_prefix_caching_block_get_num_blocks_touched( token_ids=token_ids, allocator=allocator_src, ) - # All blocks are cached - assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0 + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 0 # Free the first block in the dst allocator_dst.free(cached_blocks[0]) # Now the first block becomes dangling, the swapped blocks need # to reclaim the first block in the dst - assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1 + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 1 # Insert one non-full block in the src non_full_block = allocator_src.allocate_mutable_block( blocks_to_swap_in[-1]) non_full_block.append_token_ids([0]) blocks_to_swap_in.append(non_full_block) - assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in, - num_lookahead_slots=1) == 2 - assert allocator_dst.get_num_blocks_touched( - blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2 - assert allocator_dst.get_num_blocks_touched( - blocks_to_swap_in, num_lookahead_slots=block_size) == 3 + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 1 + # Fill up the last mutable block and invoke get_num_blocks_touched. + # Note: The last block is not cached so it will be touched. + non_full_block.append_token_ids([0] * (block_size - 1)) + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 2 @staticmethod @pytest.mark.parametrize("num_blocks", [1024]) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index c87246c1c6d6a..6eda5f99aa1c8 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -259,25 +259,22 @@ def swap(self, blocks: List[Block], src_device: Device, current_swap_mapping[src_block_id] = dst_block_id return current_swap_mapping - def get_num_blocks_touched(self, - blocks: List[Block], - device: Device, - num_lookahead_slots: int = 0) -> int: - """Returns the number of blocks that will be touched by + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + """Returns the number of full blocks that will be touched by swapping in/out the given blocks on to the 'device'. Args: blocks: List of blocks to be swapped. device (Device): Device to swap the 'blocks' on. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. Returns: - int: the number of blocks that will be touched by + int: the number of full blocks that will be touched by swapping in/out the given blocks on to the 'device'. + Non full blocks are ignored when deciding the number + of blocks to touch. """ - return self._allocators[device].get_num_blocks_touched( - blocks, num_lookahead_slots) + return self._allocators[device].get_num_full_blocks_touched(blocks) def clear_copy_on_writes(self) -> List[Tuple[int, int]]: """Clears the copy-on-write (CoW) state and returns the mapping of diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index f26bc761c9967..72bbab1dcea5d 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -181,9 +181,7 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: pass @abstractmethod - def get_num_blocks_touched(self, - blocks: List[Block], - num_lookahead_slots: int = 0) -> int: + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: pass @abstractmethod @@ -260,10 +258,8 @@ def get_common_computed_block_ids( pass @abstractmethod - def get_num_blocks_touched(self, - blocks: List[Block], - device: Device, - num_lookahead_slots: int = 0) -> int: + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: pass @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 1643fd69c58ab..9341a518d11c6 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -4,7 +4,6 @@ from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -from vllm.utils import cdiv Refcount = int @@ -282,40 +281,26 @@ def get_common_computed_block_ids( def promote_to_immutable_block(self, block: Block) -> BlockId: raise NotImplementedError("There is no promotion for naive blocks") - def get_num_blocks_touched(self, - blocks: List[Block], - num_lookahead_slots: int = 0) -> int: - """Determine the number of blocks that will be touched by - swapping in/out the given blocks from certain sequence - group with the provided num_lookahead_slots. + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. Args: - blocks (List[Block]): The potential blocks to swap. - num_lookahead_slots (int): number of lookahead slots (0 for swap - out). - + blocks: List of blocks to be swapped. Returns: - int: the number of blocks that will be touched by - swapping in/out the given blocks and num_lookahead_slots. + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. """ # NOTE: for naive block, we use set to eliminate common blocks among # seqs, also we compare the empty slots in the mutable blocks with # lookahead slots to get the number of unique new block that are # needed. old_block_set = set() - new_block_count = 0 - # TODO(cade): make sure the logic is correct and clean it up. for block in blocks: - if not block.is_full and num_lookahead_slots != 0: - new_block_count += 1 - if num_lookahead_slots > block.num_empty_slots: - new_block_count += cdiv( - num_lookahead_slots - block.num_empty_slots, - self._block_size) - else: - old_block_set.add(block.block_id) - num_touched_blocks = new_block_count + len(old_block_set) - return num_touched_blocks + if block.is_full: + old_block_set.add(block) + return len(old_block_set) def swap_out(self, blocks: List[Block]) -> None: for block in blocks: diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index db67c95c32429..7c8a2bc493513 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -8,7 +8,6 @@ from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor -from vllm.utils import cdiv PrefixHash = int @@ -576,37 +575,27 @@ def get_common_computed_block_ids( if ids ]) - def get_num_blocks_touched(self, - blocks: List[Block], - num_lookahead_slots: int = 0) -> int: - """Determine the number of blocks that will be touched by - swapping in/out the given blocks from certain sequence - group with the provided num_lookahead_slots. + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. Args: - blocks (List[Block]): The potential blocks to swap. - num_lookahead_slots (int): number of lookahead slots (0 for - swap out). - + blocks: List of blocks to be swapped. Returns: - int: the number of blocks that will be touched by - swapping in/out the given blocks and num_lookahead_slots. + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. """ - num_touched_blocks = 0 + num_touched_blocks: int = 0 for block in blocks: - if not block.is_full: + # If the block has a match in the cache and the cached + # block is not referenced, then we still count it as a + # touched block + if block.is_full and (not self.is_block_cached(block) or \ + (block.content_hash is not None and \ + self._cached_blocks[block.content_hash] in \ + self.evictor)): num_touched_blocks += 1 - if num_lookahead_slots > block.num_empty_slots: - num_touched_blocks += cdiv( - num_lookahead_slots - block.num_empty_slots, - self._block_size) - else: - # If the block has a match in the cache and the cached block - # is not referenced, then we still count it as a touched block - if not self.is_block_cached(block) or \ - (block.content_hash is not None and \ - self._cached_blocks[block.content_hash] in self.evictor): - num_touched_blocks += 1 return num_touched_blocks def swap_out(self, blocks: List[Block]) -> None: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index bb78b1e1c9138..0fad5fa99daf8 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,5 +1,4 @@ """A block manager that manages token blocks.""" -from itertools import chain from typing import Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple @@ -470,12 +469,31 @@ def _can_swap(self, AllocStatus: The AllocStatus for swapping in/out the given sequence_group on to the 'device'. """ - blocks = self._get_blocks_for_swap(seq_group, status) - num_blocks_touched = self.block_allocator.get_num_blocks_touched( - blocks, device, num_lookahead_slots) + # First determine the number of blocks that will be touched by this + # swap. Then verify if there are available blocks in the device + # to perform the swap. + num_blocks_touched = 0 + blocks: List[Block] = [] + for seq in seq_group.get_seqs(status=status): + block_table = self.block_tables[seq.seq_id] + if block_table.blocks is not None: + # Compute the number blocks to touch for the tokens to be + # appended. This does NOT include the full blocks that need + # to be touched for the swap. + num_blocks_touched += \ + block_table.get_num_blocks_touched_by_append_slots( + block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots) + blocks.extend(block_table.blocks) + # Compute the number of full blocks to touch and add it to the + # existing count of blocks to touch. + num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( + blocks, device=device) + watermark_blocks = 0 if device == Device.GPU: watermark_blocks = self.watermark_blocks + if self.block_allocator.get_num_total_blocks( device) < num_blocks_touched: return AllocStatus.NEVER @@ -484,23 +502,3 @@ def _can_swap(self, return AllocStatus.OK else: return AllocStatus.LATER - - def _get_blocks_for_swap(self, seq_group: SequenceGroup, - status: SequenceStatus) -> List[Block]: - """Returns the list of blocks those are touched by the seq_group - - Args: - sequence_group (SequenceGroup): The sequence group to swap in. - status (SequenceStatus): The status of sequence which is needed - for action. RUNNING for swap out and SWAPPED for swap in - - Returns: - The list of blocks those are touched by the seq_group. - """ - blocks: Dict[int, List[Block]] = {} - for seq in seq_group.get_seqs(status=status): - block_table = self.block_tables[seq.seq_id] - if block_table.blocks is not None: - blocks[seq.seq_id] = block_table.blocks - combined_blocks = list(chain(*blocks.values())) - return combined_blocks