Skip to content

Commit

Permalink
[Bugfix] Fix PP for Multi-Step (vllm-project#8887)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
varun-sundar-rabindranath authored and Alvant committed Oct 26, 2024
1 parent 954a6dd commit 9cbe49b
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 15 deletions.
82 changes: 82 additions & 0 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,85 @@ async def test_multi_step(
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 2),
])
@pytest.mark.asyncio
async def test_multi_step_pp_smoke(
tp_size: int,
pp_size: int,
monkeypatch,
) -> None:
"""
Smoke test for the vLLM engine with multi-step scheduling in an
OpenAI-protocol client/server environment.
This tests compares the outputs between multi-step scheduling and
single-step scheduling. Notably, this test lets the engines generate
more tokens (default is 5) and test for an exact match over all the
tokens.
Args:
tp_size: degree of tensor-parallelism
pp_size: degree of pipeline-parallelism
eager_mode
"""

model = "JackFram/llama-160m"
num_scheduler_steps = 8
attention_backend = "FLASH_ATTN"
max_num_seqs = 3

override_backend_env_variable(monkeypatch, attention_backend)

# Prompt from the ShareGPT dataset
prompts = [
"in the jtbd context whats a push?", # codespell:ignore
"in the jtbd context whats a push?", # codespell:ignore
"in the jtbd context whats a push?", # codespell:ignore
"in the jtbd context whats a push?", # codespell:ignore
]
# Use varying max_tokens to introduce scheduling randomness.
max_tokens = [10 * i for i in range(1, len(prompts) + 1)]
assert len(prompts) == len(max_tokens)

test_args = [
"--tensor-parallel-size",
str(tp_size), "--pipeline-parallel-size",
str(pp_size), "--max-num-seqs",
str(max_num_seqs)
]

server_args = DEFAULT_SERVER_ARGS + test_args
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"] + \
test_args

# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically
# was raised 3x to 720 *just for this test* due to
# observed timeouts in GHA CI
ref_completions = await completions_with_server_args(
prompts=prompts,
model_name=model,
server_cli_args=server_args,
num_logprobs=None,
max_wait_seconds=5 * 240,
max_tokens=max_tokens)

test_completions = await completions_with_server_args(
prompts=prompts,
model_name=model,
server_cli_args=ms_server_args,
num_logprobs=None,
max_wait_seconds=5 * 240,
max_tokens=max_tokens)

# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
ref_generations = get_client_text_generations(ref_completions)
test_generations = get_client_text_generations(test_completions)

assert ref_generations == test_generations
38 changes: 26 additions & 12 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import functools
import os
import signal
Expand All @@ -7,7 +8,7 @@
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import openai
import pytest
Expand Down Expand Up @@ -476,7 +477,8 @@ async def completions_with_server_args(
server_cli_args: List[str],
num_logprobs: Optional[int],
max_wait_seconds: int = 240,
) -> Completion:
max_tokens: Union[int, list] = 5,
) -> List[Completion]:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
Expand All @@ -487,37 +489,49 @@ async def completions_with_server_args(
num_logprobs: Number of logprobs to report (or `None`)
max_wait_seconds: timeout interval for bringing up server.
Default: 240sec
max_tokens: max_tokens value for each of the given input prompts.
if only one max_token value is given, the same value is used
for all the prompts.
Returns:
OpenAI Completion instance
'''

if isinstance(max_tokens, int):
max_tokens = [max_tokens] * len(prompts)

assert len(max_tokens) == len(prompts)

outputs = None
max_wait_seconds = 240 * 3 # 240 is default
with RemoteOpenAIServer(model_name,
server_cli_args,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5,
logprobs=num_logprobs)
outputs = [ client.completions.create(model=model_name,
prompt=[p],
temperature=0,
stream=False,
max_tokens=max_tok,
logprobs=num_logprobs) \
for p, max_tok in zip(prompts, max_tokens) ]
outputs = await asyncio.gather(*outputs)

assert outputs is not None, "Completion API call failed."

return outputs


def get_client_text_generations(completions: Completion) -> List[str]:
def get_client_text_generations(completions: List[Completion]) -> List[str]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
return [x.text for x in completions.choices]
assert all([len(x.choices) == 1 for x in completions])
return [x.choices[0].text for x in completions]


def get_client_text_logprob_generations(
completions: Completion) -> List[TextTextLogprobs]:
completions: List[Completion]) -> List[TextTextLogprobs]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
Expand All @@ -526,4 +540,4 @@ def get_client_text_logprob_generations(
text = ''.join(text_generations)
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for x in completions.choices]
for completion in completions for x in completion.choices]
3 changes: 3 additions & 0 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def process_outputs(self,
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]
seq_id = seq.seq_id
assert all(
[seq_id == output.samples[0].parent_seq_id for output in outputs])

if is_async:
# Async case: We process tokens one by one. Here, we know the token
Expand Down
10 changes: 9 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,16 @@ def __init__(

# Used to cache python objects
self.inter_data_cache: Dict[int, PyObjectCache] = {}

# Using the PythonizationCache in Pipeline-Parallel clobbers the
# SequenceGroupToSample object. In Pipeline-Parallel, we have
# more than 1 Scheduler, resulting in a potential back-to-back
# prepare_model_inputs() call. This clobbers the cached
# SequenceGroupToSample objects, as we reset the cache during
# every prepare_model_inputs() call.
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache()
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
Expand Down
12 changes: 10 additions & 2 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,14 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
self.is_multi_step = self.scheduler_config.is_multi_step
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None

self.pythonization_cache = PythonizationCache()
# Using the PythonizationCache in Pipeline-Parallel clobbers the
# SequenceOutput and CompletionSequenceGroupOutput object.
# When cache-reset happens at the last step of a multi-step
# execution, there may be other on-going single-step/multi-step
# executions. The current caching implementation does not check
# for this.
self.pythonization_cache = PythonizationCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

@functools.cached_property
def _copy_stream(self):
Expand Down Expand Up @@ -577,7 +584,8 @@ def execute_model(
if model_input.is_last_step:
outputs = self._final_process_outputs(
model_input, model_input.base_output_proc_callback)
self.pythonization_cache.reset()
if self.pythonization_cache:
self.pythonization_cache.reset()
return outputs

# should be [SamplerOutput]
Expand Down

0 comments on commit 9cbe49b

Please sign in to comment.