Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Chunked Prefill support for Multi Step Scheduling #7814

78 changes: 66 additions & 12 deletions tests/multi_step/test_correctness.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
# Test the AsyncLLMEngine with multi-step-decoding

import os
from collections import namedtuple
from enum import Enum
from typing import List

import pytest

from ..utils import RemoteOpenAIServer


class MultiStepChunkedPrefillPolicy(Enum):
# When prompt and decode sequences are scheduled together,
# the DEFAULT policy is to run the prompt and decodes sequences
# together only for the first step and run just the decode sequences
# in the rest of the steps.
DEFAULT = 1
# In FORCE_SINGLE_STEP policy, we force the scheduled sequences to
# run a single step and then re-schedule.
FORCE_SINGLE_STEP = 2
INVALID = 3
Comment on lines +13 to +22
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should define this in vLLM instead of tests.



ChunkedPrefillTestArgType = namedtuple('ChunkedPrefillTestArgType',
['enabled', 'policy'])

MODELS = [
"JackFram/llama-160m",
]
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS = [10]
CHUNKED_PREFILL_ARGS = [
ChunkedPrefillTestArgType(False, MultiStepChunkedPrefillPolicy.INVALID),
ChunkedPrefillTestArgType(True, MultiStepChunkedPrefillPolicy.DEFAULT),
ChunkedPrefillTestArgType(True,
MultiStepChunkedPrefillPolicy.FORCE_SINGLE_STEP)
]

DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
Expand All @@ -23,17 +48,36 @@
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):
class EnvContextManager():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you look around the tests and see if we already have a similar utility? If so please reuse it; otherwise please move this to the right place so that other tests in the future can use it.


def __init__(self, env: dict):
self.os_env = dict(os.environ)
self.add_env = dict(env)

def __enter__(self):
os.environ.update(self.add_env)

def __exit__(self, *args, **kwargs):
os.environ.clear()
os.environ.update(self.os_env)


async def completions_with_server_args(prompts: List[str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#7651 also introduces this utility. Please coordinate with @afeldman-nm to better organize these.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac do you mean #7652?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yes sorry for the typo.

model_name: str,
server_cli_args: List[str],
with_env: dict = {}): # noqa: B006
# env setup
os.environ.update(with_env)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO : fix stray update.


outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
with EnvContextManager(with_env) as _: # noqa: SIM117
with RemoteOpenAIServer(model_name, server_cli_args) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None

return outputs
Expand All @@ -47,10 +91,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("chunked_prefill", CHUNKED_PREFILL_ARGS)
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
pp_size: int, eager_mode: bool,
num_scheduler_steps: int, num_prompts: int,
chunked_prefill: ChunkedPrefillTestArgType):

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -65,6 +111,14 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
if eager_mode:
ms_server_args.append("--enforce-eager")

test_env = {}
if chunked_prefill.enabled:
ms_server_args.append("--enable-chunked-prefill")
if chunked_prefill.policy == \
MultiStepChunkedPrefillPolicy.FORCE_SINGLE_STEP:
test_env[
'VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY'] = '1'

distributed_args = [
"--tensor-parallel-size",
str(tp_size),
Expand All @@ -75,7 +129,7 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)
prompts, model, ms_server_args + distributed_args, test_env)

def get_text_generations(completions):
return [x.text for x in completions.choices]
Expand Down
53 changes: 53 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Attention layer with FlashAttention."""
import dataclasses
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -300,8 +301,60 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)

return self._cached_decode_metadata

# TODO (varun) : Try using decode_metadata here. We hit some asserts in
# advance_step - but that seems resolvable.
Comment on lines +307 to +308
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate more on this? Specifically how it can be solved and what's the plan?

@staticmethod
def without_prefills(m: "FlashAttentionMetadata") \
-> "FlashAttentionMetadata":
"""
Extract all information related to decodes from the given attention
metadata.
"""

num_prefills = m.num_prefills
num_prefill_tokens = m.num_prefill_tokens
if num_prefills == 0:
# Simply return a copy
return dataclasses.replace(m)

# Slice into GPU tensors to remove prefill related information
query_start_loc = None
seq_start_loc = None
if m.query_start_loc is not None and m.seq_start_loc is not None:
query_start_loc = m.query_start_loc[num_prefills:]
seq_start_loc = m.seq_start_loc[num_prefills:]
# query_start_loc and seq_start_loc store indices for
# indexing into some other tensor. As we are removing
# all the prefill related information from all the tensors,
# the decode information would now start from 0. Therefore,
# offset the indices in query_start_loc and seq_start_loc
query_start_loc = query_start_loc - query_start_loc[0]
seq_start_loc = seq_start_loc - seq_start_loc[0]

# All the other tensors can be sliced in-place
return FlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=m.num_decode_tokens,
slot_mapping=m.slot_mapping[num_prefill_tokens:],
seq_lens=m.seq_lens[num_prefills:]
if m.seq_lens is not None else None,
seq_lens_tensor=m.seq_lens_tensor[num_prefills:]
if m.seq_lens_tensor is not None else None,
max_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=m.max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=m.context_lens_tensor[num_prefills:]
if m.context_lens_tensor is not None else None,
block_tables=m.block_tables[num_prefills:]
if m.block_tables is not None else None,
use_cuda_graph=False)

def advance_step(self, num_seqs: int, num_queries: int):
"""
Update metadata in-place to advance one decode step.
Expand Down
18 changes: 17 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union

import vllm.envs as envs
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
Expand Down Expand Up @@ -983,6 +984,20 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
[s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)

