Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]: Implement CPU/GPU swapping in BlockManagerV2 #3834

Merged
merged 39 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
182d4a8
feat: support swap in/out for block manager v2
Kaiyang-Chen Apr 3, 2024
b6b4b8f
fix: linter
Kaiyang-Chen Apr 3, 2024
938d10e
fix: fix some bugs and add test
Kaiyang-Chen Apr 4, 2024
9181552
fix: address comment
Kaiyang-Chen Apr 4, 2024
e9a907f
fix: reduce overestimate for can_swap_in
Kaiyang-Chen Apr 5, 2024
dcff0e1
fix: reuse similar logic in can_swap_in to reduce overestimation in c…
Kaiyang-Chen Apr 5, 2024
205dda1
fix: refactor swap in/out logic
Kaiyang-Chen Apr 5, 2024
3bb125c
misc: remove useless code
Kaiyang-Chen Apr 5, 2024
403a9bd
fix: refactor can_swap_in/out
Kaiyang-Chen Apr 12, 2024
3237d63
fix: remove unused code
Kaiyang-Chen Apr 12, 2024
4131247
fix: remove unused code
Kaiyang-Chen Apr 12, 2024
0067ddf
fix: refactor swap in/out oprations
Kaiyang-Chen Apr 12, 2024
b8aee85
fix
Kaiyang-Chen Apr 12, 2024
cba0f62
fix
Kaiyang-Chen Apr 12, 2024
0430758
doc: adding docstring
Kaiyang-Chen Apr 30, 2024
fbb3099
test: adding e2e correstness test for preemption by swapping
Kaiyang-Chen May 1, 2024
66a7bbd
fix
Kaiyang-Chen May 1, 2024
35d391e
remove import for __future__.annotations
Kaiyang-Chen May 1, 2024
13ab5f5
fix: address comments
Kaiyang-Chen May 2, 2024
a1e228c
feat: add preemption as an user input arg
Kaiyang-Chen May 2, 2024
9848419
nit
Kaiyang-Chen May 2, 2024
fc5726d
Merge branch 'main' into dev_block_manager_v2_swap
Kaiyang-Chen May 2, 2024
170d5a2
fix: format and test
Kaiyang-Chen May 3, 2024
c7a3484
fix: ruff
Kaiyang-Chen May 3, 2024
c252294
test: add enable_cache=True for test_swap
Kaiyang-Chen May 3, 2024
880b855
nit
Kaiyang-Chen May 3, 2024
f16e9f1
Merge branch 'main' into dev_block_manager_v2_swap
Kaiyang-Chen May 3, 2024
8b2217b
Merge branch 'main' into dev_block_manager_v2_swap
Kaiyang-Chen May 10, 2024
fe13a91
fix
Kaiyang-Chen May 10, 2024
773d331
fix: test
Kaiyang-Chen May 10, 2024
37d9b31
test: retry ci tests
Kaiyang-Chen May 10, 2024
a2f1df3
retry
Kaiyang-Chen May 10, 2024
228950a
Merge branch 'main' into dev_block_manager_v2_swap
Kaiyang-Chen May 13, 2024
216eb76
merge
Kaiyang-Chen May 13, 2024
e318e7e
Merge branch 'main' into dev_block_manager_v2_swap
Kaiyang-Chen May 14, 2024
4e1c511
Merge branch 'main' into dev_block_manager_v2_swap
Kaiyang-Chen May 28, 2024
862a5d4
fix: ci
Kaiyang-Chen May 28, 2024
cb28e8f
Merge branch 'main' into dev_block_manager_v2_swap
Kaiyang-Chen May 30, 2024
29df092
fix: merge
Kaiyang-Chen May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list

from ..utils import create_seq_group
from ..utils import create_dummy_prompt, create_seq_group


@pytest.mark.parametrize("block_size", [16])
Expand Down Expand Up @@ -101,3 +101,49 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
assert num_consumed_blocks == expected_consumed_blocks


@pytest.mark.parametrize("block_size", [8])
@pytest.mark.parametrize("num_cpu_blocks", [4])
@pytest.mark.parametrize("num_gpu_blocks", [4])
@pytest.mark.parametrize("num_lookahead_slots", [2])
Kaiyang-Chen marked this conversation as resolved.
Show resolved Hide resolved
def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots):
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
prompt.status = SequenceStatus.RUNNING
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})

