-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Chunked Prefill support for Multi Step Scheduling #7814
Changes from 12 commits
8dc48c4
b15ef3a
af07307
d2c1f9c
2b8a6a1
e137fa2
2a44012
93cc29f
db91e3e
047a0ea
14bb0e9
6eac258
976d032
be7fe61
5b29a05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,41 @@ | ||
# Test the AsyncLLMEngine with multi-step-decoding | ||
|
||
import os | ||
from collections import namedtuple | ||
from enum import Enum | ||
from typing import List | ||
|
||
import pytest | ||
|
||
from ..utils import RemoteOpenAIServer | ||
|
||
|
||
class MultiStepChunkedPrefillPolicy(Enum): | ||
# When prompt and decode sequences are scheduled together, | ||
# the DEFAULT policy is to run the prompt and decodes sequences | ||
# together only for the first step and run just the decode sequences | ||
# in the rest of the steps. | ||
DEFAULT = 1 | ||
# In FORCE_SINGLE_STEP policy, we force the scheduled sequences to | ||
# run a single step and then re-schedule. | ||
FORCE_SINGLE_STEP = 2 | ||
INVALID = 3 | ||
|
||
|
||
ChunkedPrefillTestArgType = namedtuple('ChunkedPrefillTestArgType', | ||
['enabled', 'policy']) | ||
|
||
MODELS = [ | ||
"JackFram/llama-160m", | ||
] | ||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps | ||
NUM_PROMPTS = [10] | ||
CHUNKED_PREFILL_ARGS = [ | ||
ChunkedPrefillTestArgType(False, MultiStepChunkedPrefillPolicy.INVALID), | ||
ChunkedPrefillTestArgType(True, MultiStepChunkedPrefillPolicy.DEFAULT), | ||
ChunkedPrefillTestArgType(True, | ||
MultiStepChunkedPrefillPolicy.FORCE_SINGLE_STEP) | ||
] | ||
|
||
DEFAULT_SERVER_ARGS: List[str] = [ | ||
"--disable-log-requests", | ||
|
@@ -23,17 +48,36 @@ | |
] | ||
|
||
|
||
async def completions_with_server_args(prompts: List[str], model_name: str, | ||
server_cli_args: List[str]): | ||
class EnvContextManager(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you look around the tests and see if we already have a similar utility? If so please reuse it; otherwise please move this to the right place so that other tests in the future can use it. |
||
|
||
def __init__(self, env: dict): | ||
self.os_env = dict(os.environ) | ||
self.add_env = dict(env) | ||
|
||
def __enter__(self): | ||
os.environ.update(self.add_env) | ||
|
||
def __exit__(self, *args, **kwargs): | ||
os.environ.clear() | ||
os.environ.update(self.os_env) | ||
|
||
|
||
async def completions_with_server_args(prompts: List[str], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #7651 also introduces this utility. Please coordinate with @afeldman-nm to better organize these. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops yes sorry for the typo. |
||
model_name: str, | ||
server_cli_args: List[str], | ||
with_env: dict = {}): # noqa: B006 | ||
# env setup | ||
os.environ.update(with_env) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO : fix stray update. |
||
|
||
outputs = None | ||
with RemoteOpenAIServer(model_name, server_cli_args) as server: | ||
client = server.get_async_client() | ||
outputs = await client.completions.create(model=model_name, | ||
prompt=prompts, | ||
temperature=0, | ||
stream=False, | ||
max_tokens=5) | ||
with EnvContextManager(with_env) as _: # noqa: SIM117 | ||
with RemoteOpenAIServer(model_name, server_cli_args) as server: | ||
client = server.get_async_client() | ||
outputs = await client.completions.create(model=model_name, | ||
prompt=prompts, | ||
temperature=0, | ||
stream=False, | ||
max_tokens=5) | ||
assert outputs is not None | ||
|
||
return outputs | ||
|
@@ -47,10 +91,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str, | |
@pytest.mark.parametrize("eager_mode", [False, True]) | ||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) | ||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) | ||
@pytest.mark.parametrize("chunked_prefill", CHUNKED_PREFILL_ARGS) | ||
@pytest.mark.asyncio | ||
async def test_multi_step(example_prompts, model: str, tp_size: int, | ||
pp_size: int, eager_mode: int, | ||
num_scheduler_steps: int, num_prompts: int): | ||
pp_size: int, eager_mode: bool, | ||
num_scheduler_steps: int, num_prompts: int, | ||
chunked_prefill: ChunkedPrefillTestArgType): | ||
|
||
prompts = example_prompts | ||
if len(prompts) < num_prompts: | ||
|
@@ -65,6 +111,14 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, | |
if eager_mode: | ||
ms_server_args.append("--enforce-eager") | ||
|
||
test_env = {} | ||
if chunked_prefill.enabled: | ||
ms_server_args.append("--enable-chunked-prefill") | ||
if chunked_prefill.policy == \ | ||
MultiStepChunkedPrefillPolicy.FORCE_SINGLE_STEP: | ||
test_env[ | ||
'VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY'] = '1' | ||
|
||
distributed_args = [ | ||
"--tensor-parallel-size", | ||
str(tp_size), | ||
|
@@ -75,7 +129,7 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, | |
ref_completions = await completions_with_server_args( | ||
prompts, model, server_args + distributed_args) | ||
test_completions = await completions_with_server_args( | ||
prompts, model, ms_server_args + distributed_args) | ||
prompts, model, ms_server_args + distributed_args, test_env) | ||
|
||
def get_text_generations(completions): | ||
return [x.text for x in completions.choices] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""Attention layer with FlashAttention.""" | ||
import dataclasses | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type | ||
|
||
|
@@ -300,8 +301,60 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: | |
block_tables=self.block_tables[self.num_prefills:], | ||
use_cuda_graph=self.use_cuda_graph, | ||
) | ||
|
||
return self._cached_decode_metadata | ||
|
||
# TODO (varun) : Try using decode_metadata here. We hit some asserts in | ||
# advance_step - but that seems resolvable. | ||
Comment on lines
+307
to
+308
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate more on this? Specifically how it can be solved and what's the plan? |
||
@staticmethod | ||
def without_prefills(m: "FlashAttentionMetadata") \ | ||
-> "FlashAttentionMetadata": | ||
""" | ||
Extract all information related to decodes from the given attention | ||
metadata. | ||
""" | ||
|
||
num_prefills = m.num_prefills | ||
num_prefill_tokens = m.num_prefill_tokens | ||
if num_prefills == 0: | ||
# Simply return a copy | ||
return dataclasses.replace(m) | ||
|
||
# Slice into GPU tensors to remove prefill related information | ||
query_start_loc = None | ||
seq_start_loc = None | ||
if m.query_start_loc is not None and m.seq_start_loc is not None: | ||
query_start_loc = m.query_start_loc[num_prefills:] | ||
seq_start_loc = m.seq_start_loc[num_prefills:] | ||
# query_start_loc and seq_start_loc store indices for | ||
# indexing into some other tensor. As we are removing | ||
# all the prefill related information from all the tensors, | ||
# the decode information would now start from 0. Therefore, | ||
# offset the indices in query_start_loc and seq_start_loc | ||
query_start_loc = query_start_loc - query_start_loc[0] | ||
seq_start_loc = seq_start_loc - seq_start_loc[0] | ||
|
||
# All the other tensors can be sliced in-place | ||
return FlashAttentionMetadata( | ||
num_prefills=0, | ||
num_prefill_tokens=0, | ||
num_decode_tokens=m.num_decode_tokens, | ||
slot_mapping=m.slot_mapping[num_prefill_tokens:], | ||
seq_lens=m.seq_lens[num_prefills:] | ||
if m.seq_lens is not None else None, | ||
seq_lens_tensor=m.seq_lens_tensor[num_prefills:] | ||
if m.seq_lens_tensor is not None else None, | ||
max_query_len=1, | ||
max_prefill_seq_len=0, | ||
max_decode_seq_len=m.max_decode_seq_len, | ||
query_start_loc=query_start_loc, | ||
seq_start_loc=seq_start_loc, | ||
context_lens_tensor=m.context_lens_tensor[num_prefills:] | ||
if m.context_lens_tensor is not None else None, | ||
block_tables=m.block_tables[num_prefills:] | ||
if m.block_tables is not None else None, | ||
use_cuda_graph=False) | ||
|
||
def advance_step(self, num_seqs: int, num_queries: int): | ||
""" | ||
Update metadata in-place to advance one decode step. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from dataclasses import dataclass, field | ||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union | ||
|
||
import vllm.envs as envs | ||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig | ||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager | ||
from vllm.logger import init_logger | ||
|
@@ -983,6 +984,20 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: | |
[s.seq_group for s in swapped_in.prefill_seq_groups]) | ||
# Update swapped requests. | ||
self.swapped.extend(running_scheduled.swapped_out) | ||
|
||
if self.scheduler_config.is_multi_step and \ | ||
envs.VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: | ||
# When prefill sequences are scheduled together with decode | ||
# sequences, force all sequences to take single-step. | ||
has_prefills = len(prefills.seq_groups) + \ | ||
len(running_scheduled.prefill_seq_groups) + \ | ||
len(swapped_in.prefill_seq_groups) > 0 | ||
if has_prefills: | ||
for sg in running_scheduled.decode_seq_groups: | ||
sg.seq_group.init_multi_step(1) | ||
for sg in swapped_in.decode_seq_groups: | ||
sg.seq_group.init_multi_step(1) | ||
|
||
return SchedulerOutputs( | ||
scheduled_seq_groups=(prefills.seq_groups + | ||
running_scheduled.prefill_seq_groups + | ||
|
@@ -1202,7 +1217,8 @@ def _append_slots( | |
the new source and destination block indices for the appended | ||
slots. | ||
""" | ||
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) | ||
num_lookahead_slots = self._get_num_lookahead_slots( | ||
is_prefill=seq_group.is_prefill()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spent some time looking into PP. The assertion issue may be related to this change. Something is causing the batch to generate logits for the prefills in the batch and since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks @SolitaryThinker. I believe it should be
instead of just
I haven't been able to reproduce this assertion yet. I'll keep at it. When you find some time, can you try this out as well ? Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 976d032 fixes the PP issue. The SamplingMetadata objects were being clobbered due to the SamplingMetadataCache reset. |
||
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) | ||
|
||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -300,10 +300,11 @@ async def step_async( | |||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[ | ||||||
virtual_engine].schedule() | ||||||
|
||||||
if (self.scheduler_config.is_multi_step | ||||||
and scheduler_outputs.num_lookahead_slots > 0): | ||||||
remaining_steps = self._remaining_steps(seq_group_metadata_list) | ||||||
if self.scheduler_config.is_multi_step and \ | ||||||
remaining_steps is not None and remaining_steps > 1: | ||||||
# cache the scheduler outputs for the next iteration if we have | ||||||
# lookahead slots | ||||||
# one. | ||||||
self._cache_scheduler_outputs_for_multi_step( | ||||||
virtual_engine, seq_group_metadata_list, scheduler_outputs) | ||||||
|
||||||
|
@@ -346,7 +347,8 @@ async def step_async( | |||||
# Finish the current step for all the sequence groups. | ||||||
if self.scheduler_config.is_multi_step: | ||||||
for seq_group in seq_group_metadata_list: | ||||||
seq_group.finish_step() | ||||||
if seq_group.state.remaining_steps > 0: | ||||||
seq_group.finish_step() | ||||||
Comment on lines
+350
to
+351
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you comment on this? |
||||||
|
||||||
if not self._has_remaining_steps(seq_group_metadata_list): | ||||||
# clear the cache if we have finished all the steps | ||||||
|
@@ -367,25 +369,59 @@ async def step_async( | |||||
|
||||||
return request_outputs | ||||||
|
||||||
def _has_remaining_steps( | ||||||
def _remaining_steps( | ||||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] | ||||||
) -> bool: | ||||||
) -> Optional[int]: | ||||||
if (not self.scheduler_config.is_multi_step | ||||||
or not seq_group_metadata_list): | ||||||
return None | ||||||
|
||||||
# Get the remaining steps of the last sequence. This is motivated, | ||||||
# by a few assumptions that are generally true. | ||||||
# 1. The sequences in seq_group_metadata_list is always sorted by | ||||||
# "prefills-then-decodes". | ||||||
# 2. All the prefill sequences have the same number of num_steps. | ||||||
# 3. All the decode sequences have the same number of num_steps. | ||||||
# 4. The num_steps of the decode_sequences >= num_steps of the | ||||||
# prefill sequences. | ||||||
|
||||||
remaining_steps = seq_group_metadata_list[-1].state.remaining_steps | ||||||
|
||||||
if self.scheduler_config.chunked_prefill_enabled: | ||||||
# When chunked prefill is enabled, the prompt and decode sequences | ||||||
# may be scheduled together. | ||||||
# | ||||||
# The decode sequences should have `remaining_steps` steps to go. | ||||||
# The prefill sequences's remaining_step is 1 when they are | ||||||
# scheduled initially. After the first step their remaining_step | ||||||
# becomes 0. | ||||||
if any([sgml.state.remaining_steps not in [0, 1, remaining_steps] \ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
for sgml in seq_group_metadata_list]): | ||||||
raise AssertionError( | ||||||
"Running sequences violate assumptions about " | ||||||
"remaining_step counts.") | ||||||
else: | ||||||
# In the normal case, the sequences in seq_group_metadata_list are | ||||||
# either all prefills or all decodes and there for all sequences | ||||||
# must have the same number of remaining_steps. | ||||||
if any([ | ||||||
seq_group.state.remaining_steps != remaining_steps | ||||||
for seq_group in seq_group_metadata_list[1:] | ||||||
]): | ||||||
raise AssertionError(("All running sequence groups should " | ||||||
"have the same remaining steps.")) | ||||||
|
||||||
return remaining_steps | ||||||
|
||||||
def _has_remaining_steps( | ||||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] | ||||||
) -> bool: | ||||||
remaining_steps: Optional[int] = self._remaining_steps( | ||||||
seq_group_metadata_list) | ||||||
if remaining_steps is None: | ||||||
return False | ||||||
|
||||||
# TODO(will) this is a sanity check for nowto make sure that all the | ||||||
# seqs are on the same steps. Eventually we will want to do some sort of | ||||||
# dynamic scheduling when doing multi-step decoding. | ||||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps | ||||||
if any([ | ||||||
seq_group.state.remaining_steps != ref_remaining_steps | ||||||
for seq_group in seq_group_metadata_list[1:] | ||||||
]): | ||||||
raise AssertionError(("All running sequence groups should " | ||||||
"have the same remaining steps.")) | ||||||
|
||||||
return ref_remaining_steps > 0 | ||||||
return remaining_steps > 0 | ||||||
|
||||||
def _cache_scheduler_outputs_for_multi_step( | ||||||
self, virtual_engine: int, | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -60,6 +60,7 @@ | |||||
VLLM_ALLOW_ENGINE_USE_RAY: bool = False | ||||||
VLLM_PLUGINS: Optional[List[str]] = None | ||||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None | ||||||
VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: bool = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would that be better if we make it more extensible like the following?
Suggested change
And another policy would be |
||||||
|
||||||
|
||||||
def get_default_cache_root(): | ||||||
|
@@ -400,6 +401,17 @@ def get_default_config_root(): | |||||
"VLLM_TORCH_PROFILER_DIR": | ||||||
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os | ||||||
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), | ||||||
|
||||||
# Applicable when multi-decodes (--num-scheduler-steps) and chunked-prefill | ||||||
# (--enable-chunked-prefill) are both enabled. When prefills are scheduled | ||||||
# together with decode sequences, this flag forces the engine to single-step | ||||||
# the model execution for all the sequences. The default behaviour is to, | ||||||
# run the both the prefill and decode sequence for the first step and run | ||||||
# only the decode sequences for the rest of the steps. | ||||||
"VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY": | ||||||
lambda: os.environ.get( | ||||||
"VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY", "False").lower( | ||||||
) in ("true", "1"), | ||||||
} | ||||||
|
||||||
# end-env-vars-definition | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should define this in vLLM instead of tests.