Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Block manager v2 with preemption and lookahead slots #8824

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
b3575e9
Merge branch 'vllm-project:main' into main
sroy745 Jun 20, 2024
94b0d43
Merge branch 'vllm-project:main' into main
sroy745 Jun 24, 2024
fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
b71c533
Merge branch 'vllm-project:main' into main
sroy745 Jun 28, 2024
57babef
Merge branch 'vllm-project:main' into main
sroy745 Jun 29, 2024
4b19bac
Merge branch 'vllm-project:main' into main
sroy745 Jul 1, 2024
eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 Jul 6, 2024
7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 Jul 10, 2024
6212d5f
Merge branch 'vllm-project:main' into main
sroy745 Jul 15, 2024
5491438
Merge branch 'vllm-project:main' into main
sroy745 Jul 17, 2024
68e080a
Merge branch 'vllm-project:main' into main
sroy745 Jul 31, 2024
55e4332
Merge branch 'vllm-project:main' into main
sroy745 Aug 13, 2024
532eb48
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
7cea056
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
185e056
Merge branch 'vllm-project:main' into main
sroy745 Aug 24, 2024
e2be95f
Merge branch 'vllm-project:main' into main
sroy745 Aug 27, 2024
2ed5473
Merge branch 'vllm-project:main' into main
sroy745 Aug 28, 2024
efa4714
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
fb87d34
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
5419e49
Merge branch 'vllm-project:main' into main
sroy745 Aug 31, 2024
9ba12f8
Merge branch 'vllm-project:main' into main
sroy745 Sep 2, 2024
25cef3d
Merge branch 'vllm-project:main' into main
sroy745 Sep 3, 2024
9d4cd09
Merge branch 'vllm-project:main' into main
sroy745 Sep 4, 2024
c48cacb
Merge branch 'vllm-project:main' into main
sroy745 Sep 5, 2024
c42c399
Merge branch 'vllm-project:main' into main
sroy745 Sep 7, 2024
3d13e43
Merge branch 'vllm-project:main' into main
sroy745 Sep 9, 2024
7479775
Merge branch 'vllm-project:main' into main
sroy745 Sep 11, 2024
df9b966
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
9a7ed92
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
118e838
Merge branch 'vllm-project:main' into main
sroy745 Sep 19, 2024
e640c69
Merge branch 'vllm-project:main' into main
sroy745 Sep 20, 2024
89fb6cd
Merge branch 'vllm-project:main' into main
sroy745 Sep 23, 2024
5d886cc
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
56f2065
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
28e103e
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
2fc1490
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
caef280
Fix test in test_preemption.py that fails with BlockManagerV2
sroy745 Sep 25, 2024
03276ae
Use unseen tokens during touched blocks computation for swap in
sroy745 Sep 26, 2024
de4acb7
Revert change
sroy745 Sep 26, 2024
8805750
Merge branch 'vllm-project:main' into main
sroy745 Sep 26, 2024
976dd06
Fix format and logic
sroy745 Sep 26, 2024
494bce4
Fix format
sroy745 Sep 26, 2024
b8b097c
Fix logic
sroy745 Sep 26, 2024
14de5f0
Fix tests and logic
sroy745 Sep 26, 2024
d094f4f
Fix tests, logic and format
sroy745 Sep 27, 2024
e69809e
Format
sroy745 Sep 27, 2024
49b8d2c
Add docstring
sroy745 Sep 27, 2024
da5b541
Fix comments
sroy745 Sep 27, 2024
114b1d0
Fix test
sroy745 Sep 27, 2024
81350e6
Format
sroy745 Sep 27, 2024
c0537af
Addressing comments
sroy745 Sep 28, 2024
962b17f
Dummy
sroy745 Sep 28, 2024
fa32ba8
Dummy
sroy745 Sep 28, 2024
b30e5af
Merge branch 'vllm-project:main' into main
sroy745 Sep 28, 2024
0f42599
Merge remote-tracking branch 'origin/main' into sroy-fix-premption-te…
sroy745 Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
@pytest.fixture(scope="module", autouse=True)
def check_settings():
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest "
"tests/basic_correctness/test_preemption.py`")


