diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py new file mode 100644 index 0000000000000..217f11d14d30a --- /dev/null +++ b/benchmarks/benchmark_throughput_async.py @@ -0,0 +1,480 @@ +"""Benchmark offline inference throughput.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.utils import FlexibleArgumentParser, merge_async_iterators + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +async def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, + use_v2_block_manager: bool = False, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, +) -> float: + from vllm import SamplingParams + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + use_v2_block_manager=use_v2_block_manager, + disable_async_output_proc=disable_async_output_proc, + worker_use_ray=False, + engine_use_ray=False, + disable_log_requests=True, + ) + + decoupled = True + + async with build_async_engine_client_from_engine_args( + engine_args, not decoupled) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + # generator = await llm.generate(prompt, sp, request_id=f"test{i}") + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + use_beam_search: bool, + max_batch_size: int, + trust_remote_code: bool, +) -> float: + assert not use_beam_search + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: List[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt, prompt_len, output_len = requests[i] + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + _, next_prompt_len, next_output_len = requests[i + 1] + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=not use_beam_search, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def run_mii( + requests: List[Tuple[str, int, int]], + model: str, + tensor_parallel_size: int, + output_len: int, +) -> float: + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) + prompts = [prompt for prompt, _, _ in requests] + + start = time.perf_counter() + llm.generate(prompts, max_new_tokens=output_len) + end = time.perf_counter() + client = client(model) + client.terminate_server() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + coro = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.num_scheduler_steps, + args.use_v2_block_manager, args.download_dir, args.load_format, + args.disable_async_output_proc) + + elapsed_time = uvloop.run(coro) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.use_beam_search, args.hf_max_batch_size, + args.trust_remote_code) + elif args.backend == "mii": + elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, + args.output_len) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') + parser.add_argument( + "--num-scheduler-steps", + type=int, + default=1, + help="Maximum number of forward steps per scheduler call.") + parser.add_argument("--use-v2-block-manager", + action='store_true', + help="Enable block manager v2.") + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="Enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + parser.add_argument( + "--disable-async-output-proc", + action='store_true', + default=False, + help="Disable async output processor for vLLM backend.") + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + elif args.backend == "hf": + if args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.use_beam_search: + raise ValueError("Beam search is not supported for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") + main(args) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 58519f978d340..0b77ed4d25584 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -14,14 +14,12 @@ model = models.data[0].id # Completion API -stream = False +stream = True completion = client.completions.create( model=model, prompt="A robot may not injure a human being", - echo=False, - n=2, stream=stream, - logprobs=3) + max_tokens=100) print("Completion results:") if stream: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe8053fffb7b..17b9ed40e41cf 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -761,11 +761,6 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None - @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None - def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/engine/multiprocessing/__init__.py similarity index 76% rename from vllm/entrypoints/openai/rpc/__init__.py rename to vllm/engine/multiprocessing/__init__.py index efc7e43afdcc9..cf566933801e9 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -10,13 +10,6 @@ # Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 - -# HWM is set to Infinity. -VLLM_RPC_ZMQ_HWM = 0 - - @dataclass class RPCGenerateRequest: inputs: PromptInputs @@ -31,20 +24,20 @@ class RPCGenerateRequest: class RPCAbortRequest: request_id: str - class RPCUtilityRequest(Enum): + DO_LOG_STATS = 1 + CHECK_HEALTH = 2 + +class RPCStartupRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 GET_DECODING_CONFIG = 3 GET_PARALLEL_CONFIG = 4 GET_SCHEDULER_CONFIG = 5 GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 + GET_TRACING_ENABLED = 7 + CLIENT_IS_READY = 8 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] + RPCUtilityRequest, RPCStartupRequest] diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py new file mode 100644 index 0000000000000..c49a836755d4a --- /dev/null +++ b/vllm/engine/multiprocessing/mp_client.py @@ -0,0 +1,388 @@ +import asyncio +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union) + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf: disable +from vllm.engine.multiprocessing import (RPC_REQUEST_TYPE, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCStartupRequest, + RPCUtilityRequest) +# yapf: enable +from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +logger = init_logger(__name__) + + +class MPClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MPEngineClient: + + def __init__(self, ipc_path: str): + self.context = zmq.asyncio.Context() + self._errored = False + + # Send RPCGenerateRequest to the MPLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}_input_socket") + + # Receive streams of RequestOutput from the MPLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}_output_socket") + + # IPC path for ack of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}_health_socket") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}_data_socket" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_handler = asyncio.create_task(self.run_output_handler()) + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_output_handler(self): + # Stream lists of RequestOutput from MPLLMEngine. + while True: + message: Frame = await self.output_socket.recv(copy=False) + request_outputs: List[RequestOutput] = pickle.loads(message.buffer) + + for output in request_outputs: + if isinstance(output, tuple): + # Exception case + request_id, output = output + else: + request_id = output.request_id + + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + with self.get_data_socket() as socket: + + # Wait until server is ready. + await self._wait_for_server_rpc(socket) + + # Get the configs. + self.model_config = await self._get_model_config_rpc(socket) + self.decoding_config = await self._get_decoding_config_rpc(socket) + self.tracing_flag = await self._is_tracing_enabled_rpc(socket) + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await + self._get_scheduler_config_rpc(socket)), + parallel_config=(await self._get_parallel_config_rpc(socket)), + enable_lora=bool(await self._get_lora_config_rpc(socket)), + ) + + # Notify MPLLMEngine client is ready to start sending requests. + await self._notify_ready(socket) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets associated with this context and + # then terminate the context. + self.output_socket.close() + self.input_socket.close() + self.health_socket.close() + self.context.destroy(linger=0) + + # TODO: cancel the handler task. + + async def _send_get_data_rpc_request(self, request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((cloudpickle.dumps(request), ), copy=False) + + # Make sure the server responds + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("Server didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, Exception): + # Re-raise exceptions returned by the server + raise data + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + elif isinstance(data, Exception): + logger.error(error_message) + raise data + else: + raise ValueError(error_message) + + return data + + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + await socket.send_multipart((cloudpickle.dumps(request), )) + + # TODO: is there a way to ack this if we are using the input_socket? + # I don't think so, b/c we are using PUSH/PULL w/out identities so no + # way to preserve order. + + async def _ack_one_way_rpc_request( + self, + timeout: int, + expected_str: str, + error_message: str, + socket: Socket, + ): + if await socket.poll(timeout=timeout) == 0: + raise TimeoutError(f"MPLLMEngine didn't reply within {timeout}ms") + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + if not isinstance(response, str) or response != expected_str: + if isinstance(response, Exception): + logger.error(error_message) + raise response + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket): + """Wait for the RPCServer to start up.""" + + # Readiness probe. + request = RPCStartupRequest.IS_SERVER_READY + await socket.send_multipart((cloudpickle.dumps(request), )) + + # Raises TimeoutError if not ack, causing a retry. + await self._ack_one_way_rpc_request( + expected_str=VLLM_RPC_SUCCESS_STR, + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + error_message="Unable to start RPC Server", + socket=socket) + + async def _notify_ready(self, socket: Socket): + """Get the RPCServer that the RPCClient is ready""" + + await self._send_one_way_rpc_request( + request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) + + async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: + """Get the ModelConfig object from the RPC Server""" + + return await self._send_get_data_rpc_request( + RPCStartupRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server", + socket=socket) + + async def _get_decoding_config_rpc(self, socket: Socket) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCStartupRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server", + socket=socket) + + async def _get_parallel_config_rpc(self, socket: Socket) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCStartupRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ParallelConfig from RPC Server", + socket=socket) + + async def _get_scheduler_config_rpc(self, + socket: Socket) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCStartupRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server", + socket=socket) + + async def _get_lora_config_rpc(self, socket: Socket) -> LoRAConfig: + """Get LoRAConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCStartupRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server", + socket=socket) + + async def _is_tracing_enabled_rpc(self, socket: Socket) -> bool: + """Get is_tracing_enabled flag from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCStartupRequest.GET_TRACING_ENABLED, + expected_type=bool, + error_message="Could not get is_tracing_enabled from RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + # Suppress timeouts and MPClientClosedError. + # In cases where the server is busy processing requests and a very + # large volume of abort requests arrive, it is likely that the server + # will not be able to ack all of them in time. We have seen this when + # we abort 20k requests at once while another 2k are processing- many + # of them time out, but we see the server successfully abort all of the + # requests. + # In this case we assume that the server has received or will receive + # these abort requests, and ignore the timeout. This prevents a massive + # wall of `TimeoutError` stack traces. + with suppress(MPClientClosedError, TimeoutError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats(self): + """Send a DO_LOG_STATS signal to the RPC Server""" + with suppress(MPClientClosedError): + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + socket=self.input_socket) + + @property + def is_running(self) -> bool: + return not self._errored + + @property + def is_stopped(self) -> bool: + return self._errored + + @property + def errored(self) -> bool: + return self._errored + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + finished = False + try: + + # Send RPCGenerateRequest to the RPCServer. + await self.input_socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) + + while not finished: + request_output = await queue.get() + if isinstance(request_output, BaseException): + finished = True + # On exception, check if the server is still healthy + # possibly setting the `errored` property. + if not self._errored: + try: + await self.check_health() + except Exception as e: + self._errored = True + logger.exception(repr(e)) + raise request_output + + finished = request_output.finished + yield request_output + + finally: + self.output_queues.pop(request_id) + # Request was canceled by the client. + if not finished and not self._errored: + await self.abort(request_id) + + async def check_health(self) -> None: + """Raise if unhealthy""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) + + # Await acknowledgement from MPLLMEngine. + # Note: these requests are not necessarily serial. + # I.e. if two clients A, B send CHECK_HEALTH, the + # response to A could actually be the call send by B. + # TODO: is this bad? + await self._ack_one_way_rpc_request( + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + expected_str=VLLM_RPC_SUCCESS_STR, + error_message="Check health timeout.", + socket=self.health_socket) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py new file mode 100644 index 0000000000000..f839a272e40f7 --- /dev/null +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -0,0 +1,259 @@ +import pickle +from contextlib import contextmanager +from typing import Iterator, List, Type, Union + +import cloudpickle +import ray +import zmq + +from vllm import AsyncEngineArgs, AsyncLLMEngine, LLMEngine +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCStartupRequest, + RPCUtilityRequest) +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 + + +class MPLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in asynchronous manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally to another process. + + The :class:`LLMEngine` is kicked off when a new RPCGenerateRequest + is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()` and sends the RequestOutputs back over + the output_socket. + + Args: + worker_use_ray: Whether to use Ray for model workers. Required for + distributed execution. Should be the same as + `parallel_config.worker_use_ray`. + engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the + async frontend will be executed in a separate process as the + model workers. + async_engine_args: AsyncLLMEngine args + log_requests: Whether to log the requests. + """ + + _engine_class: Type[LLMEngine] = LLMEngine + + def __init__(self, + worker_use_ray: bool, + engine_use_ray: bool, + *args, + ipc_path: str, + log_requests: bool = True, + **kwargs) -> None: + + if engine_use_ray: + raise NotImplementedError("Not yet supported.") + + self.worker_use_ray = worker_use_ray + self.engine_use_ray = engine_use_ray + self.log_requests = log_requests + self.engine = self._init_engine(*args, **kwargs) + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}_input_socket") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}_output_socket") + + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}_health_socket") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}_data_socket" + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MPLLMEngine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + if engine_args.engine_use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + # TODO: better abstraction? + executor_class = AsyncLLMEngine._get_executor_cls(engine_config) + + return cls( + executor_class.uses_ray, + engine_args.engine_use_ray, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + ipc_path=ipc_path, + ) + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + self.input_socket.close() + self.output_socket.close() + self.ctx.destroy(linger=0) + del self.engine + + def _init_engine(self, *args, + **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: + """Initialize the LLMEngine""" + + if not self.engine_use_ray: + engine_class = self._engine_class + elif self.worker_use_ray: + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + else: + raise NotImplementedError("Not supported yet!") + return engine_class(*args, **kwargs) + + def run_background_loop(self): + """Entrypoint that kicks off the background processing loop.""" + + # Allow RPCClient to query data in startup phase. + self.run_startup_loop() + + # Kick off core processing loop. + self.run_engine_loop() + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Loop over startup RPCStatupRequest from RPCClient.""" + + with self.make_data_socket() as socket: + + # Loop until the RPCClient has all the data it needs. + client_is_ready = False + while not client_is_ready: + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.GET_MODEL_CONFIG: + response = self.engine.get_model_config() + elif request == RPCStartupRequest.GET_DECODING_CONFIG: + response = self.engine.get_decoding_config() + elif request == RPCStartupRequest.GET_LORA_CONFIG: + response = self.engine.get_lora_config() + elif request == RPCStartupRequest.GET_SCHEDULER_CONFIG: + response = self.engine.get_scheduler_config() + elif request == RPCStartupRequest.GET_PARALLEL_CONFIG: + response = self.engine.get_parallel_config() + elif request == RPCStartupRequest.GET_TRACING_ENABLED: + response = self.engine.is_tracing_enabled() + elif request == RPCStartupRequest.IS_SERVER_READY: + response = VLLM_RPC_SUCCESS_STR + elif request == RPCStartupRequest.CLIENT_IS_READY: + response = VLLM_RPC_SUCCESS_STR + # Breakout of loop once client is ready. + client_is_ready = True + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + except Exception as e: + socket.send_multipart((identity, pickle.dumps(e)), + copy=False) + + def run_engine_loop(self) -> None: + while True: + # Block until there is a new request. + if not self.engine.has_unfinished_requests(): + self.wait_for_new_input() + + # Handle any new input from the input socket. + self.maybe_handle_new_input() + + # Engine step. + request_outputs = self.engine.step() + + # Stream results to output socket. + self.stream_outputs(request_outputs) + + def wait_for_new_input(self): + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for new request.") + + def stream_outputs(self, request_outputs: List[RequestOutput]): + self.output_socket.send_multipart((pickle.dumps(request_outputs), ), + copy=False) + + def ack_check_health(self): + self.health_socket.send_multipart( + (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) + + def maybe_handle_new_input(self): + """Handle new input with non-blocking IO""" + while self.input_socket.poll(timeout=0) != 0: + message = self.input_socket.recv(copy=False) + request = cloudpickle.loads(message.buffer) + + if isinstance(request, RPCGenerateRequest): + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUtilityRequest): + self._handle_utility_request(request) + else: + raise ValueError(f"Unknown RPCRequest: {request}") + + def _handle_generate_request(self, request: RPCGenerateRequest): + self.engine.add_request( + request_id=request.request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + ) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request([request.request_id]) + + def _handle_utility_request(self, request: RPCUtilityRequest): + if request == RPCUtilityRequest.DO_LOG_STATS: + self.engine.do_log_stats() + elif request == RPCUtilityRequest.CHECK_HEALTH: + self.engine.check_health() + self.ack_check_health() + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str): + engine = MPLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + + engine.run_background_loop() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8df..de6314d532193 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -29,10 +29,6 @@ def is_stopped(self) -> bool: def errored(self) -> bool: ... - @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - def generate( self, inputs: PromptInputs, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb0..f4a9c61a431c1 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -27,15 +27,6 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if engine.limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency %s. To avoid this " - "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", engine.limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency - config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server, engine) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7632e8aa5e32e..2e09954800d02 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,6 +21,9 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +# yapf: enable +from vllm.engine.multiprocessing.mp_client import MPEngineClient +from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger @@ -37,9 +40,6 @@ EmbeddingResponse, ErrorResponse, TokenizeRequest, TokenizeResponse) -# yapf: enable -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -82,7 +82,7 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(10) + await asyncio.sleep(1.) await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: @@ -156,56 +156,56 @@ async def build_async_engine_client_from_engine_args( "and vLLM will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for RPC Path.", - rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) # Build RPCClient, which conforms to AsyncEngineClient Protocol. # NOTE: Actually, this is not true yet. We still need to support # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) + mp_engine_client = MPEngineClient(ipc_path) - # Start RPCServer in separate process (holds the AsyncLLMEngine). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) - rpc_server_process.start() - logger.info("Started engine process with PID %d", - rpc_server_process.pid) + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): - logger.error( - "RPCServer process died before responding " - "to readiness probe") + if not engine_process.is_alive(): + logger.error("Engine process died before responding " + "to readiness probe") yield None return - yield rpc_client # type: ignore[misc] + yield mp_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() # Wait for server process to join - rpc_server_process.join() + engine_process.join() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) + multiprocess.mark_process_dead(engine_process.pid) router = APIRouter() diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py deleted file mode 100644 index 9b88db746be5c..0000000000000 --- a/vllm/entrypoints/openai/rpc/client.py +++ /dev/null @@ -1,451 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Mapping, Optional -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -# yapf: disable -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SOCKET_LIMIT_CUTOFF, - VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -# yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS -from vllm.inputs import PromptInputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -logger = init_logger(__name__) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS - self._errored = False - - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - assert isinstance(socket_limit, int) - if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests vLLM can process. Launch " - "vLLM with --disable-frontend-multiprocessing and open a " - "GitHub issue so we can investigate.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in vLLM w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 - - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" - while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() - self.context.destroy() - - @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(VLLM_RPC_ZMQ_HWM) - try: - socket.connect(INPROC_PROXY_PATH) - yield socket - finally: - socket.close(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.to_proxy_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - return pickle.loads(frame.buffer) - - # Make a new socket connection. - if socket is None: - with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request) - - # Use existing socket connection. - else: - response = await do_rpc_call(socket, request) - - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def _is_tracing_enabled_rpc(self) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - finished = False - try: - with self.to_proxy_socket() as socket: - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - assert isinstance(message, Frame) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output - - finally: - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py deleted file mode 100644 index bebc2faedb680..0000000000000 --- a/vllm/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,237 +0,0 @@ -import asyncio -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import uvloop -import zmq -import zmq.asyncio -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, usage_context=usage_context) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - self.engine.shutdown_background_loop() - # Clear the engine reference so that it can be GC'ed. - del self.engine - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - (identity, pickle.dumps(tracing_flag))) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - try: - results_generator = self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def start_profile(self, identity): - logger.info("Starting profiler...") - await self.engine.start_profile() - logger.info("Profiler started.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - async def stop_profile(self, identity): - logger.info("Stopping profiler...") - await self.engine.stop_profile() - logger.info("Profiler stopped.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - elif request == RPCUtilityRequest.START_PROFILE: - return self.start_profile(identity) - elif request == RPCUtilityRequest.STOP_PROFILE: - return self.stop_profile(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) - uvloop.run(run_server(server))