Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Frontend] Coalesce bunched RequestOutputs #12298

Merged
merged 6 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from contextlib import ExitStack
from typing import List, Tuple

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM

if not current_platform.is_cuda():
Expand All @@ -18,28 +20,39 @@


async def generate(engine: AsyncLLM, request_id: str,
output_kind: RequestOutputKind,
max_tokens: int) -> Tuple[int, str]:
count = 0
async for _ in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(
max_tokens=max_tokens, temperature=0)):
sampling_params = SamplingParams(max_tokens=max_tokens,
output_kind=output_kind,
temperature=0)
async for out in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=sampling_params):

num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens

count += 1
await asyncio.sleep(0.)

return count, request_id


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(monkeypatch):
async def test_load(monkeypatch, output_kind: RequestOutputKind):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
after.callback(engine.shutdown)

NUM_REQUESTS = 10000
NUM_EXPECTED_TOKENS = 10
Expand All @@ -51,26 +64,33 @@ async def test_load(monkeypatch):
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
generate(engine, request_id, output_kind,
NUM_EXPECTED_TOKENS)))

# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")

assert not engine.output_processor.has_unfinished_requests()
engine.shutdown()


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort(monkeypatch):
async def test_abort(monkeypatch, output_kind: RequestOutputKind):

with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 100
Expand All @@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
generate(engine, request_id, output_kind,
NUM_EXPECTED_TOKENS)))

# API server cancels requests when they disconnect.
for idx in REQUEST_IDS_TO_ABORT:
Expand All @@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
# Confirm we can do another generation.
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
task = asyncio.create_task(
generate(engine, request_id, NUM_EXPECTED_TOKENS))
generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()

engine.shutdown()
22 changes: 21 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, Generic, List, Optional
from typing import Dict, Generic, List, MutableSequence, Optional
from typing import Sequence as GenericSequence
from typing import Union

Expand Down Expand Up @@ -162,6 +162,26 @@ def new(
finished=finished,
)

def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""

self.prompt = next_output.prompt
self.prompt_token_ids = next_output.prompt_token_ids
self.prompt_logprobs = next_output.prompt_logprobs
self.finished |= next_output.finished

#TODO assuming n == 1 for now
completion = self.outputs[0]
next_completion = next_output.outputs[0]
njhill marked this conversation as resolved.
Show resolved Hide resolved
completion.text += next_completion.text
if not isinstance(completion.token_ids, MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(next_completion.logprobs)
completion.cumulative_logprob = next_completion.cumulative_logprob

@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
Expand Down
10 changes: 9 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -214,6 +214,14 @@ async def generate(
# task switching under load which helps performance).
out = q.get_nowait() if not q.empty() else await q.get()

# Coalesce any additional queued outputs
while not q.empty():
next_out = q.get_nowait()
if sampling_params.output_kind == RequestOutputKind.DELTA:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be in a future PR, but we should check for invalid sampling params like n > 1 here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I guess that's a general V1 thing, we should be explicitly rejecting requests that include unsupported parameters if we aren't already.

out.add(next_out)
else:
out = next_out

# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
Expand Down
Loading