diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4752adab444c8..e15926d9a8780 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,6 @@ from weakref import ReferenceType import vllm.envs as envs -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs @@ -15,25 +14,24 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.protocol import EngineClient from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType, TokensPrompt +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, - RequestOutput) +from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - random_uuid, weak_bind) +from vllm.utils import deprecate_kwargs, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -583,7 +581,7 @@ async def build_guided_decoding_logits_processor_async( return sampling_params -class AsyncLLMEngine: +class AsyncLLMEngine(EngineClient): """An asynchronous wrapper for :class:`LLMEngine`. This class is used to wrap the :class:`LLMEngine` class to make it @@ -1081,102 +1079,6 @@ async def generate( ): yield LLMEngine.validate_output(output, RequestOutput) - async def beam_search( - self, - prompt: Union[PromptType, List[int]], - request_id: str, - params: BeamSearchParams, - ) -> AsyncGenerator[RequestOutput, None]: - - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - - tokenizer = await self.get_tokenizer() - tokenizedPrompt = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) - tokenizedLength = len(tokenizedPrompt) - - sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) - - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) - all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] - completed = [] - - for _ in range(max_tokens): - prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) - for beam in all_beams - ] - - tasks = [] - - request_id = f"beam_search-{random_uuid()}" - for i, individual_prompt in enumerate(prompts_batch): - request_id_item = f"{request_id}-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.generate(individual_prompt, beam_search_params, - request_id_item))) - tasks.append(task) - - output = await asyncio.gather(*tasks) - - output = [x[0] for x in output] - - logger.info(output) - - new_beams = [] - for i, current_beam in enumerate(all_beams): - result = output[i] - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) - - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - completed.append(new_beam) - else: - new_beams.append(new_beam) - - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) - all_beams = sorted_beams[:beam_width] - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) - - beam_search_output = RequestOutput( - request_id=request_id, - prompt=prompt, - outputs=[ - CompletionOutput( - text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens, - index=i, - logprobs=beam.cum_logprob, - ) for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=tokenizedPrompt, - prompt_logprobs=None) - - yield LLMEngine.validate_output(beam_search_output, RequestOutput) - async def encode( self, prompt: PromptType, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 166906f24673b..6bf553666a852 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -12,8 +12,8 @@ from zmq.asyncio import Socket from vllm import PoolingParams -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block # yapf: disable @@ -26,18 +26,18 @@ RPCError, RPCProcessRequest, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) +from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType, TokensPrompt +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, - RequestOutput) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - random_uuid) +from vllm.utils import deprecate_kwargs logger = init_logger(__name__) @@ -53,7 +53,7 @@ class MQClientClosedError(Exception): """ -class MQLLMEngineClient: +class MQLLMEngineClient(EngineClient): """A client wrapper for MQLLMEngine that conforms to the EngineClient protocol. @@ -316,7 +316,7 @@ async def _check_success(error_message: str, socket: Socket): or response != VLLM_RPC_SUCCESS_STR): raise ValueError(error_message) - async def get_tokenizer(self, lora_request: LoRARequest): + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) async def get_decoding_config(self) -> DecodingConfig: @@ -344,8 +344,14 @@ async def abort(self, request_id: str): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), socket=self.input_socket) - async def do_log_stats(self): - """Ignore do_log_stats (handled on MQLLMEngine polling)""" + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + """ + Ignore do_log_stats (handled on MQLLMEngine polling) + """ pass async def check_health(self): @@ -444,104 +450,6 @@ def generate( lora_request, trace_headers, prompt_adapter_request, priority) - async def beam_search( - self, - prompt: Union[PromptType, List[int]], - request_id: str, - params: BeamSearchParams, - ) -> AsyncGenerator[RequestOutput, None]: - - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - - tokenizer = await self.get_tokenizer(lora_request=None) - tokenizedPrompt = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) - tokenizedLength = len(tokenizedPrompt) - - sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) - - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) - all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] - completed = [] - - for _ in range(max_tokens): - prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) - for beam in all_beams - ] - - tasks = [] - - request_id = f"beam_search-{random_uuid()}" - for i, individual_prompt in enumerate(prompts_batch): - request_id_item = f"{request_id}-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.generate(individual_prompt, beam_search_params, - request_id_item))) - tasks.append(task) - - output = await asyncio.gather(*tasks) - - output = [x[0] for x in output] - - logger.info(output) - - new_beams = [] - for i, current_beam in enumerate(all_beams): - result = output[i] - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) - - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - completed.append(new_beam) - else: - new_beams.append(new_beam) - - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) - all_beams = sorted_beams[:beam_width] - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) - - beam_search_output = RequestOutput( - request_id=request_id, - prompt=prompt, - outputs=[ - CompletionOutput( - text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens, - index=i, - logprobs=beam.cum_logprob, - ) for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=tokenizedPrompt, - prompt_logprobs=None) - - logger.info(beam_search_output) - - yield beam_search_output - @overload # DEPRECATED def encode( self, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d7ff743e0ada6..16ceddf13511c 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,38 +1,49 @@ -from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, - runtime_checkable) +import asyncio +from abc import ABC, abstractmethod +from typing import AsyncGenerator, List, Mapping, Optional, Union +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, + RequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import collect_from_async_generator, random_uuid +logger = init_logger(__name__) -@runtime_checkable -class EngineClient(Protocol): + +class EngineClient(ABC): """Protocol class for Clients to Engine""" @property + @abstractmethod def is_running(self) -> bool: ... @property + @abstractmethod def is_stopped(self) -> bool: ... @property + @abstractmethod def errored(self) -> bool: ... @property + @abstractmethod def dead_error(self) -> BaseException: ... + @abstractmethod def generate( self, prompt: PromptType, @@ -46,6 +57,101 @@ def generate( """Generate outputs for a request.""" ... + async def beam_search( + self, + prompt: Union[PromptType, List[int]], + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + + tokenizer = await self.get_tokenizer(lora_request=None) + tokenizedPrompt = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + tokenizedLength = len(tokenizedPrompt) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens, + index=i, + logprobs=beam.cum_logprob, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=tokenizedPrompt, + prompt_logprobs=None) + + yield beam_search_output + + @abstractmethod def encode( self, prompt: PromptType, @@ -58,6 +164,7 @@ def encode( """Generate outputs for a request from an embedding model.""" ... + @abstractmethod async def abort(self, request_id: str) -> None: """Abort a request. @@ -65,14 +172,17 @@ async def abort(self, request_id: str) -> None: request_id: The unique id of the request. """ + @abstractmethod async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" ... + @abstractmethod async def get_decoding_config(self) -> DecodingConfig: ... """Get the decoding configuration of the vLLM engine.""" + @abstractmethod async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, @@ -80,9 +190,11 @@ async def get_tokenizer( """Get the appropriate tokenizer for the request""" ... + @abstractmethod async def is_tracing_enabled(self) -> bool: ... + @abstractmethod async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, @@ -90,14 +202,17 @@ async def do_log_stats( ) -> None: ... + @abstractmethod async def check_health(self) -> None: """Raise if unhealthy""" ... + @abstractmethod async def start_profile(self) -> None: """Start profiling the engine""" ... + @abstractmethod async def stop_profile(self) -> None: """Start profiling the engine""" ... diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4931195ae0e02..9470b6ea03ef6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,8 +9,6 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, @@ -237,11 +235,6 @@ async def create_chat_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - assert isinstance(self.engine_client, - (AsyncLLMEngine, - MQLLMEngineClient)), \ - "Beam search is only supported with" \ - "AsyncLLMEngine and MQLLMEngineClient." result_generator = self.engine_client.beam_search( engine_inputs['prompt_token_ids'], request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 077312dd1414e..7aa4587e23c15 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,8 +8,6 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block @@ -151,11 +149,6 @@ async def create_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - assert isinstance(self.engine_client, - (AsyncLLMEngine, - MQLLMEngineClient)), \ - "Beam search is only supported with" \ - "AsyncLLMEngine and MQLLMEngineClient." generator = self.engine_client.beam_search( prompt_inputs["prompt_token_ids"], request_id_item,