Skip to content

Commit

Permalink
[Frontend] API support for beam search (vllm-project#9087)
Browse files Browse the repository at this point in the history
Co-authored-by: youkaichao <[email protected]>
  • Loading branch information
2 people authored and liuyanyi committed Oct 6, 2024
1 parent e8097e0 commit 7f616e5
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 68 deletions.
12 changes: 8 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


Expand Down Expand Up @@ -145,10 +146,13 @@ def run_vllm(
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(prompts,
beam_width=n,
max_tokens=output_len,
ignore_eos=True)
llm.beam_search(
prompts,
BeamSearchParams(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
))
end = time.perf_counter()
return end - start

Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_cpu)

Expand Down Expand Up @@ -812,7 +813,9 @@ def generate_beam_search_new(
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
outputs = self.model.beam_search(
prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
returned_outputs = []
for output in outputs:
token_ids = [x.tokens for x in output.sequences]
Expand Down
43 changes: 24 additions & 19 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,25 +495,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text

# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
try:
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
except BadRequestError as e:
# the only allowed exception is when beam search is not supported
# in the default mqllmengine
assert "--disable-frontend-multiprocessing" in str(e)

# test streaming
batch = await client.completions.create(
Expand Down
107 changes: 103 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,26 @@
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.entrypoints.llm import BeamSearchSequence
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType
from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid, weak_bind)

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -1036,6 +1039,102 @@ async def generate(
):
yield LLMEngine.validate_output(output, RequestOutput)

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature

tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]

tasks = []

request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)

output = await asyncio.gather(*tasks)

output = [x[0] for x in output]

logger.info(output)

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]

if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams,
key=lambda x: x.cum_logprob,
reverse=True)
all_beams = sorted_beams[:beam_width]

completed.extend(all_beams)
sorted_completed = sorted(completed,
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)

yield LLMEngine.validate_output(beam_search_output, RequestOutput)

async def encode(
self,
prompt: PromptType,
Expand Down
20 changes: 10 additions & 10 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams)
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -394,25 +394,25 @@ def generate(
def beam_search(
self,
prompts: List[Union[str, List[int]]],
beam_width: int,
max_tokens: int,
ignore_eos: bool = False,
temperature: float = 0.0,
params: BeamSearchParams,
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
temperature: The temperature to use for generation.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""

beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos

tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
Expand Down
5 changes: 3 additions & 2 deletions vllm/entrypoints/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams

logger = init_logger(__name__)

Expand All @@ -21,7 +21,8 @@ def log_inputs(
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams]],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
Expand Down
36 changes: 34 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams)
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid

Expand Down Expand Up @@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):

# doc: end-chat-completion-extra-params

def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0

return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
Expand Down Expand Up @@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):

# doc: end-completion-extra-params

def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0

return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
Expand Down
Loading

0 comments on commit 7f616e5

Please sign in to comment.