# Swap seq group from GPU -> CPU.
gpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
assert list(mapping.keys()) == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
prompt.status = SequenceStatus.SWAPPED

# Swap seq group from CPU -> GPU.
cpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_in(seq_group, num_lookahead_slots)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group, num_lookahead_slots)
adjusted_cpu_blocks = [block - num_gpu_blocks for block in cpu_blocks]
assert list(mapping.keys()) == adjusted_cpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
30 changes: 24 additions & 6 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
"""
return cdiv(len(token_ids), block_size)

def get_blocks(self) -> Optional[List[Block]]:
return self._blocks
Kaiyang-Chen marked this conversation as resolved.
Show resolved Hide resolved

def allocate(self,
token_ids: List[int],
device: Device = Device.GPU) -> None:
device: Device = Device.GPU,
by_block: bool = False) -> Optional[Block]:
"""Allocates memory blocks for storing the given sequence of token IDs.

This method allocates the required number of blocks to store the given
Expand All @@ -77,13 +81,23 @@ def allocate(self,
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
by_block (bool, optional): whether we are allocate block by block.
Set to True when doing cache swapping. Defaults to False.
"""
assert not self._is_allocated
assert not self._is_allocated or by_block
assert token_ids
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self._num_full_slots = len(token_ids)
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self._num_full_slots += len(token_ids)
if not (by_block and self._is_allocated):
self._blocks = blocks
else:
# Note: whenever we call allocate with by_block set to True,
# because of swapping, the tokens must fit in a block
assert len(blocks) == 1
self._blocks.append(blocks[0])
return blocks[0]

def append_token_ids(self,
token_ids: List[int],
Expand Down Expand Up @@ -249,6 +263,10 @@ def _get_all_token_ids(self) -> List[int]:
def _is_allocated(self) -> bool:
return self._blocks is not None

@property
def _num_touched_blocks(self) -> int:
return len(self._blocks)

@property
def _num_empty_slots(self) -> int:
assert self._is_allocated
Expand Down
67 changes: 62 additions & 5 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from typing import Dict, List, Optional

from vllm.core.block.block_table import BlockTable
from vllm.core.block.interfaces import (Block, BlockAllocator,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.sequence import Sequence
from vllm.utils import Device


Expand Down Expand Up @@ -88,18 +92,17 @@ def create(
return CpuGpuBlockAllocator(
cpu_block_allocator=cpu_allocator,
gpu_block_allocator=gpu_allocator,
block_size=block_size,
)

def __init__(
self,
cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator,
):
def __init__(self, cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator, block_size: int):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
), "cpu and gpu block allocators can't have intersection of block ids"

self._block_size = block_size
self._allocators = {
Device.CPU: cpu_block_allocator,
Device.GPU: gpu_block_allocator,
Expand Down Expand Up @@ -143,6 +146,16 @@ def allocate_immutable(self, prev_block: Optional[Block],
return self._allocators[device].allocate_immutable(
prev_block, token_ids)

def reference(self, block: Block) -> None:
"""Notify the device aware allocator there is new sequence reference
the given block.

Args:
block (Block): The block to be referenced.
"""
allocator = self._block_ids_to_allocator[block.block_id]
return allocator.reference(block)

def free(self, block: Block) -> None:
"""Frees the memory occupied by the given block.

Expand Down Expand Up @@ -204,3 +217,47 @@ def get_common_computed_block_ids(

def all_block_ids(self) -> frozenset[int]:
return frozenset(self._block_ids_to_allocator.keys())

def get_seq_swap_out_block_mapping(
self, seq: Sequence, block_table: BlockTable,
mapping: Dict[Block, Block]) -> BlockTable:
# The swap out logic for a sequence, the mapping dict will be updated
# and the new block table for swapped out sequence is returned.
new_block_table = BlockTable(
block_size=self._block_size,
block_allocator=self,
)
for gpu_block in block_table.get_blocks():
if gpu_block in mapping:
cpu_block = mapping[gpu_block]
self.reference(cpu_block)
else:
cpu_block = new_block_table.allocate(
token_ids=gpu_block.token_ids,
device=Device.CPU,
by_block=True)
mapping[gpu_block] = cpu_block
self.free(gpu_block)
Kaiyang-Chen marked this conversation as resolved.
Show resolved Hide resolved
return new_block_table

def get_seq_swap_in_block_mapping(
self, seq: Sequence, block_table: BlockTable,
mapping: Dict[Block, Block]) -> BlockTable:
# The swap in logic for a sequence, the mapping dict will be updated
# and the new block table for swapped in sequence is returned.
new_block_table = BlockTable(
block_size=self._block_size,
block_allocator=self,
)
for cpu_block in block_table.get_blocks():
if cpu_block in mapping:
gpu_block = mapping[cpu_block]
self.reference(gpu_block)
else:
gpu_block = new_block_table.allocate(
token_ids=cpu_block.token_ids,
device=Device.GPU,
by_block=True)
mapping[cpu_block] = gpu_block
self.free(cpu_block)
return new_block_table
2 changes: 2 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Optional, Protocol

Expand Down
6 changes: 4 additions & 2 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations
Kaiyang-Chen marked this conversation as resolved.
Show resolved Hide resolved

from typing import Dict, Iterable, List, Optional, Set

from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
Expand Down Expand Up @@ -90,8 +92,8 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def free(self, block: Block) -> None:
self._free_block_id(block.block_id)

# Mark the block as having no allocation.
block.block_id = None
def reference(self, block: Block) -> None:
self._refcounter.incr(block.block_id)

def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
Expand Down
5 changes: 5 additions & 0 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Token blocks."""
from __future__ import annotations

from itertools import takewhile
from os.path import commonprefix
from typing import Dict, Iterable, List, Optional
Expand Down Expand Up @@ -197,6 +199,9 @@ def _free_block_id_for_block(self, block_id: BlockId,
assert block.content_hash in self._cached_blocks
self._unused_cached_blocks[block.content_hash] = block_id

def reference(self, block: Block) -> None:
self._refcounter.incr(block.block_id)

def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Expand Down
55 changes: 51 additions & 4 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
Expand Down Expand Up @@ -227,17 +228,63 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:

def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
return False
"""
We go through all sequence in seq group to get their number of blocks
touched and sum them up to see whether there is enough memory to swap in
"""
num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
block_table = self.block_tables[seq.seq_id]
num_touched_blocks += (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=seq.get_token_ids(),
num_lookahead_slots=num_lookahead_slots,
))
num_free_blocks = self.block_allocator.get_num_free_blocks(Device.GPU)
return num_free_blocks - num_touched_blocks >= self.watermark_blocks
Kaiyang-Chen marked this conversation as resolved.
Show resolved Hide resolved

def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
raise NotImplementedError
mapping: Dict[Block, Block] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
block_table = self.block_tables[seq.seq_id]
self.block_tables[
seq.
seq_id] = self.block_allocator.get_seq_swap_in_block_mapping(
seq, block_table, mapping)

# NOTE: since the memory operation in physical blocks need the
Kaiyang-Chen marked this conversation as resolved.
Show resolved Hide resolved
# relative position of CPU block to its starting address, here
# we need to shift the block id of cpu block back to its relative
# position within CPU cache.
block_number_mapping = {
cpu_block.block_id - self.num_total_gpu_blocks: gpu_block.block_id
for cpu_block, gpu_block in mapping.items()
}
return block_number_mapping

def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return False
num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
num_touched_blocks += block_table._num_touched_blocks
Kaiyang-Chen marked this conversation as resolved.
Show resolved Hide resolved
return num_touched_blocks <= self.block_allocator.get_num_free_blocks(
Device.CPU)

def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
mapping: Dict[Block, Block] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
self.block_tables[
seq.
seq_id] = self.block_allocator.get_seq_swap_out_block_mapping(
seq, block_table, mapping)

block_number_mapping = {
gpu_block.block_id: cpu_block.block_id - self.num_total_gpu_blocks
for gpu_block, cpu_block in mapping.items()
}
return block_number_mapping

def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
Expand Down
Loading