Expand Down Expand Up @@ -199,6 +201,7 @@ def test_swap(
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
@pytest.mark.parametrize("use_v2_block_manager", [True, False])
def test_swap_infeasible(
vllm_runner,
example_prompts,
Expand All @@ -207,6 +210,7 @@ def test_swap_infeasible(
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
use_v2_block_manager: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
Expand All @@ -223,6 +227,7 @@ def test_swap_infeasible(
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
Expand Down
47 changes: 46 additions & 1 deletion tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,52 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
seq_group, num_lookahead_slots) == AllocStatus.NEVER


@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10])
@pytest.mark.parametrize("enable_caching", [False, True])
def test_swap_in_infeasible(num_lookahead_slots, enable_caching):
"""Verifies that swapping fails if there is not enough free blocks
to account for unseen tokens and lookahead_slots.
"""
block_size = 8
num_cpu_blocks = 1
num_gpu_blocks = 1
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt_length = block_size - 3
assert prompt_length > 0
prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
prompt.status = SequenceStatus.RUNNING
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})

# Swap seq group from GPU -> CPU.
assert block_manager.can_swap_out(seq_group)
block_manager.swap_out(seq_group)
prompt.status = SequenceStatus.SWAPPED

# Swap seq group from CPU -> GPU.
# The number of unseen tokens is 1. If the number of existing
# tokens plus the unseen ones and number of lookahead slots exceeds
# the total number of available GPU blocks then the swap
# should fail.
num_unseen_tokens = 1
if (num_lookahead_slots + num_unseen_tokens +
prompt_length) <= (block_size * num_gpu_blocks):
assert block_manager.can_swap_in(seq_group,
num_lookahead_slots) == AllocStatus.OK
else:
assert block_manager.can_swap_in(
seq_group, num_lookahead_slots) == AllocStatus.NEVER


# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.


Expand Down Expand Up @@ -400,7 +446,6 @@ def check_used(min_n, max_n=None):
if max_n is None:
max_n = min_n
used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
#print("check", min_n, used, max_n)
assert min_n <= used
assert used <= max_n

Expand Down
19 changes: 10 additions & 9 deletions tests/core/block/test_naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def test_get_num_free_blocks(allocate_type: str, num_blocks: int,
@staticmethod
@pytest.mark.parametrize("num_blocks", [4])
@pytest.mark.parametrize("block_size", [8])
def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size):
""" Verify the allocator can correctly return the number of
blocks touched, with different lookahead slots.
full blocks touched.
"""
allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
Expand All @@ -124,7 +124,7 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
src_blocks = [allocate_block() for _ in range(num_blocks - 1)]