if self.scheduler_config.is_multi_step and \
envs.VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY:
# When prefill sequences are scheduled together with decode
# sequences, force all sequences to take single-step.
has_prefills = len(prefills.seq_groups) + \
len(running_scheduled.prefill_seq_groups) + \
len(swapped_in.prefill_seq_groups) > 0
if has_prefills:
for sg in running_scheduled.decode_seq_groups:
sg.seq_group.init_multi_step(1)
for sg in swapped_in.decode_seq_groups:
sg.seq_group.init_multi_step(1)

return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.prefill_seq_groups +
Expand Down Expand Up @@ -1202,7 +1217,8 @@ def _append_slots(
the new source and destination block indices for the appended
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill=seq_group.is_prefill())
Copy link
Contributor

@SolitaryThinker SolitaryThinker Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spent some time looking into PP. The assertion issue may be related to this change. Something is causing the batch to generate logits for the prefills in the batch and since len(sample_indices) == 0) as they don't perform sampling, the assertion fails on assert logits_applied == logits.shape[0] in _apply_min_tokens_penalty since logits_applied is the sum of all sampled_indices in the batch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks @SolitaryThinker. I believe it should be

        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=seq_group.is_prefill() and seq_group.do_sample)

instead of just

        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)

I haven't been able to reproduce this assertion yet. I'll keep at it. When you find some time, can you try this out as well ? Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

976d032 fixes the PP issue. The SamplingMetadata objects were being clobbered due to the SamplingMetadataCache reset.

seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Expand Down
10 changes: 3 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,13 +888,9 @@ def create_engine_config(self, ) -> EngineConfig:
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)

if self.num_scheduler_steps > 1:
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.enable_chunked_prefill:
raise ValueError("Chunked prefill is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.num_scheduler_steps > 1 and speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")

# make sure num_lookahead_slots is set the higher value depending on
# if we are using speculative decoding or multi-step
Expand Down
72 changes: 54 additions & 18 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,11 @@ async def step_async(
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
remaining_steps = self._remaining_steps(seq_group_metadata_list)
if self.scheduler_config.is_multi_step and \
remaining_steps is not None and remaining_steps > 1:
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
# one.
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)

Expand Down Expand Up @@ -346,7 +347,8 @@ async def step_async(
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if seq_group.state.remaining_steps > 0:
seq_group.finish_step()
Comment on lines +350 to +351
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment on this?


if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
Expand All @@ -367,25 +369,59 @@ async def step_async(

return request_outputs

def _has_remaining_steps(
def _remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
) -> Optional[int]:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return None

# Get the remaining steps of the last sequence. This is motivated,
# by a few assumptions that are generally true.
# 1. The sequences in seq_group_metadata_list is always sorted by
# "prefills-then-decodes".
# 2. All the prefill sequences have the same number of num_steps.
# 3. All the decode sequences have the same number of num_steps.
# 4. The num_steps of the decode_sequences >= num_steps of the
# prefill sequences.

remaining_steps = seq_group_metadata_list[-1].state.remaining_steps

if self.scheduler_config.chunked_prefill_enabled:
# When chunked prefill is enabled, the prompt and decode sequences
# may be scheduled together.
#
# The decode sequences should have `remaining_steps` steps to go.
# The prefill sequences's remaining_step is 1 when they are
# scheduled initially. After the first step their remaining_step
# becomes 0.
if any([sgml.state.remaining_steps not in [0, 1, remaining_steps] \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if any([sgml.state.remaining_steps not in [0, 1, remaining_steps] \
if any([sgml.state.remaining_steps not in (0, 1, remaining_steps) \

for sgml in seq_group_metadata_list]):
raise AssertionError(
"Running sequences violate assumptions about "
"remaining_step counts.")
else:
# In the normal case, the sequences in seq_group_metadata_list are
# either all prefills or all decodes and there for all sequences
# must have the same number of remaining_steps.
if any([
seq_group.state.remaining_steps != remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))

return remaining_steps

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
remaining_steps: Optional[int] = self._remaining_steps(
seq_group_metadata_list)
if remaining_steps is None:
return False

# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))

return ref_remaining_steps > 0
return remaining_steps > 0

def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ def process_outputs(self, sequence_group: SequenceGroup,
"Beam search not supported in multi-step decoding.")
seq = seqs[0]

# Since there's only one sequence per sequence group, we can take the
# first sample.
samples = [output.samples[0] for output in outputs]
# TODO (Varun) : Pass in an output_token_id of -1 instead of returning
# 0 samples.
samples = [output.samples[0] for output in outputs if output.samples]

# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
# rejecting tokens)
valid_samples = [
sample for sample in samples if sample.output_token != -1
]
Expand Down
12 changes: 12 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that be better if we make it more extensible like the following?

Suggested change
VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: bool = False
VLLM_MULTI_STEP_CHUNKED_PREFILL_POLICY: str = "let-prefill-wait"

And another policy would be "single-step-with-prefill"



def get_default_cache_root():
Expand Down Expand Up @@ -400,6 +401,17 @@ def get_default_config_root():
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),

# Applicable when multi-decodes (--num-scheduler-steps) and chunked-prefill
# (--enable-chunked-prefill) are both enabled. When prefills are scheduled
# together with decode sequences, this flag forces the engine to single-step
# the model execution for all the sequences. The default behaviour is to,
# run the both the prefill and decode sequence for the first step and run
# only the decode sequences for the rest of the steps.
"VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY":
lambda: os.environ.get(
"VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY", "False").lower(
) in ("true", "1"),
}

# end-env-vars-definition
Expand Down
Loading
Loading