# All blocks are cached
assert allocator_dst.get_num_blocks_touched(
assert allocator_dst.get_num_full_blocks_touched(
src_blocks) == num_blocks - 1

# Insert one non-full block in the src
Expand All @@ -136,9 +136,10 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
src_blocks.append(allocate_non_full_block())
src_blocks[-1].append_token_ids([0])

assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=1) == num_blocks
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=block_size - 1) == num_blocks
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1)
assert allocator_dst.get_num_full_blocks_touched(
src_blocks) == num_blocks - 1
# Fill up the last source block and then invoke
# get_num_blocks_touched
src_blocks[-1].append_token_ids([0] * (block_size - 1))
assert allocator_dst.get_num_full_blocks_touched(
src_blocks) == num_blocks
25 changes: 13 additions & 12 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,10 @@ def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int):
@staticmethod
@pytest.mark.parametrize("num_blocks", [4])
@pytest.mark.parametrize("block_size", [8])
def test_prefix_caching_block_get_num_blocks_touched(
def test_prefix_caching_block_get_num_full_blocks_touched(
num_blocks, block_size):
""" Verify the allocator can correctly return the number of
blocks touched, when there are cached prefixes and different
lookahead slots.
blocks touched, when there are cached prefixes.
"""
allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
Expand All @@ -346,28 +345,30 @@ def test_prefix_caching_block_get_num_blocks_touched(
token_ids=token_ids,
allocator=allocator_src,
)

# All blocks are cached
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0
assert allocator_dst.get_num_full_blocks_touched(
blocks_to_swap_in) == 0

# Free the first block in the dst
allocator_dst.free(cached_blocks[0])

# Now the first block becomes dangling, the swapped blocks need
# to reclaim the first block in the dst
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1
assert allocator_dst.get_num_full_blocks_touched(
blocks_to_swap_in) == 1

# Insert one non-full block in the src
non_full_block = allocator_src.allocate_mutable_block(
blocks_to_swap_in[-1])
non_full_block.append_token_ids([0])
blocks_to_swap_in.append(non_full_block)
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in,
num_lookahead_slots=1) == 2
assert allocator_dst.get_num_blocks_touched(
blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2
assert allocator_dst.get_num_blocks_touched(
blocks_to_swap_in, num_lookahead_slots=block_size) == 3
assert allocator_dst.get_num_full_blocks_touched(
blocks_to_swap_in) == 1
# Fill up the last mutable block and invoke get_num_blocks_touched.
# Note: The last block is not cached so it will be touched.
non_full_block.append_token_ids([0] * (block_size - 1))
assert allocator_dst.get_num_full_blocks_touched(
blocks_to_swap_in) == 2

@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
Expand Down
17 changes: 7 additions & 10 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,25 +259,22 @@ def swap(self, blocks: List[Block], src_device: Device,
current_swap_mapping[src_block_id] = dst_block_id
return current_swap_mapping

def get_num_blocks_touched(self,
blocks: List[Block],
device: Device,
num_lookahead_slots: int = 0) -> int:
"""Returns the number of blocks that will be touched by
def get_num_full_blocks_touched(self, blocks: List[Block],
device: Device) -> int:
"""Returns the number of full blocks that will be touched by
swapping in/out the given blocks on to the 'device'.

Args:
blocks: List of blocks to be swapped.
device (Device): Device to swap the 'blocks' on.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.

Returns:
int: the number of blocks that will be touched by
int: the number of full blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
Non full blocks are ignored when deciding the number
of blocks to touch.
"""
return self._allocators[device].get_num_blocks_touched(
blocks, num_lookahead_slots)
return self._allocators[device].get_num_full_blocks_touched(blocks)

def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
Expand Down
10 changes: 3 additions & 7 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ def promote_to_immutable_block(self, block: Block) -> BlockId:
pass

@abstractmethod
def get_num_blocks_touched(self,
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
pass

@abstractmethod
Expand Down Expand Up @@ -260,10 +258,8 @@ def get_common_computed_block_ids(
pass

@abstractmethod
def get_num_blocks_touched(self,
blocks: List[Block],
device: Device,
num_lookahead_slots: int = 0) -> int:
def get_num_full_blocks_touched(self, blocks: List[Block],
device: Device) -> int:
pass

@abstractmethod
Expand Down
35 changes: 10 additions & 25 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.utils import cdiv

Refcount = int

Expand Down Expand Up @@ -282,40 +281,26 @@ def get_common_computed_block_ids(
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError("There is no promotion for naive blocks")

def get_num_blocks_touched(self,
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
"""Returns the number of full blocks that will be touched by
swapping in/out.

Args:
blocks (List[Block]): The potential blocks to swap.
num_lookahead_slots (int): number of lookahead slots (0 for swap
out).

blocks: List of blocks to be swapped.
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots.
int: the number of full blocks that will be touched by
swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
"""
# NOTE: for naive block, we use set to eliminate common blocks among
# seqs, also we compare the empty slots in the mutable blocks with
# lookahead slots to get the number of unique new block that are
# needed.
old_block_set = set()
new_block_count = 0
# TODO(cade): make sure the logic is correct and clean it up.
for block in blocks:
if not block.is_full and num_lookahead_slots != 0:
new_block_count += 1
if num_lookahead_slots > block.num_empty_slots:
new_block_count += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)
else:
old_block_set.add(block.block_id)
num_touched_blocks = new_block_count + len(old_block_set)
return num_touched_blocks
if block.is_full:
old_block_set.add(block)
return len(old_block_set)

def swap_out(self, blocks: List[Block]) -> None:
for block in blocks:
Expand Down
41 changes: 15 additions & 26 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
from vllm.utils import cdiv

PrefixHash = int

Expand Down Expand Up @@ -576,37 +575,27 @@ def get_common_computed_block_ids(
if ids
])

def get_num_blocks_touched(self,
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
def get_num_full_blocks_touched(self, blocks: List[Block]) -> int:
"""Returns the number of full blocks that will be touched by
swapping in/out.

Args:
blocks (List[Block]): The potential blocks to swap.
num_lookahead_slots (int): number of lookahead slots (0 for
swap out).

blocks: List of blocks to be swapped.
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots.
int: the number of full blocks that will be touched by
swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
"""
num_touched_blocks = 0
num_touched_blocks: int = 0
for block in blocks:
if not block.is_full:
# If the block has a match in the cache and the cached
# block is not referenced, then we still count it as a
# touched block
if block.is_full and (not self.is_block_cached(block) or \
(block.content_hash is not None and \
self._cached_blocks[block.content_hash] in \
self.evictor)):
num_touched_blocks += 1
if num_lookahead_slots > block.num_empty_slots:
num_touched_blocks += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)
else:
# If the block has a match in the cache and the cached block
# is not referenced, then we still count it as a touched block
if not self.is_block_cached(block) or \
(block.content_hash is not None and \
self._cached_blocks[block.content_hash] in self.evictor):
num_touched_blocks += 1
return num_touched_blocks

def swap_out(self, blocks: List[Block]) -> None:
Expand Down
Loading
Loading