From a7a6e438d7a79ab57694ef58cc83ccfb2746e418 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 28 Aug 2024 12:10:23 -0700 Subject: [PATCH 01/29] [Benchmark] Add async throughput benchmark Like benchmark_throughput but using AsyncLLMEngine rather than LLM --- benchmarks/benchmark_throughput_async.py | 479 +++++++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 35 +- 2 files changed, 505 insertions(+), 9 deletions(-) create mode 100644 benchmarks/benchmark_throughput_async.py diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py new file mode 100644 index 0000000000000..0b9c2e16a3706 --- /dev/null +++ b/benchmarks/benchmark_throughput_async.py @@ -0,0 +1,479 @@ +"""Benchmark offline inference throughput.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.entrypoints.openai.api_server import build_async_engine_client_from_engine_args +from vllm.utils import merge_async_iterators +from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.utils import FlexibleArgumentParser + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +async def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, + use_v2_block_manager: bool = False, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, +) -> float: + from vllm import LLM, SamplingParams + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + use_v2_block_manager=use_v2_block_manager, + disable_async_output_proc=disable_async_output_proc, + worker_use_ray=False, + engine_use_ray=False, + disable_log_requests=True, + ) + + decoupled = True + + async with build_async_engine_client_from_engine_args(engine_args, + not decoupled) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + use_beam_search: bool, + max_batch_size: int, + trust_remote_code: bool, +) -> float: + assert not use_beam_search + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: List[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt, prompt_len, output_len = requests[i] + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + _, next_prompt_len, next_output_len = requests[i + 1] + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=not use_beam_search, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def run_mii( + requests: List[Tuple[str, int, int]], + model: str, + tensor_parallel_size: int, + output_len: int, +) -> float: + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) + prompts = [prompt for prompt, _, _ in requests] + + start = time.perf_counter() + llm.generate(prompts, max_new_tokens=output_len) + end = time.perf_counter() + client = client(model) + client.terminate_server() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + coro = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.num_scheduler_steps, + args.use_v2_block_manager, args.download_dir, args.load_format, + args.disable_async_output_proc) + + elapsed_time = uvloop.run(coro) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.use_beam_search, args.hf_max_batch_size, + args.trust_remote_code) + elif args.backend == "mii": + elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, + args.output_len) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') + parser.add_argument( + "--num-scheduler-steps", + type=int, + default=1, + help="Maximum number of forward steps per scheduler call.") + parser.add_argument("--use-v2-block-manager", + action='store_true', + help="Enable block manager v2.") + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="Enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + parser.add_argument( + "--disable-async-output-proc", + action='store_true', + default=False, + help="Disable async output processor for vLLM backend.") + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + elif args.backend == "hf": + if args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.use_beam_search: + raise ValueError("Beam search is not supported for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") + main(args) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8e8371ef1559a..e99e4bd951089 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -96,6 +96,22 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + + # Context manager to handle async_engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + global engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[Optional[AsyncEngineClient]]: """ Create AsyncEngineClient, either: - in-process using the AsyncLLMEngine Directly @@ -104,22 +120,21 @@ async def build_async_engine_client( Returns the Client or None if the creation failed. """ - # Context manager to handle async_engine_client lifecycle - # Ensures everything is shutdown and cleaned up on error/exit - global engine_args - engine_args = AsyncEngineArgs.from_cli_args(args) - # Backend itself still global for the silly lil' health handler global async_engine_client # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. - if (model_is_embedding(args.model, args.trust_remote_code, - args.quantization) - or args.disable_frontend_multiprocessing): + if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, + engine_args.quantization) + or disable_frontend_multiprocessing): async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - yield async_engine_client + try: + yield async_engine_client + finally: + async_engine_client.shutdown_background_loop() + async_engine_client = None #TODO return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -192,6 +207,8 @@ async def build_async_engine_client( from prometheus_client import multiprocess multiprocess.mark_process_dead(rpc_server_process.pid) + async_engine_client = None #TODO + router = APIRouter() From ce7d15974028679c4d08742bb763489a4a06c004 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 28 Aug 2024 17:25:19 -0700 Subject: [PATCH 02/29] wip --- vllm/engine/async_llm_engine.py | 135 +++++++++++++++++------- vllm/entrypoints/openai/rpc/__init__.py | 5 + vllm/entrypoints/openai/rpc/client.py | 79 ++++++++++---- vllm/entrypoints/openai/rpc/server.py | 36 +++++-- 4 files changed, 186 insertions(+), 69 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37696bf1d9dc9..fdad2d18c5b8c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -47,7 +47,6 @@ def _log_task_completion(task: asyncio.Task, there is an exception. """ - exception = None try: return_value = task.result() raise AssertionError( @@ -80,8 +79,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, - Exception]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -123,10 +121,11 @@ def _is_raisable(value: Any): class RequestTracker: """Synchronous abstraction for tracking requests.""" - def __init__(self) -> None: + def __init__(self, per_request_streams: bool = True) -> None: + self._per_request_streams = per_request_streams self._request_streams: Dict[str, AsyncStream] = {} self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[AsyncStream, + self._new_requests: asyncio.Queue[Tuple[Optional[AsyncStream], dict]] = asyncio.Queue() self.new_requests_event = asyncio.Event() @@ -186,14 +185,15 @@ def add_request(self, request_id: str, *, verbose: bool = False, - **engine_add_request_kwargs) -> AsyncStream: + **engine_add_request_kwargs) -> Optional[AsyncStream]: """Add a request to be sent to the engine on the next background loop iteration.""" if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") abort_request = partial(self.abort_request, verbose=verbose) - stream = AsyncStream(request_id, abort_request) + stream = AsyncStream(request_id, abort_request) \ + if self._per_request_streams else None self._new_requests.put_nowait((stream, { "request_id": request_id, **engine_add_request_kwargs @@ -234,13 +234,15 @@ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() - request_id = stream.request_id + request_id = new_request["request_id"] if request_id in finished_requests: # The request has already been aborted. - stream.finish(asyncio.CancelledError) + if stream is not None: + stream.finish(asyncio.CancelledError) finished_requests.discard(request_id) else: - self._request_streams[request_id] = stream + if stream is not None: + self._request_streams[request_id] = stream new_requests.append(new_request) return new_requests, finished_requests @@ -639,7 +641,34 @@ def __init__(self, self._errored_with: Optional[BaseException] = None # Lazy initialized fields - self._request_tracker: RequestTracker + self._request_tracker: RequestTracker = None # type: ignore[assignment] + + self._global_queue: Optional[asyncio.Queue] = None + + async def global_output_generator( + self + ) -> AsyncGenerator[List[Union[RequestOutput, EmbeddingRequestOutput, + Tuple[str, BaseException]]], None]: + """Returns a single generator that streams outputs from all + requests. + + Must be called at most once prior to processing any requests, + and if used, generate() will return None rather than a per-request + stream. + """ + if self._global_queue is not None: + raise RuntimeError( + "global_output_generator can only be called once") + if self._request_tracker is not None: + raise RuntimeError( + "global_output_generator must be called before processing " + "any requests") + + self._global_queue = asyncio.Queue() + + # This runs until the engine is shut down + while True: + yield await self._global_queue.get() @classmethod def _get_executor_cls( @@ -763,6 +792,11 @@ def set_errored(self, exc: Exception) -> None: def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) + if self._global_queue is not None: + #TODO clean this up + for request_id in tuple( + self._request_tracker._request_streams.keys()): + self._global_queue.put_nowait((request_id, exc)) async def get_tokenizer( self, @@ -783,7 +817,8 @@ def start_background_loop(self) -> None: if self.is_running: raise RuntimeError("Background loop is already running.") # Initialize the RequestTracker here so it uses the right event loop. - self._request_tracker = RequestTracker() + per_request_streams = self._global_queue is None + self._request_tracker = RequestTracker(per_request_streams) self._background_loop_unshielded = asyncio.get_event_loop( ).create_task(self.run_engine_loop()) @@ -844,11 +879,14 @@ async def engine_step(self, virtual_engine: int) -> bool: await self.engine.add_request_async(**new_request) except ValueError as e: # TODO: use a vLLM specific error for failed validation + request_id = new_request["request_id"] self._request_tracker.process_exception( - new_request["request_id"], + request_id, e, verbose=self.log_requests, ) + if self._global_queue is not None: + self._global_queue.put_nowait((request_id, e)) if aborted_requests: await self._engine_abort(aborted_requests) @@ -859,13 +897,18 @@ async def engine_step(self, virtual_engine: int) -> bool: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. - finished = True + all_finished = True for request_output in request_outputs: - self._request_tracker.process_request_output( - request_output, verbose=self.log_requests) - finished = finished and request_output.finished + finished = request_output.finished + if finished or self._global_queue is None: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + all_finished = all_finished and finished + + if self._global_queue is not None: + self._global_queue.put_nowait(request_outputs) - return not finished + return not all_finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: @@ -950,8 +993,9 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Optional[AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], + None]]: if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -972,7 +1016,7 @@ async def add_request( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request) - return stream.generator() + return stream.generator() if stream is not None else None async def generate( self, @@ -982,7 +1026,7 @@ async def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: + ) -> Optional[AsyncGenerator[RequestOutput, None]]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -1004,6 +1048,9 @@ async def generate( The output `RequestOutput` objects from the LLMEngine for the request. + Unless a global output generator is being used, in which case + this methods will return None. + Details: - If the engine is not running, start the background loop, which iteratively invokes @@ -1047,15 +1094,22 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - inputs, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - ): - yield LLMEngine.validate_output(output, RequestOutput) + maybe_generator = await self.add_request( + request_id, + inputs, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + ) + if maybe_generator is None or not LLMEngine.DO_VALIDATE_OUTPUT: + return maybe_generator + + async def validating_generator(): + async for output in maybe_generator: + yield LLMEngine.validate_output(output, RequestOutput) + + return validating_generator() async def encode( self, @@ -1125,13 +1179,15 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in await self.add_request( - request_id, - inputs, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - ): + generator = await self.add_request( + request_id, + inputs, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + ) + assert generator is not None + async for output in generator: yield LLMEngine.validate_output(output, EmbeddingRequestOutput) async def abort(self, request_id: str) -> None: @@ -1165,6 +1221,9 @@ def _abort(self, request_id: str) -> None: exception=asyncio.CancelledError, verbose=self.log_requests) + if self._global_queue is not None: + self._global_queue.put_nowait((request_id, asyncio.CancelledError)) + async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" if self.engine_use_ray: diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index efc7e43afdcc9..c4cce036281ae 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -17,6 +17,11 @@ VLLM_RPC_ZMQ_HWM = 0 +@dataclass +class RPCOutputStreamRequest: + pass + + @dataclass class RPCGenerateRequest: inputs: PromptInputs diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index c457555c54b9c..51fbbfd3b461a 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,12 +1,14 @@ import asyncio import pickle from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Mapping, Optional +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) from uuid import uuid4 import cloudpickle import zmq import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -16,7 +18,9 @@ VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) + RPCGenerateRequest, + RPCOutputStreamRequest, + RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS from vllm.inputs import PromptInputs @@ -141,12 +145,37 @@ def __init__(self, rpc_path: str): # 1 for generate(), 1 for abort(), do_log_stats(), check_health() self.limit_concurrency = socket_limit // 2 - 2 + self.output_queues: Dict[str, asyncio.Queue] = {} + + self.output_handler = asyncio.create_task(self.run_output_handler()) + async def run_proxy(self, socket_from: Socket, socket_to: Socket): """Background task that runs a proxy""" while True: frames = await socket_from.recv_multipart(copy=False) await socket_to.send_multipart(frames, copy=False) + async def run_output_handler(self): + with self.to_proxy_socket() as socket: + await socket.send_multipart( + (cloudpickle.dumps(RPCOutputStreamRequest()), )) + + # Stream back the results from the RPC Server. + while True: + message: Frame = await socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + for output in request_outputs: + if isinstance(output, tuple): + # Exception case + request_id, output = output + else: + request_id = output.request_id + + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) + async def setup(self): """Setup the client before it starts sending server requests.""" @@ -379,6 +408,9 @@ async def generate( ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue finished = False try: with self.to_proxy_socket() as socket: @@ -392,29 +424,30 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output + ack: Frame = await socket.recv(copy=False) + if len(ack.buffer) != 0: + exception = pickle.loads(ack.buffer) + raise exception + + while not finished: + request_output = await queue.get() + if isinstance(request_output, BaseException): + finished = True + # On exception, check if the server is still healthy + # possibly setting the `errored` property. + if not self._errored: + try: + await self.check_health(socket=socket) + except Exception as e: + self._errored = True + logger.exception(repr(e)) + raise request_output + + finished = request_output.finished + yield request_output finally: + self.output_queues.pop(request_id) # Request was canceled by the client. if not finished and not self._errored: await self.abort(request_id) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index bebc2faedb680..42a66e35a65bd 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -16,7 +16,9 @@ ParallelConfig, SchedulerConfig) from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) + RPCGenerateRequest, + RPCOutputStreamRequest, + RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -102,9 +104,27 @@ async def abort(self, identity, request: RPCAbortRequest): result = e await self.socket.send_multipart((identity, pickle.dumps(result))) + async def stream_outputs(self, identity): + # This runs indefinitely + #TODO handle shutdown + async for outputs in self.engine.global_output_generator(): + # Trim down contents to be equivalent to deltas (other PR for this) + # for output in outputs: + # output.prompt = None + # output.prompt_token_ids = None + # output.prompt_logprobs = None + # for o in output.outputs: + # o.token_ids = [0] + # o.text = " word" + + await self.socket.send_multipart((identity, pickle.dumps(outputs)), + copy=False) + async def generate(self, identity, generate_request: RPCGenerateRequest): + # Empty result to indicate success + result = b'' try: - results_generator = self.engine.generate( + await self.engine.generate( generate_request.inputs, sampling_params=generate_request.sampling_params, request_id=generate_request.request_id, @@ -112,13 +132,10 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): trace_headers=generate_request.trace_headers, prompt_adapter_request=generate_request.prompt_adapter_request) - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) + result = pickle.dumps(e) + + await self.socket.send_multipart((identity, result), copy=False) async def check_health(self, identity): try: @@ -156,6 +173,9 @@ def _make_handler_coro(self, identity, request = cloudpickle.loads(message.buffer) + if isinstance(request, RPCOutputStreamRequest): + return self.stream_outputs(identity) + if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) From d99ce6f2c5034c4292ce90d4bbfcaa0ef0502393 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sat, 31 Aug 2024 19:16:16 +0000 Subject: [PATCH 03/29] stash --- vllm/engine/async_llm_engine.py | 10 ++++++---- vllm/entrypoints/openai/rpc/server.py | 5 ++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index fdad2d18c5b8c..75cc637ca6a13 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -159,14 +159,16 @@ def process_request_output(self, if finished: stream = self._request_streams.pop(request_id, None) + if stream is not None: + stream.finish() else: stream = self._request_streams.get(request_id) # Guard against a KeyError which can occur if the request was aborted # while the output was generated - if stream is not None: - stream.put(request_output) - if finished: - stream.finish() + # if stream is not None: + # stream.put(request_output) + # if finished: + # stream.finish() if verbose and finished: logger.info("Finished request %s.", request_id) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 42a66e35a65bd..c799ab8a35fa9 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -230,6 +230,9 @@ async def run_server_loop(self): async def run_server(server: AsyncEngineRPCServer): + # import pyinstrument + + # with pyinstrument.Profiler(async_mode="disabled") as prof: # Put the server task into the asyncio loop. loop = asyncio.get_running_loop() server_task = loop.create_task(server.run_server_loop()) @@ -249,7 +252,7 @@ def signal_handler() -> None: finally: # Clean up all resources. server.cleanup() - + # prof.write_html("prof-disabled.html", show_all=True) def run_rpc_server(async_engine_args: AsyncEngineArgs, usage_context: UsageContext, rpc_path: str): From 8d6b2e9d434908c7c9cfd686bbd8674175175f95 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 17:51:49 +0000 Subject: [PATCH 04/29] remove proxy --- benchmarks/benchmark_throughput_async.py | 2 +- vllm/entrypoints/openai/rpc/client.py | 100 +++++++---------------- vllm/entrypoints/openai/rpc/server.py | 4 +- 3 files changed, 32 insertions(+), 74 deletions(-) diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py index 0b9c2e16a3706..ec4351cebc29d 100644 --- a/benchmarks/benchmark_throughput_async.py +++ b/benchmarks/benchmark_throughput_async.py @@ -120,7 +120,7 @@ async def run_vllm( disable_log_requests=True, ) - decoupled = True + decoupled = False async with build_async_engine_client_from_engine_args(engine_args, not decoupled) as llm: diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 51fbbfd3b461a..21e8fbbefcbd3 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -104,77 +104,36 @@ def __init__(self, rpc_path: str): self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS self._errored = False - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - assert isinstance(socket_limit, int) - if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests vLLM can process. Launch " - "vLLM with --disable-frontend-multiprocessing and open a " - "GitHub issue so we can investigate.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in vLLM w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 + self.socket: Socket = self.context.socket(zmq.constants.DEALER) + self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) + self.socket.connect(rpc_path) + self.rpc_path = rpc_path + self.limit_concurrency = None self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_handler = asyncio.create_task(self.run_output_handler()) + - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" + async def run_output_handler(self): + await self.socket.send_multipart( + (cloudpickle.dumps(RPCOutputStreamRequest()), )) + + # Stream back the results from the RPC Server. while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) + message: Frame = await self.socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) - async def run_output_handler(self): - with self.to_proxy_socket() as socket: - await socket.send_multipart( - (cloudpickle.dumps(RPCOutputStreamRequest()), )) - - # Stream back the results from the RPC Server. - while True: - message: Frame = await socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) - - for output in request_outputs: - if isinstance(output, tuple): - # Exception case - request_id, output = output - else: - request_id = output.request_id - - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(output) + for output in request_outputs: + if isinstance(output, tuple): + # Exception case + request_id, output = output + else: + request_id = output.request_id + + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) async def setup(self): """Setup the client before it starts sending server requests.""" @@ -200,12 +159,11 @@ def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets associated with this context and # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() + self.socket.close() self.context.destroy() @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: + def rpc_get_data_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. @@ -221,7 +179,7 @@ def to_proxy_socket(self) -> Iterator[Socket]: socket = self.context.socket(zmq.constants.DEALER) socket.set_hwm(VLLM_RPC_ZMQ_HWM) try: - socket.connect(INPROC_PROXY_PATH) + socket.connect(self.rpc_path) yield socket finally: socket.close(linger=0) @@ -231,7 +189,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" - with self.to_proxy_socket() as socket: + with self.rpc_get_data_socket() as socket: # Ping RPCServer with a request. await socket.send_multipart((cloudpickle.dumps(request), ), copy=False) @@ -280,7 +238,7 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): # Make a new socket connection. if socket is None: - with self.to_proxy_socket() as socket: + with self.rpc_get_data_socket() as socket: response = await do_rpc_call(socket, request) # Use existing socket connection. @@ -413,7 +371,7 @@ async def generate( self.output_queues[request_id] = queue finished = False try: - with self.to_proxy_socket() as socket: + with self.rpc_get_data_socket() as socket: # Send RPCGenerateRequest to the RPCServer. await socket.send_multipart((cloudpickle.dumps( RPCGenerateRequest( diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index c799ab8a35fa9..4b1ba19327b70 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -40,9 +40,9 @@ def __init__(self, async_engine_args: AsyncEngineArgs, self.context = zmq.asyncio.Context() # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) + self.socket: Socket = self.context.socket(zmq.constants.ROUTER) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) + self.socket.bind(rpc_path) def cleanup(self): """Cleanup all resources.""" From 14f36373e0b549c55e2caa281ffb5e22904545da Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 20:52:21 +0000 Subject: [PATCH 05/29] stash --- benchmarks/benchmark_throughput_async.py | 7 +- examples/openai_completion_client.py | 6 +- vllm/engine/async_llm_engine.py | 1 + vllm/entrypoints/openai/api_server.py | 16 ++- vllm/entrypoints/openai/rpc/__init__.py | 1 + vllm/entrypoints/openai/rpc/client.py | 145 ++++++++++++----------- vllm/utils.py | 1 - 7 files changed, 95 insertions(+), 82 deletions(-) diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py index ec4351cebc29d..54eed0f4de783 100644 --- a/benchmarks/benchmark_throughput_async.py +++ b/benchmarks/benchmark_throughput_async.py @@ -1,5 +1,6 @@ """Benchmark offline inference throughput.""" import argparse +import asyncio import json import random import time @@ -120,7 +121,7 @@ async def run_vllm( disable_log_requests=True, ) - decoupled = False + decoupled = True async with build_async_engine_client_from_engine_args(engine_args, not decoupled) as llm: @@ -143,15 +144,15 @@ async def run_vllm( generators = [] start = time.perf_counter() for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + # generator = await llm.generate(prompt, sp, request_id=f"test{i}") generator = llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) + generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass end = time.perf_counter() return end - start - def run_hf( requests: List[Tuple[str, int, int]], model: str, diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 58519f978d340..13f98d3220366 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -14,14 +14,12 @@ model = models.data[0].id # Completion API -stream = False +stream = True completion = client.completions.create( model=model, prompt="A robot may not injure a human being", - echo=False, - n=2, stream=stream, - logprobs=3) + max_tokens=1000) print("Completion results:") if stream: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 75cc637ca6a13..e4bc40150af11 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1104,6 +1104,7 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, ) + return maybe_generator if maybe_generator is None or not LLMEngine.DO_VALIDATE_OUTPUT: return maybe_generator diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e99e4bd951089..34daf5bcb35fc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -39,7 +39,8 @@ TokenizeResponse) # yapf: enable from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server +# from vllm.entrypoints.openai.rpc.server import run_rpc_server +from vllm.engine.llm_engine2 import run_rpc_server from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -84,8 +85,9 @@ async def _force_log(): while True: await asyncio.sleep(10) await async_engine_client.do_log_stats() - - if not engine_args.disable_log_stats: + + # if not engine_args.disable_log_stats: + if False: task = asyncio.create_task(_force_log()) _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) @@ -169,9 +171,11 @@ async def build_async_engine_client_from_engine_args( context = multiprocessing.get_context("spawn") # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) + # rpc_server_process = context.Process( + # target=run_rpc_server, + # args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) + + rpc_server_process = context.Process(target=run_rpc_server, args=(engine_args,)) rpc_server_process.start() logger.info("Started engine process with PID %d", rpc_server_process.pid) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index c4cce036281ae..4bf24bdc37f46 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -49,6 +49,7 @@ class RPCUtilityRequest(Enum): IS_TRACING_ENABLED = 9 START_PROFILE = 10 STOP_PROFILE = 11 + CLIENT_IS_READY = 11 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 21e8fbbefcbd3..c71f250844224 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -104,24 +104,47 @@ def __init__(self, rpc_path: str): self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS self._errored = False - # IPC connection to RPC Server (uses unix sockets). - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - self.rpc_path = rpc_path + self.new_req_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.new_req_socket.connect("ipc:///tmp/new_req_socket") + + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect("ipc:///tmp/output_socket") + + # self.data_socket: Socket = self.context.socket(zmq.constants.DEALER) + # self.data_socket.connect("ipc:///tmp/data_socket") self.limit_concurrency = None self.output_queues: Dict[str, asyncio.Queue] = {} self.output_handler = asyncio.create_task(self.run_output_handler()) + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + # Connect to the RPCServer via the proxy. + + # Raise a sensible error if the client was already closed. + # This can happen if a server shutdown is triggered but some coroutines + # are still running requests. + # There should not be a race condition with this check because we don't + # yield to the event loop between here and opening the socket. + if self.context.closed: + raise RPCClientClosedError("The ZMQ client has already shut down") + + # Note that we use DEALER to enable asynchronous communication + # to enable streaming. + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect("ipc:///tmp/data_socket") + yield socket + finally: + socket.close(linger=0) async def run_output_handler(self): - await self.socket.send_multipart( - (cloudpickle.dumps(RPCOutputStreamRequest()), )) + # await self.socket.send_multipart( + # (cloudpickle.dumps(RPCOutputStreamRequest()), )) # Stream back the results from the RPC Server. while True: - message: Frame = await self.socket.recv(copy=False) + message: Frame = await self.output_socket.recv(copy=False) request_outputs = pickle.loads(message.buffer) for output in request_outputs: @@ -155,69 +178,50 @@ async def setup(self): enable_lora=bool(await self._get_lora_config_rpc()), ) + await self._notify_ready() + def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets associated with this context and # then terminate the context. - self.socket.close() - self.context.destroy() - - @contextmanager - def rpc_get_data_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. + self.context.destroy(linger=0) - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(VLLM_RPC_ZMQ_HWM) - try: - socket.connect(self.rpc_path) - yield socket - finally: - socket.close(linger=0) async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, expected_type: Any, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" - with self.rpc_get_data_socket() as socket: + with self.get_data_socket() as socket: # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) + await socket.send_multipart( + (cloudpickle.dumps(request), ), + copy=False) # Make sure the server responds if await socket.poll(timeout=self._data_timeout) == 0: raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") + f"{self._data_timeout} ms") # Await the data from the Server. frame = await socket.recv(copy=False) data = pickle.loads(frame.buffer) - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) + if isinstance(data, Exception): + # Re-raise exceptions returned by the server raise data - else: - raise ValueError(error_message) - return data + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + elif isinstance(data, Exception): + logger.error(error_message) + raise data + else: + raise ValueError(error_message) + + return data async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, @@ -236,12 +240,9 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): frame = await socket.recv(copy=False) return pickle.loads(frame.buffer) - # Make a new socket connection. if socket is None: - with self.rpc_get_data_socket() as socket: + with self.get_data_socket() as socket: response = await do_rpc_call(socket, request) - - # Use existing socket connection. else: response = await do_rpc_call(socket, request) @@ -270,6 +271,13 @@ async def _wait_for_server_rpc(self): request=RPCUtilityRequest.IS_SERVER_READY, error_message="Unable to start RPC Server") + async def _notify_ready(self): + """Get the RPCServer that the RPCClient is ready""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.CLIENT_IS_READY, + error_message="Unable to notify RPC Server of client readiness") + async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -371,21 +379,21 @@ async def generate( self.output_queues[request_id] = queue finished = False try: - with self.rpc_get_data_socket() as socket: - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - ack: Frame = await socket.recv(copy=False) - if len(ack.buffer) != 0: - exception = pickle.loads(ack.buffer) - raise exception + + # Send RPCGenerateRequest to the RPCServer. + await self.new_req_socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) + + # ack: Frame = await socket.recv(copy=False) + # if len(ack.buffer) != 0: + # exception = pickle.loads(ack.buffer) + # raise exception while not finished: request_output = await queue.get() @@ -395,7 +403,8 @@ async def generate( # possibly setting the `errored` property. if not self._errored: try: - await self.check_health(socket=socket) + # await self.check_health(socket=socket) + pass except Exception as e: self._errored = True logger.exception(repr(e)) diff --git a/vllm/utils.py b/vllm/utils.py index dab8e5fe04359..dd255684cd0a0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -449,7 +449,6 @@ async def merge_async_iterators( It also optionally polls a provided function at least once per second to check for client cancellation. """ - # Can use anext() in python >= 3.10 awaits = { ensure_future(pair[1].__anext__()): pair From 3b8311bc70d64086fef3b034de95cf36af442eaa Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:03:04 +0000 Subject: [PATCH 06/29] added mp_llm_engine --- vllm/engine/mp_llm_engine.py | 119 +++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 vllm/engine/mp_llm_engine.py diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/mp_llm_engine.py new file mode 100644 index 0000000000000..cb639021cc244 --- /dev/null +++ b/vllm/engine/mp_llm_engine.py @@ -0,0 +1,119 @@ +import zmq +import cloudpickle, pickle +from vllm.logger import init_logger +from vllm import EngineArgs, LLMEngine +from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, + VLLM_RPC_ZMQ_HWM, + RPCAbortRequest, + RPCGenerateRequest, + RPCOutputStreamRequest, + RPCUtilityRequest) + +logger = init_logger(__name__) + +class MPLLMEngine: + def __init__(self, engine_args) -> None: + self.engine = LLMEngine.from_engine_args(engine_args) + + self.ctx = zmq.Context() + + self.new_req_socket = self.ctx.socket(zmq.constants.PULL) + self.new_req_socket.bind("ipc:///tmp/new_req_socket") + + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind("ipc:///tmp/output_socket") + + self.data_socket = self.ctx.socket(zmq.constants.ROUTER) + self.data_socket.bind("ipc:///tmp/data_socket") + + def run(self): + logger.info("Running Startup Loop.") + self.startup_loop() + logger.info("Running Engine Loop.") + self.engine_loop() + + def startup_loop(self): + client_is_ready = False + while not client_is_ready: + identity, message = self.data_socket.recv_multipart(copy=False) + request = cloudpickle.loads(message.buffer) + if request in [ + RPCUtilityRequest.GET_MODEL_CONFIG, + RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCUtilityRequest.GET_DECODING_CONFIG, + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCUtilityRequest.GET_LORA_CONFIG + ]: + config = self.get_config(request) + self.data_socket.send_multipart((identity, pickle.dumps(config)), copy=False) + elif request == RPCUtilityRequest.IS_SERVER_READY: + self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + self.data_socket.send_multipart((identity, pickle.dumps(self.engine.is_tracing_enabled())), copy=False) + elif request == RPCUtilityRequest.CLIENT_IS_READY: + self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) + client_is_ready = True + self.data_socket.close() + del self.data_socket + + def engine_loop(self): + has_requests_in_progress = False + while True: + if not has_requests_in_progress: + self.wait_for_new_requests() + has_requests_in_progress = self.engine_step() + + def engine_step(self): + self.add_new_requests() + request_outputs = self.engine.step() + self.send_request_outputs(request_outputs) + + all_finished = True + for request_output in request_outputs: + finished = request_output.finished + if not finished: + all_finished = False + break + + return not all_finished + + def send_request_outputs(self, request_outputs): + self.output_socket.send_multipart( + (pickle.dumps(request_outputs),), copy=False) + + def add_new_requests(self): + while self.new_req_socket.poll(timeout=0) != 0: + message = self.new_req_socket.recv(copy=False) + generate_rpc_request = pickle.loads(message.buffer) + self.engine.add_request( + request_id=generate_rpc_request.request_id, + inputs=generate_rpc_request.inputs, + params=generate_rpc_request.sampling_params, + lora_request=generate_rpc_request.lora_request, + trace_headers=generate_rpc_request.trace_headers, + prompt_adapter_request=generate_rpc_request.prompt_adapter_request, + ) + + def wait_for_new_requests(self): + while self.new_req_socket.poll(timeout=1000) == 0: + logger.info("Waiting for new requests...") + logger.info("Found new request!") + + def get_config(self, request): + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + model_config = self.engine.get_model_config() + return model_config + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.engine.get_decoding_config() + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.engine.get_lora_config() + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.engine.get_scheduler_config() + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.engine.get_parallel_config() + else: + raise ValueError("Unknown Config Request: %s", request) + +def run_rpc_server(engine_args: EngineArgs): + engine = RPCLLMEngine(engine_args) + engine.run() From 5e2eb7449b3a414b4834c4bca616c7dd648a27e1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:05:35 +0000 Subject: [PATCH 07/29] fixed --- vllm/entrypoints/openai/api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 34daf5bcb35fc..cdba0a0ecc9a1 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -40,7 +40,7 @@ # yapf: enable from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient # from vllm.entrypoints.openai.rpc.server import run_rpc_server -from vllm.engine.llm_engine2 import run_rpc_server +from vllm.engine.mp_llm_engine import run_rpc_server from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding From aa62f2e4137f6dcb0bea7b250923d4752ea95028 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:06:37 +0000 Subject: [PATCH 08/29] format --- vllm/engine/mp_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/mp_llm_engine.py index cb639021cc244..ff376208ed023 100644 --- a/vllm/engine/mp_llm_engine.py +++ b/vllm/engine/mp_llm_engine.py @@ -115,5 +115,5 @@ def get_config(self, request): raise ValueError("Unknown Config Request: %s", request) def run_rpc_server(engine_args: EngineArgs): - engine = RPCLLMEngine(engine_args) + engine = MPLLMEngine(engine_args) engine.run() From 863081bc1320244f4c172b5312791d4a38c07e19 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:19:20 +0000 Subject: [PATCH 09/29] cleanup --- vllm/entrypoints/openai/rpc/server.py | 260 -------------------------- 1 file changed, 260 deletions(-) delete mode 100644 vllm/entrypoints/openai/rpc/server.py diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py deleted file mode 100644 index 4b1ba19327b70..0000000000000 --- a/vllm/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,260 +0,0 @@ -import asyncio -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import uvloop -import zmq -import zmq.asyncio -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, - RPCOutputStreamRequest, - RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, usage_context=usage_context) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.ROUTER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.bind(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - self.engine.shutdown_background_loop() - # Clear the engine reference so that it can be GC'ed. - del self.engine - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - (identity, pickle.dumps(tracing_flag))) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def stream_outputs(self, identity): - # This runs indefinitely - #TODO handle shutdown - async for outputs in self.engine.global_output_generator(): - # Trim down contents to be equivalent to deltas (other PR for this) - # for output in outputs: - # output.prompt = None - # output.prompt_token_ids = None - # output.prompt_logprobs = None - # for o in output.outputs: - # o.token_ids = [0] - # o.text = " word" - - await self.socket.send_multipart((identity, pickle.dumps(outputs)), - copy=False) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - # Empty result to indicate success - result = b'' - try: - await self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - except Exception as e: - result = pickle.dumps(e) - - await self.socket.send_multipart((identity, result), copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def start_profile(self, identity): - logger.info("Starting profiler...") - await self.engine.start_profile() - logger.info("Profiler started.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - async def stop_profile(self, identity): - logger.info("Stopping profiler...") - await self.engine.stop_profile() - logger.info("Profiler stopped.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCOutputStreamRequest): - return self.stream_outputs(identity) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - elif request == RPCUtilityRequest.START_PROFILE: - return self.start_profile(identity) - elif request == RPCUtilityRequest.STOP_PROFILE: - return self.stop_profile(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # import pyinstrument - - # with pyinstrument.Profiler(async_mode="disabled") as prof: - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - # prof.write_html("prof-disabled.html", show_all=True) - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) - uvloop.run(run_server(server)) From 965b97a9e18a83987414ec7b6d7cb083d6a83598 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:20:24 +0000 Subject: [PATCH 10/29] revert asyncllmengine --- vllm/engine/async_llm_engine.py | 214 +++++++++++++------------------- 1 file changed, 88 insertions(+), 126 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e4bc40150af11..159281dabde4a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -22,11 +22,12 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -47,6 +48,7 @@ def _log_task_completion(task: asyncio.Task, there is an exception. """ + exception = None try: return_value = task.result() raise AssertionError( @@ -79,7 +81,8 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -121,11 +124,10 @@ def _is_raisable(value: Any): class RequestTracker: """Synchronous abstraction for tracking requests.""" - def __init__(self, per_request_streams: bool = True) -> None: - self._per_request_streams = per_request_streams + def __init__(self) -> None: self._request_streams: Dict[str, AsyncStream] = {} self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() - self._new_requests: asyncio.Queue[Tuple[Optional[AsyncStream], + self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() self.new_requests_event = asyncio.Event() @@ -159,16 +161,14 @@ def process_request_output(self, if finished: stream = self._request_streams.pop(request_id, None) - if stream is not None: - stream.finish() else: stream = self._request_streams.get(request_id) # Guard against a KeyError which can occur if the request was aborted # while the output was generated - # if stream is not None: - # stream.put(request_output) - # if finished: - # stream.finish() + if stream is not None: + stream.put(request_output) + if finished: + stream.finish() if verbose and finished: logger.info("Finished request %s.", request_id) @@ -187,15 +187,14 @@ def add_request(self, request_id: str, *, verbose: bool = False, - **engine_add_request_kwargs) -> Optional[AsyncStream]: + **engine_add_request_kwargs) -> AsyncStream: """Add a request to be sent to the engine on the next background loop iteration.""" if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") abort_request = partial(self.abort_request, verbose=verbose) - stream = AsyncStream(request_id, abort_request) \ - if self._per_request_streams else None + stream = AsyncStream(request_id, abort_request) self._new_requests.put_nowait((stream, { "request_id": request_id, **engine_add_request_kwargs @@ -236,15 +235,13 @@ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() - request_id = new_request["request_id"] + request_id = stream.request_id if request_id in finished_requests: # The request has already been aborted. - if stream is not None: - stream.finish(asyncio.CancelledError) + stream.finish(asyncio.CancelledError) finished_requests.discard(request_id) else: - if stream is not None: - self._request_streams[request_id] = stream + self._request_streams[request_id] = stream new_requests.append(new_request) return new_requests, finished_requests @@ -283,6 +280,10 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # skip the scheduler if there are any remaining steps in the seq groups. @@ -293,17 +294,27 @@ async def step_async( # Clear outputs on scheduler iteration start ctx.request_outputs.clear() + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # If current scheduler iteration has no async postprocessor, - # then we need first to drain the pending async postprocessor - # before moving forward + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -315,9 +326,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -343,8 +351,13 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ - virtual_engine] + async_callback = self.async_callback_multi_step[ + virtual_engine] if use_async_and_multi_step \ + else self.async_callback[virtual_engine] + + execute_model_req.async_callback = async_callback + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step # Execute the model. output = await self.model_executor.execute_model_async( @@ -354,7 +367,7 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -366,22 +379,25 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(virtual_engine=virtual_engine, @@ -394,7 +410,11 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - ctx.request_outputs = [] + # Multi-step case + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) @@ -643,34 +663,7 @@ def __init__(self, self._errored_with: Optional[BaseException] = None # Lazy initialized fields - self._request_tracker: RequestTracker = None # type: ignore[assignment] - - self._global_queue: Optional[asyncio.Queue] = None - - async def global_output_generator( - self - ) -> AsyncGenerator[List[Union[RequestOutput, EmbeddingRequestOutput, - Tuple[str, BaseException]]], None]: - """Returns a single generator that streams outputs from all - requests. - - Must be called at most once prior to processing any requests, - and if used, generate() will return None rather than a per-request - stream. - """ - if self._global_queue is not None: - raise RuntimeError( - "global_output_generator can only be called once") - if self._request_tracker is not None: - raise RuntimeError( - "global_output_generator must be called before processing " - "any requests") - - self._global_queue = asyncio.Queue() - - # This runs until the engine is shut down - while True: - yield await self._global_queue.get() + self._request_tracker: RequestTracker @classmethod def _get_executor_cls( @@ -794,11 +787,6 @@ def set_errored(self, exc: Exception) -> None: def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) - if self._global_queue is not None: - #TODO clean this up - for request_id in tuple( - self._request_tracker._request_streams.keys()): - self._global_queue.put_nowait((request_id, exc)) async def get_tokenizer( self, @@ -819,8 +807,7 @@ def start_background_loop(self) -> None: if self.is_running: raise RuntimeError("Background loop is already running.") # Initialize the RequestTracker here so it uses the right event loop. - per_request_streams = self._global_queue is None - self._request_tracker = RequestTracker(per_request_streams) + self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( ).create_task(self.run_engine_loop()) @@ -881,14 +868,11 @@ async def engine_step(self, virtual_engine: int) -> bool: await self.engine.add_request_async(**new_request) except ValueError as e: # TODO: use a vLLM specific error for failed validation - request_id = new_request["request_id"] self._request_tracker.process_exception( - request_id, + new_request["request_id"], e, verbose=self.log_requests, ) - if self._global_queue is not None: - self._global_queue.put_nowait((request_id, e)) if aborted_requests: await self._engine_abort(aborted_requests) @@ -899,18 +883,13 @@ async def engine_step(self, virtual_engine: int) -> bool: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. - all_finished = True + finished = True for request_output in request_outputs: - finished = request_output.finished - if finished or self._global_queue is None: - self._request_tracker.process_request_output( - request_output, verbose=self.log_requests) - all_finished = all_finished and finished - - if self._global_queue is not None: - self._global_queue.put_nowait(request_outputs) + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + finished = finished and request_output.finished - return not all_finished + return not finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: @@ -995,9 +974,8 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Optional[AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], - None]]: + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -1018,7 +996,7 @@ async def add_request( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request) - return stream.generator() if stream is not None else None + return stream.generator() async def generate( self, @@ -1028,7 +1006,7 @@ async def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> Optional[AsyncGenerator[RequestOutput, None]]: + ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -1050,9 +1028,6 @@ async def generate( The output `RequestOutput` objects from the LLMEngine for the request. - Unless a global output generator is being used, in which case - this methods will return None. - Details: - If the engine is not running, start the background loop, which iteratively invokes @@ -1096,23 +1071,15 @@ async def generate( >>> # Process and return the final output >>> ... """ - maybe_generator = await self.add_request( - request_id, - inputs, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - ) - return maybe_generator - if maybe_generator is None or not LLMEngine.DO_VALIDATE_OUTPUT: - return maybe_generator - - async def validating_generator(): - async for output in maybe_generator: - yield LLMEngine.validate_output(output, RequestOutput) - - return validating_generator() + async for output in await self.add_request( + request_id, + inputs, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + ): + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, @@ -1182,15 +1149,13 @@ async def encode( >>> # Process and return the final output >>> ... """ - generator = await self.add_request( - request_id, - inputs, - pooling_params, - lora_request=lora_request, - trace_headers=trace_headers, - ) - assert generator is not None - async for output in generator: + async for output in await self.add_request( + request_id, + inputs, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + ): yield LLMEngine.validate_output(output, EmbeddingRequestOutput) async def abort(self, request_id: str) -> None: @@ -1224,9 +1189,6 @@ def _abort(self, request_id: str) -> None: exception=asyncio.CancelledError, verbose=self.log_requests) - if self._global_queue is not None: - self._global_queue.put_nowait((request_id, asyncio.CancelledError)) - async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" if self.engine_use_ray: From 8fd72f69ed8476c585b53b9d79972deb1084eda4 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:23:39 +0000 Subject: [PATCH 11/29] fix nit --- vllm/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/utils.py b/vllm/utils.py index dd255684cd0a0..dab8e5fe04359 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -449,6 +449,7 @@ async def merge_async_iterators( It also optionally polls a provided function at least once per second to check for client cancellation. """ + # Can use anext() in python >= 3.10 awaits = { ensure_future(pair[1].__anext__()): pair From ddeb7c672f0b88d34c575d6fdb072cb709f6bc98 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:24:38 +0000 Subject: [PATCH 12/29] format --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 159281dabde4a..6c7e2fdc7a6d9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1281,4 +1281,4 @@ async def start_profile(self) -> None: self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") + self.engine.model_executor._run_workers("stop_profile") \ No newline at end of file From 4b111e4eed882c6489493967009c62b922f44ee1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:25:18 +0000 Subject: [PATCH 13/29] clean --- vllm/engine/async_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6c7e2fdc7a6d9..159281dabde4a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1281,4 +1281,4 @@ async def start_profile(self) -> None: self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") \ No newline at end of file + self.engine.model_executor._run_workers("stop_profile") From a5ffd2c3ea387103a2bf733e24930a565c6e66ec Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:26:55 +0000 Subject: [PATCH 14/29] fix --- vllm/entrypoints/openai/rpc/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 4bf24bdc37f46..a99b6edcc65e9 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -17,11 +17,6 @@ VLLM_RPC_ZMQ_HWM = 0 -@dataclass -class RPCOutputStreamRequest: - pass - - @dataclass class RPCGenerateRequest: inputs: PromptInputs From 139587264dadefe503fa51a0b48242b1c1267c90 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 21:57:24 +0000 Subject: [PATCH 15/29] stash --- vllm/engine/mp_llm_engine.py | 26 +++++--------------------- vllm/entrypoints/openai/rpc/client.py | 4 +--- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/mp_llm_engine.py index ff376208ed023..4c1ede7cedff1 100644 --- a/vllm/engine/mp_llm_engine.py +++ b/vllm/engine/mp_llm_engine.py @@ -3,10 +3,6 @@ from vllm.logger import init_logger from vllm import EngineArgs, LLMEngine from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, - RPCAbortRequest, - RPCGenerateRequest, - RPCOutputStreamRequest, RPCUtilityRequest) logger = init_logger(__name__) @@ -57,25 +53,13 @@ def startup_loop(self): del self.data_socket def engine_loop(self): - has_requests_in_progress = False while True: - if not has_requests_in_progress: + if not self.engine.has_unfinished_requests(): self.wait_for_new_requests() - has_requests_in_progress = self.engine_step() - - def engine_step(self): - self.add_new_requests() - request_outputs = self.engine.step() - self.send_request_outputs(request_outputs) - - all_finished = True - for request_output in request_outputs: - finished = request_output.finished - if not finished: - all_finished = False - break - - return not all_finished + + self.add_new_requests() + request_outputs = self.engine.step() + self.send_request_outputs(request_outputs) def send_request_outputs(self, request_outputs): self.output_socket.send_multipart( diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index c71f250844224..a13e70e8f94d3 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -15,11 +15,9 @@ ParallelConfig, SchedulerConfig) # yapf: disable from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, + RPCAbortRequest, RPCGenerateRequest, - RPCOutputStreamRequest, RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS From 938cf85bda9da1880d34f46791bff4362098c7a3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 2 Sep 2024 22:06:03 +0000 Subject: [PATCH 16/29] move files --- .../openai => engine}/rpc/__init__.py | 6 -- .../openai => engine}/rpc/client.py | 69 +++---------------- .../rpc_llm_engine.py} | 6 +- 3 files changed, 13 insertions(+), 68 deletions(-) rename vllm/{entrypoints/openai => engine}/rpc/__init__.py (89%) rename vllm/{entrypoints/openai => engine}/rpc/client.py (82%) rename vllm/engine/{mp_llm_engine.py => rpc/rpc_llm_engine.py} (96%) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/engine/rpc/__init__.py similarity index 89% rename from vllm/entrypoints/openai/rpc/__init__.py rename to vllm/engine/rpc/__init__.py index a99b6edcc65e9..387119a1b11e5 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/engine/rpc/__init__.py @@ -10,12 +10,6 @@ # Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 - -# HWM is set to Infinity. -VLLM_RPC_ZMQ_HWM = 0 - @dataclass class RPCGenerateRequest: diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/engine/rpc/client.py similarity index 82% rename from vllm/entrypoints/openai/rpc/client.py rename to vllm/engine/rpc/client.py index a13e70e8f94d3..2bcd12c4e2dff 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/engine/rpc/client.py @@ -48,70 +48,21 @@ class RPCClientClosedError(Exception): class AsyncEngineRPCClient: """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. + xxx """ def __init__(self, rpc_path: str): self.context = zmq.asyncio.Context() - self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS self._errored = False self.new_req_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.new_req_socket.connect("ipc:///tmp/new_req_socket") + self.new_req_socket.connect(f"{rpc_path}_new_req_socket") self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.output_socket.connect("ipc:///tmp/output_socket") + self.new_req_socket.connect(f"{rpc_path}_output_socket") - # self.data_socket: Socket = self.context.socket(zmq.constants.DEALER) - # self.data_socket.connect("ipc:///tmp/data_socket") + self.get_data_path = f"{rpc_path}_data_socket" - self.limit_concurrency = None self.output_queues: Dict[str, asyncio.Queue] = {} self.output_handler = asyncio.create_task(self.run_output_handler()) @@ -120,8 +71,8 @@ def get_data_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. + # This can happen if a server shutdown is triggered but some + # coroutines are still running requests. # There should not be a race condition with this check because we don't # yield to the event loop between here and opening the socket. if self.context.closed: @@ -197,9 +148,9 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, copy=False) # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") # Await the data from the Server. frame = await socket.recv(copy=False) @@ -231,9 +182,9 @@ async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): await socket.send_multipart((cloudpickle.dumps(request), )) - if await socket.poll(timeout=self._data_timeout) == 0: + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") frame = await socket.recv(copy=False) return pickle.loads(frame.buffer) diff --git a/vllm/engine/mp_llm_engine.py b/vllm/engine/rpc/rpc_llm_engine.py similarity index 96% rename from vllm/engine/mp_llm_engine.py rename to vllm/engine/rpc/rpc_llm_engine.py index 4c1ede7cedff1..4c6ec13134a04 100644 --- a/vllm/engine/mp_llm_engine.py +++ b/vllm/engine/rpc/rpc_llm_engine.py @@ -2,12 +2,12 @@ import cloudpickle, pickle from vllm.logger import init_logger from vllm import EngineArgs, LLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - RPCUtilityRequest) +from vllm.engine.rpc import (VLLM_RPC_SUCCESS_STR, + RPCUtilityRequest) logger = init_logger(__name__) -class MPLLMEngine: +class RPCLLMEngine: def __init__(self, engine_args) -> None: self.engine = LLMEngine.from_engine_args(engine_args) From 72d1d4233cd24e66c72e7f5a7664a95c55942df2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 00:42:45 +0000 Subject: [PATCH 17/29] cleanup code --- vllm/engine/async_llm_engine.py | 5 - vllm/engine/protocol.py | 4 - vllm/engine/rpc/__init__.py | 45 --- vllm/engine/rpc/client.py | 396 -------------------------- vllm/engine/rpc/rpc_llm_engine.py | 103 ------- vllm/entrypoints/launcher.py | 9 - vllm/entrypoints/openai/api_server.py | 46 ++- 7 files changed, 22 insertions(+), 586 deletions(-) delete mode 100644 vllm/engine/rpc/__init__.py delete mode 100644 vllm/engine/rpc/client.py delete mode 100644 vllm/engine/rpc/rpc_llm_engine.py diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 159281dabde4a..203f2f2748916 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -776,11 +776,6 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None - @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None - def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8df..de6314d532193 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -29,10 +29,6 @@ def is_stopped(self) -> bool: def errored(self) -> bool: ... - @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - def generate( self, inputs: PromptInputs, diff --git a/vllm/engine/rpc/__init__.py b/vllm/engine/rpc/__init__.py deleted file mode 100644 index 387119a1b11e5..0000000000000 --- a/vllm/engine/rpc/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Mapping, Optional, Union - -from vllm.inputs import PromptInputs -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams - -# Success string used for RPC instructions. -VLLM_RPC_SUCCESS_STR = "SUCCESS" - - -@dataclass -class RPCGenerateRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCUtilityRequest(Enum): - IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - CLIENT_IS_READY = 11 - - -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] diff --git a/vllm/engine/rpc/client.py b/vllm/engine/rpc/client.py deleted file mode 100644 index 2bcd12c4e2dff..0000000000000 --- a/vllm/engine/rpc/client.py +++ /dev/null @@ -1,396 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union) -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -# yapf: disable -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, - RPCGenerateRequest, - RPCUtilityRequest) -# yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS -from vllm.inputs import PromptInputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -logger = init_logger(__name__) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - xxx - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._errored = False - - self.new_req_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.new_req_socket.connect(f"{rpc_path}_new_req_socket") - - self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.new_req_socket.connect(f"{rpc_path}_output_socket") - - self.get_data_path = f"{rpc_path}_data_socket" - - self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_handler = asyncio.create_task(self.run_output_handler()) - - @contextmanager - def get_data_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some - # coroutines are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - try: - socket.connect("ipc:///tmp/data_socket") - yield socket - finally: - socket.close(linger=0) - - async def run_output_handler(self): - # await self.socket.send_multipart( - # (cloudpickle.dumps(RPCOutputStreamRequest()), )) - - # Stream back the results from the RPC Server. - while True: - message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) - - for output in request_outputs: - if isinstance(output, tuple): - # Exception case - request_id, output = output - else: - request_id = output.request_id - - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(output) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - await self._notify_ready() - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.context.destroy(linger=0) - - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.get_data_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart( - (cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - frame = await socket.recv(copy=False) - return pickle.loads(frame.buffer) - - if socket is None: - with self.get_data_socket() as socket: - response = await do_rpc_call(socket, request) - else: - response = await do_rpc_call(socket, request) - - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _notify_ready(self): - """Get the RPCServer that the RPCClient is ready""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.CLIENT_IS_READY, - error_message="Unable to notify RPC Server of client readiness") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def _is_tracing_enabled_rpc(self) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - queue: asyncio.Queue[Union[RequestOutput, - BaseException]] = asyncio.Queue() - self.output_queues[request_id] = queue - finished = False - try: - - # Send RPCGenerateRequest to the RPCServer. - await self.new_req_socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - # ack: Frame = await socket.recv(copy=False) - # if len(ack.buffer) != 0: - # exception = pickle.loads(ack.buffer) - # raise exception - - while not finished: - request_output = await queue.get() - if isinstance(request_output, BaseException): - finished = True - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - # await self.check_health(socket=socket) - pass - except Exception as e: - self._errored = True - logger.exception(repr(e)) - raise request_output - - finished = request_output.finished - yield request_output - - finally: - self.output_queues.pop(request_id) - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/engine/rpc/rpc_llm_engine.py b/vllm/engine/rpc/rpc_llm_engine.py deleted file mode 100644 index 4c6ec13134a04..0000000000000 --- a/vllm/engine/rpc/rpc_llm_engine.py +++ /dev/null @@ -1,103 +0,0 @@ -import zmq -import cloudpickle, pickle -from vllm.logger import init_logger -from vllm import EngineArgs, LLMEngine -from vllm.engine.rpc import (VLLM_RPC_SUCCESS_STR, - RPCUtilityRequest) - -logger = init_logger(__name__) - -class RPCLLMEngine: - def __init__(self, engine_args) -> None: - self.engine = LLMEngine.from_engine_args(engine_args) - - self.ctx = zmq.Context() - - self.new_req_socket = self.ctx.socket(zmq.constants.PULL) - self.new_req_socket.bind("ipc:///tmp/new_req_socket") - - self.output_socket = self.ctx.socket(zmq.constants.PUSH) - self.output_socket.bind("ipc:///tmp/output_socket") - - self.data_socket = self.ctx.socket(zmq.constants.ROUTER) - self.data_socket.bind("ipc:///tmp/data_socket") - - def run(self): - logger.info("Running Startup Loop.") - self.startup_loop() - logger.info("Running Engine Loop.") - self.engine_loop() - - def startup_loop(self): - client_is_ready = False - while not client_is_ready: - identity, message = self.data_socket.recv_multipart(copy=False) - request = cloudpickle.loads(message.buffer) - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - config = self.get_config(request) - self.data_socket.send_multipart((identity, pickle.dumps(config)), copy=False) - elif request == RPCUtilityRequest.IS_SERVER_READY: - self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - self.data_socket.send_multipart((identity, pickle.dumps(self.engine.is_tracing_enabled())), copy=False) - elif request == RPCUtilityRequest.CLIENT_IS_READY: - self.data_socket.send_multipart((identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)), copy=False) - client_is_ready = True - self.data_socket.close() - del self.data_socket - - def engine_loop(self): - while True: - if not self.engine.has_unfinished_requests(): - self.wait_for_new_requests() - - self.add_new_requests() - request_outputs = self.engine.step() - self.send_request_outputs(request_outputs) - - def send_request_outputs(self, request_outputs): - self.output_socket.send_multipart( - (pickle.dumps(request_outputs),), copy=False) - - def add_new_requests(self): - while self.new_req_socket.poll(timeout=0) != 0: - message = self.new_req_socket.recv(copy=False) - generate_rpc_request = pickle.loads(message.buffer) - self.engine.add_request( - request_id=generate_rpc_request.request_id, - inputs=generate_rpc_request.inputs, - params=generate_rpc_request.sampling_params, - lora_request=generate_rpc_request.lora_request, - trace_headers=generate_rpc_request.trace_headers, - prompt_adapter_request=generate_rpc_request.prompt_adapter_request, - ) - - def wait_for_new_requests(self): - while self.new_req_socket.poll(timeout=1000) == 0: - logger.info("Waiting for new requests...") - logger.info("Found new request!") - - def get_config(self, request): - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - model_config = self.engine.get_model_config() - return model_config - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - return self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - return self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - return self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - return self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - -def run_rpc_server(engine_args: EngineArgs): - engine = MPLLMEngine(engine_args) - engine.run() diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb0..f4a9c61a431c1 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -27,15 +27,6 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if engine.limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency %s. To avoid this " - "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", engine.limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency - config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server, engine) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cdba0a0ecc9a1..6dcbbd433596d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -38,9 +38,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -# from vllm.entrypoints.openai.rpc.server import run_rpc_server -from vllm.engine.mp_llm_engine import run_rpc_server +from vllm.engine.multiprocessing.mp_client import MPEngineClient +from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -157,38 +156,37 @@ async def build_async_engine_client_from_engine_args( "and vLLM will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for RPC Path.", - rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) # Build RPCClient, which conforms to AsyncEngineClient Protocol. # NOTE: Actually, this is not true yet. We still need to support # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) - async_engine_client = rpc_client # type: ignore + mp_engine_client = MPEngineClient(ipc_path) + async_engine_client = mp_engine_client # type: ignore - # Start RPCServer in separate process (holds the AsyncLLMEngine). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - # rpc_server_process = context.Process( - # target=run_rpc_server, - # args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) - - rpc_server_process = context.Process(target=run_rpc_server, args=(engine_args,)) - rpc_server_process.start() + context = multiprocessing.get_context("spawn") + + engine_process = context.Process( + target=run_mp_engine, + args=(engine_args, UsageContext.OPENAI_API_SERVER, ipc_path)) + engine_process.start() logger.info("Started engine process with PID %d", - rpc_server_process.pid) + engine_process.pid) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): + if not engine_process.is_alive(): logger.error( - "RPCServer process died before responding " + "Engine process died before responding " "to readiness probe") yield None return @@ -196,20 +194,20 @@ async def build_async_engine_client_from_engine_args( yield async_engine_client finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() # Wait for server process to join - rpc_server_process.join() + engine_process.join() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) + multiprocess.mark_process_dead(engine_process.pid) async_engine_client = None #TODO From fcdcfc921540cc3c115bb314187ddba5af17522f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 00:42:59 +0000 Subject: [PATCH 18/29] refactor, cleanup --- vllm/engine/multiprocessing/__init__.py | 51 +++ vllm/engine/multiprocessing/mp_client.py | 368 +++++++++++++++++++ vllm/engine/multiprocessing/mp_llm_engine.py | 253 +++++++++++++ 3 files changed, 672 insertions(+) create mode 100644 vllm/engine/multiprocessing/__init__.py create mode 100644 vllm/engine/multiprocessing/mp_client.py create mode 100644 vllm/engine/multiprocessing/mp_llm_engine.py diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000000000..c6ecb6aa75459 --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +# Success string used for RPC instructions. +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCAbortRequest: + request_id: str + +class RPCUtilityRequest(Enum): + IS_SERVER_READY = 1 + GET_MODEL_CONFIG = 2 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + IS_SERVER_HEALTHY = 8 + IS_TRACING_ENABLED = 9 + START_PROFILE = 10 + STOP_PROFILE = 11 + CLIENT_IS_READY = 11 + + +RPC_COFNIG_REQUEST = [ + RPCUtilityRequest.GET_MODEL_CONFIG, + RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCUtilityRequest.GET_DECODING_CONFIG, + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCUtilityRequest.GET_LORA_CONFIG +] + +RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, + RPCUtilityRequest] diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py new file mode 100644 index 0000000000000..fd3011d548884 --- /dev/null +++ b/vllm/engine/multiprocessing/mp_client.py @@ -0,0 +1,368 @@ +import asyncio +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) +from uuid import uuid4 + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf: disable +from vllm.engine.multiprocessing import (RPC_REQUEST_TYPE, + VLLM_RPC_SUCCESS_STR, + RPCAbortRequest, + RPCGenerateRequest, + RPCUtilityRequest) +# yapf: enable +from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +logger = init_logger(__name__) + + +class MPEngineClient: + + def __init__(self, ipc_path: str): + self.context = zmq.asyncio.Context() + self._errored = False + + # Send RPCGenerateRequest to the MPLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}_input_socket") + + # Recieve streams of RequestOutput from the MPLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}_output_socket") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}_data_socket" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_handler = asyncio.create_task(self.run_output_handler()) + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_output_handler(self): + # Stream lists of RequestOutput from MPLLMEngine. + while True: + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + for output in request_outputs: + if isinstance(output, tuple): + # Exception case + request_id, output = output + else: + request_id = output.request_id + + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(output) + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self._wait_for_server_rpc() + + # Get the configs. + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() + self.tracing_flag = await self._is_tracing_enabled_rpc() + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), + ) + + await self._notify_ready() + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets associated with this context and + # then terminate the context. + self.context.destroy(linger=0) + + + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + with self.get_data_socket() as socket: + # Ping RPCServer with a request. + await socket.send_multipart( + (cloudpickle.dumps(request), ), + copy=False) + + # Make sure the server responds + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("Server didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, Exception): + # Re-raise exceptions returned by the server + raise data + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + elif isinstance(data, Exception): + logger.error(error_message) + raise data + else: + raise ValueError(error_message) + + return data + + async def _send_one_way_rpc_request(self, + request: RPC_REQUEST_TYPE, + error_message: str, + socket: Optional[Socket] = None): + """Send one-way RPC request to trigger an action.""" + + async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): + + await socket.send_multipart((cloudpickle.dumps(request), )) + + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("Server didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + + frame = await socket.recv(copy=False) + return pickle.loads(frame.buffer) + + if socket is None: + with self.get_data_socket() as socket: + response = await do_rpc_call(socket, request) + else: + response = await do_rpc_call(socket, request) + + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + if isinstance(response, Exception): + logger.error(error_message) + raise response + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self): + """Wait for the RPCServer to start up.""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_READY, + error_message="Unable to start RPC Server") + + async def _notify_ready(self): + """Get the RPCServer that the RPCClient is ready""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.CLIENT_IS_READY, + error_message="Unable to notify RPC Server of client readiness") + + async def _get_model_config_rpc(self) -> ModelConfig: + """Get the ModelConfig object from the RPC Server""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_decoding_config_rpc(self) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") + + async def _get_parallel_config_rpc(self) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ParallelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server") + + async def _get_lora_config_rpc(self) -> LoRAConfig: + """Get LoRAConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server") + + async def _is_tracing_enabled_rpc(self) -> bool: + """Get is_tracing_enabled flag from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.IS_TRACING_ENABLED, + expected_type=bool, + error_message="Could not get is_tracing_enabled from RPC Server") + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + # Suppress timeouts as well. + # In cases where the server is busy processing requests and a very + # large volume of abort requests arrive, it is likely that the server + # will not be able to ack all of them in time. We have seen this when + # we abort 20k requests at once while another 2k are processing- many + # of them time out, but we see the server successfully abort all of the + # requests. + # In this case we assume that the server has received or will receive + # these abort requests, and ignore the timeout. This prevents a massive + # wall of `TimeoutError` stack traces. + with suppress(RPCClientClosedError, TimeoutError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") + + async def do_log_stats(self): + """Send a DO_LOG_STATS signal to the RPC Server""" + with suppress(RPCClientClosedError): + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") + + @property + def is_running(self) -> bool: + return not self._errored + + @property + def is_stopped(self) -> bool: + return self._errored + + @property + def errored(self) -> bool: + return self._errored + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + finished = False + try: + + # Send RPCGenerateRequest to the RPCServer. + await self.input_socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) + + # ack: Frame = await socket.recv(copy=False) + # if len(ack.buffer) != 0: + # exception = pickle.loads(ack.buffer) + # raise exception + + while not finished: + request_output = await queue.get() + if isinstance(request_output, BaseException): + finished = True + # On exception, check if the server is still healthy + # possibly setting the `errored` property. + if not self._errored: + try: + # await self.check_health(socket=socket) + pass + except Exception as e: + self._errored = True + logger.exception(repr(e)) + raise request_output + + finished = request_output.finished + yield request_output + + finally: + self.output_queues.pop(request_id) + # Request was canceled by the client. + if not finished and not self._errored: + await self.abort(request_id) + + async def check_health(self, socket: Optional[Socket] = None) -> None: + """Raise if unhealthy""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_HEALTHY, + error_message="Got Unhealthy response from RPC Server", + socket=socket) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.START_PROFILE, + error_message="RPCRequest START_PROFILE failed.") + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.STOP_PROFILE, + error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py new file mode 100644 index 0000000000000..0671c48d84c6d --- /dev/null +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -0,0 +1,253 @@ +import ray +import zmq +import cloudpickle +import pickle +from typing import Any, Type, Union, Iterator +from contextlib import contextmanager + +import vllm.envs as envs +from vllm import AsyncEngineArgs, LLMEngine, AsyncLLMEngine +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, + RPCUtilityRequest) +from vllm.utils import print_warning_once +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +class MPLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in asynchronous manner. It runs a background loop and uses zeromq to + recieve new requests and stream outputs incrementally to another process. + + The :class:`LLMEngine` is kicked off when a new RPCGenerateRequest + is recieved by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()` and sends the RequestOutputs back over + the output_socket. + + Args: + worker_use_ray: Whether to use Ray for model workers. Required for + distributed execution. Should be the same as + `parallel_config.worker_use_ray`. + engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the + async frontend will be executed in a separate process as the + model workers. + async_engine_args: AsyncLLMEngine args + log_requests: Whether to log the requests. + """ + + _engine_class: Type[LLMEngine] = LLMEngine + + def __init__(self, + worker_use_ray: bool, + engine_use_ray: bool, + *args, + ipc_path: str, + log_requests: bool = True, + **kwargs) -> None: + + if engine_use_ray: + raise NotImplementedError("Not yet supported.") + + self.worker_use_ray = worker_use_ray + self.engine_use_ray = engine_use_ray + self.log_requests = log_requests + self.engine = self._init_engine(*args, **kwargs) + + if self.engine_use_ray: + print_warning_once( + "DEPRECATED. `--engine-use-ray` is deprecated and will " + "be removed in a future update. " + "See https://github.com/vllm-project/vllm/issues/7045.") + + if envs.VLLM_ALLOW_ENGINE_USE_RAY: + print_warning_once( + "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray") + else: + raise ValueError("`--engine-use-ray` is deprecated. " + "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to " + "force use it") + + self.ctx = zmq.Context() + + # Recieve RPCGenerateRequest from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}_input_socket") + + # Send streams of RequestOutput back to Client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}_output_socket") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}_data_socket" + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an RPCLLM engine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + if engine_args.engine_use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + # TODO: better abstraction? + executor_class = AsyncLLMEngine._get_executor_cls(engine_config) + + return cls( + executor_class.uses_ray, + engine_args.engine_use_ray, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + ipc_path=ipc_path, + ) + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + self.input_socket.close() + self.output_socket.close() + self.ctx.destroy(linger=0) + del self.engine + + def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: + """Initialize the LLMEngine""" + + if not self.engine_use_ray: + engine_class = self._engine_class + elif self.worker_use_ray: + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + else: + # FIXME(woosuk): This is a bit hacky. Be careful when changing the + # order of the arguments. + cache_config = kwargs["cache_config"] + parallel_config = kwargs["parallel_config"] + if (parallel_config.tensor_parallel_size == 1 + and parallel_config.pipeline_parallel_size == 1): + num_gpus = cache_config.gpu_memory_utilization + else: + num_gpus = 1 + engine_class = ray.remote(num_gpus=num_gpus)( + self._engine_class).remote + return engine_class(*args, **kwargs) + + def run_background_loop(self): + """Entrypoint that kicks off the background processing loop.""" + + # Allow RPCClient to query data in startup phase. + self.run_startup_loop() + + # Kick off core processing loop. + self.run_engine_loop() + + @contextmanager + def make_data_socket(self) -> Iterator[zmq.Socket]: + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Loop over startup RPCRequests from RPCClient.""" + + with self.make_data_socket() as socket: + + # Loop until the RPCClient has all the data it needs. + client_is_ready = False + while not client_is_ready: + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCUtilityRequest = cloudpickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + response = self.engine.get_model_config() + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + response = self.engine.get_decoding_config() + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + response = self.engine.get_lora_config() + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + response = self.engine.get_scheduler_config() + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + response = self.engine.get_parallel_config() + elif request == RPCUtilityRequest.IS_SERVER_READY: + response = VLLM_RPC_SUCCESS_STR + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + response = self.engine.is_tracing_enabled() + elif request == RPCUtilityRequest.CLIENT_IS_READY: + response = VLLM_RPC_SUCCESS_STR + # Once client ready, breakout of loop. + client_is_ready = True + else: + raise ValueError(f"Unknown RPCRequest: {request}") + + socket.send_multipart( + (identity, pickle.dumps(response)), copy=False) + + except Exception as e: + socket.send_multipart((identity, pickle.dumps(e)), copy=False) + + def run_engine_loop(self) -> None: + # TODO: handle PP + + while True: + # Block until there is a new request. + if not self.engine.has_unfinished_requests(): + self.wait_for_new_requests() + + # Add new work from input socket. + self.maybe_add_new_requests() + + # Engine step. + request_outputs = self.engine.step() + + # Stream results to output socket. + self.stream_outputs(request_outputs) + + + def wait_for_new_requests(self): + while self.input_socket.poll(timeout=10000) == 0: + logger.debug("Waiting for new request.") + + def stream_outputs(self, request_outputs): + self.output_socket.send_multipart( + (pickle.dumps(request_outputs),), copy=False) + + def maybe_add_new_requests(self): + while self.input_socket.poll(timeout=0) != 0: + message = self.input_socket.recv(copy=False) + generate_rpc_request = pickle.loads(message.buffer) + self.engine.add_request( + request_id=generate_rpc_request.request_id, + inputs=generate_rpc_request.inputs, + params=generate_rpc_request.sampling_params, + lora_request=generate_rpc_request.lora_request, + trace_headers=generate_rpc_request.trace_headers, + prompt_adapter_request=generate_rpc_request.prompt_adapter_request, + ) + + +def run_mp_engine(engine_args: AsyncEngineArgs, + usage_context: UsageContext, + ipc_path: str): + engine = MPLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + + engine.run_background_loop() From 659169ee8290812e1e32d89f9bed33a9ea8fe196 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 01:42:58 +0000 Subject: [PATCH 19/29] updated --- examples/openai_completion_client.py | 2 +- vllm/engine/multiprocessing/__init__.py | 22 +- vllm/engine/multiprocessing/mp_client.py | 237 ++++++++++--------- vllm/engine/multiprocessing/mp_llm_engine.py | 116 ++++----- vllm/entrypoints/openai/api_server.py | 5 +- 5 files changed, 195 insertions(+), 187 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 13f98d3220366..0b77ed4d25584 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -19,7 +19,7 @@ model=model, prompt="A robot may not injure a human being", stream=stream, - max_tokens=1000) + max_tokens=100) print("Completion results:") if stream: diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index c6ecb6aa75459..be7d80072f964 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -25,27 +25,19 @@ class RPCAbortRequest: request_id: str class RPCUtilityRequest(Enum): + DO_LOG_STATS = 1 + CHECK_HEALTH = 2 + +class RPCStartupRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 GET_DECODING_CONFIG = 3 GET_PARALLEL_CONFIG = 4 GET_SCHEDULER_CONFIG = 5 GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - CLIENT_IS_READY = 11 - - -RPC_COFNIG_REQUEST = [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG -] + GET_TRACING_ENABLED = 7 + CLIENT_IS_READY = 8 + RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, RPCUtilityRequest] diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index fd3011d548884..086242d28fb59 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -1,9 +1,8 @@ import asyncio import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Union) -from uuid import uuid4 import cloudpickle import zmq @@ -18,6 +17,7 @@ VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, + RPCStartupRequest, RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS @@ -31,6 +31,15 @@ logger = init_logger(__name__) +class MPClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ class MPEngineClient: @@ -82,24 +91,27 @@ async def run_output_handler(self): async def setup(self): """Setup the client before it starts sending server requests.""" - # Wait until server is ready. - await self._wait_for_server_rpc() + with self.get_data_socket() as socket: + + # Wait until server is ready. + await self._wait_for_server_rpc(socket) - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() + # Get the configs. + self.model_config = await self._get_model_config_rpc(socket) + self.decoding_config = await self._get_decoding_config_rpc(socket) + self.tracing_flag = await self._is_tracing_enabled_rpc(socket) - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc(socket)), + parallel_config=(await self._get_parallel_config_rpc(socket)), + enable_lora=bool(await self._get_lora_config_rpc(socket)), + ) - await self._notify_ready() + # Notify MPLLMEngine client is ready to start sending requests. + await self._notify_ready(socket) def close(self): """Destroy the ZeroMQ Context.""" @@ -110,64 +122,63 @@ def close(self): async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, expected_type: Any, - error_message: str) -> Any: + error_message: str, + socket: Socket) -> Any: """Send an RPC request that is expecting data back.""" - with self.get_data_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart( - (cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server + # Ping RPCServer with a request. + await socket.send_multipart( + (cloudpickle.dumps(request), ), + copy=False) + + # Make sure the server responds + if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: + raise TimeoutError("Server didn't reply within " + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, Exception): + # Re-raise exceptions returned by the server + raise data + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + elif isinstance(data, Exception): + logger.error(error_message) raise data + else: + raise ValueError(error_message) - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data + return data async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): + socket: Socket): """Send one-way RPC request to trigger an action.""" - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) + await socket.send_multipart((cloudpickle.dumps(request), )) - if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: - raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") - - frame = await socket.recv(copy=False) - return pickle.loads(frame.buffer) - - if socket is None: - with self.get_data_socket() as socket: - response = await do_rpc_call(socket, request) - else: - response = await do_rpc_call(socket, request) + # TODO: is there a way to ack this if we are using the input_socket? + # I don't think so, b/c we are using PUSH/PULL + + async def _awk_one_way_rpc_request(self, + timeout: int, + expected_str: str, + error_message: str, + socket: Socket,): + if await socket.poll(timeout=timeout) == 0: + raise TimeoutError(f"MPLLMEngine didn't reply within {timeout}ms") + + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + if not isinstance(response, str) or response != expected_str: if isinstance(response, Exception): logger.error(error_message) raise response @@ -185,72 +196,86 @@ async def get_model_config(self) -> ModelConfig: async def is_tracing_enabled(self) -> bool: return self.tracing_flag - async def _wait_for_server_rpc(self): + async def _wait_for_server_rpc(self, socket: Socket): """Wait for the RPCServer to start up.""" + + # Readiness probe. + request = RPCStartupRequest.IS_SERVER_READY + await socket.send_multipart((cloudpickle.dumps(request), )) + + # Raises TimeoutError if not awk, causing a retry. + await self._awk_one_way_rpc_request( + expected_str=VLLM_RPC_SUCCESS_STR, + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + error_message="Unable to start RPC Server", + socket=socket) + - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _notify_ready(self): + async def _notify_ready(self, socket: Socket): """Get the RPCServer that the RPCClient is ready""" await self._send_one_way_rpc_request( - request=RPCUtilityRequest.CLIENT_IS_READY, - error_message="Unable to notify RPC Server of client readiness") + request=RPCStartupRequest.CLIENT_IS_READY, + socket=socket) - async def _get_model_config_rpc(self) -> ModelConfig: + async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, + RPCStartupRequest.GET_MODEL_CONFIG, expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") + error_message="Could not get ModelConfig from RPC Server", + socket=socket) - async def _get_decoding_config_rpc(self) -> DecodingConfig: + async def _get_decoding_config_rpc(self, socket: Socket) -> DecodingConfig: """Get DecodingConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, + RPCStartupRequest.GET_DECODING_CONFIG, expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") + error_message="Could not get DecodingConfig from RPC Server", + socket=socket) - async def _get_parallel_config_rpc(self) -> ParallelConfig: + async def _get_parallel_config_rpc(self, socket: Socket) -> ParallelConfig: """Get ParallelConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCStartupRequest.GET_PARALLEL_CONFIG, expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") + error_message="Could not get ParallelConfig from RPC Server", + socket=socket) - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + async def _get_scheduler_config_rpc(self, socket: Socket) -> SchedulerConfig: """Get SchedulerConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCStartupRequest.GET_SCHEDULER_CONFIG, expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") + error_message="Could not get SchedulerConfig from RPC Server", + socket=socket) - async def _get_lora_config_rpc(self) -> LoRAConfig: + async def _get_lora_config_rpc(self, socket: Socket) -> LoRAConfig: """Get LoRAConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, + RPCStartupRequest.GET_LORA_CONFIG, expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") + error_message="Could not get LoRAConfig from RPC Server", + socket=socket) - async def _is_tracing_enabled_rpc(self) -> bool: + async def _is_tracing_enabled_rpc(self, socket: Socket) -> bool: """Get is_tracing_enabled flag from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, + RPCStartupRequest.GET_TRACING_ENABLED, expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") + error_message="Could not get is_tracing_enabled from RPC Server", + socket=socket) async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" - # Suppress timeouts as well. + # Suppress timeouts and MPClientClosedError. # In cases where the server is busy processing requests and a very # large volume of abort requests arrive, it is likely that the server # will not be able to ack all of them in time. We have seen this when @@ -260,17 +285,17 @@ async def abort(self, request_id: str): # In this case we assume that the server has received or will receive # these abort requests, and ignore the timeout. This prevents a massive # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): + with suppress(MPClientClosedError, TimeoutError): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") + socket=self.input_socket) async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): + with suppress(MPClientClosedError): await self._send_one_way_rpc_request( request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") + socket=self.input_socket) @property def is_running(self) -> bool: @@ -340,29 +365,15 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) - async def check_health(self, socket: Optional[Socket] = None) -> None: + async def check_health(self) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) + request=RPCUtilityRequest.CHECK_HEALTH, + socket=self.input_socket) + async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 0671c48d84c6d..6323b1d0734f4 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -2,17 +2,19 @@ import zmq import cloudpickle import pickle -from typing import Any, Type, Union, Iterator +from typing import Iterator, List, Type, Union from contextlib import contextmanager -import vllm.envs as envs from vllm import AsyncEngineArgs, LLMEngine, AsyncLLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, - RPCUtilityRequest) -from vllm.utils import print_warning_once + RPCGenerateRequest, + RPCAbortRequest, + RPCStartupRequest, + RPCUtilityRequest) +from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, @@ -64,27 +66,13 @@ def __init__(self, self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) - if self.engine_use_ray: - print_warning_once( - "DEPRECATED. `--engine-use-ray` is deprecated and will " - "be removed in a future update. " - "See https://github.com/vllm-project/vllm/issues/7045.") - - if envs.VLLM_ALLOW_ENGINE_USE_RAY: - print_warning_once( - "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray") - else: - raise ValueError("`--engine-use-ray` is deprecated. " - "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to " - "force use it") - self.ctx = zmq.Context() - # Recieve RPCGenerateRequest from the client. + # Recieve input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) self.input_socket.bind(f"{ipc_path}_input_socket") - # Send streams of RequestOutput back to Client. + # Send output stream back to client. self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}_output_socket") @@ -144,6 +132,7 @@ def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: self._engine_class).remote return engine_class(*args, **kwargs) + def run_background_loop(self): """Entrypoint that kicks off the background processing loop.""" @@ -152,7 +141,8 @@ def run_background_loop(self): # Kick off core processing loop. self.run_engine_loop() - + + @contextmanager def make_data_socket(self) -> Iterator[zmq.Socket]: socket = self.ctx.socket(zmq.constants.ROUTER) @@ -163,7 +153,7 @@ def make_data_socket(self) -> Iterator[zmq.Socket]: socket.close(linger=0) def run_startup_loop(self) -> None: - """Loop over startup RPCRequests from RPCClient.""" + """Loop over startup RPCStatupRequest from RPCClient.""" with self.make_data_socket() as socket: @@ -172,29 +162,27 @@ def run_startup_loop(self) -> None: while not client_is_ready: try: identity, message = socket.recv_multipart(copy=False) - request: RPCUtilityRequest = cloudpickle.loads(message.buffer) + request: RPCStartupRequest = pickle.loads(message.buffer) # Handle the query from the Client. - if request == RPCUtilityRequest.GET_MODEL_CONFIG: + if request == RPCStartupRequest.GET_MODEL_CONFIG: response = self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + elif request == RPCStartupRequest.GET_DECODING_CONFIG: response = self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: + elif request == RPCStartupRequest.GET_LORA_CONFIG: response = self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + elif request == RPCStartupRequest.GET_SCHEDULER_CONFIG: response = self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + elif request == RPCStartupRequest.GET_PARALLEL_CONFIG: response = self.engine.get_parallel_config() - elif request == RPCUtilityRequest.IS_SERVER_READY: - response = VLLM_RPC_SUCCESS_STR - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + elif request == RPCStartupRequest.GET_TRACING_ENABLED: response = self.engine.is_tracing_enabled() - elif request == RPCUtilityRequest.CLIENT_IS_READY: + elif request == RPCStartupRequest.IS_SERVER_READY: + response = VLLM_RPC_SUCCESS_STR + elif request == RPCStartupRequest.CLIENT_IS_READY: response = VLLM_RPC_SUCCESS_STR - # Once client ready, breakout of loop. + # Breakout of loop once client is ready. client_is_ready = True - else: - raise ValueError(f"Unknown RPCRequest: {request}") socket.send_multipart( (identity, pickle.dumps(response)), copy=False) @@ -203,43 +191,61 @@ def run_startup_loop(self) -> None: socket.send_multipart((identity, pickle.dumps(e)), copy=False) def run_engine_loop(self) -> None: - # TODO: handle PP - while True: # Block until there is a new request. if not self.engine.has_unfinished_requests(): - self.wait_for_new_requests() + self.wait_for_new_input() - # Add new work from input socket. - self.maybe_add_new_requests() + # Handle any new input from the input socket. + self.maybe_handle_new_input() # Engine step. request_outputs = self.engine.step() # Stream results to output socket. - self.stream_outputs(request_outputs) - + self.stream_outputs(request_outputs) - def wait_for_new_requests(self): + def wait_for_new_input(self): while self.input_socket.poll(timeout=10000) == 0: logger.debug("Waiting for new request.") - def stream_outputs(self, request_outputs): + def stream_outputs(self, request_outputs: List[RequestOutput]): self.output_socket.send_multipart( (pickle.dumps(request_outputs),), copy=False) - - def maybe_add_new_requests(self): + + def maybe_handle_new_input(self): + """Handle new input with non-blocking IO""" while self.input_socket.poll(timeout=0) != 0: message = self.input_socket.recv(copy=False) - generate_rpc_request = pickle.loads(message.buffer) - self.engine.add_request( - request_id=generate_rpc_request.request_id, - inputs=generate_rpc_request.inputs, - params=generate_rpc_request.sampling_params, - lora_request=generate_rpc_request.lora_request, - trace_headers=generate_rpc_request.trace_headers, - prompt_adapter_request=generate_rpc_request.prompt_adapter_request, - ) + request = cloudpickle.loads(message.buffer) + + if isinstance(request, RPCGenerateRequest): + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUtilityRequest): + self._handle_utility_request(request) + else: + raise ValueError(f"Unknown RPCRequest: {request}") + + def _handle_generate_request(self, request: RPCGenerateRequest): + self.engine.add_request( + request_id=request.request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + ) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request([request.request_id]) + + def _handle_utility_request(self, request: RPCUtilityRequest): + if request == RPCUtilityRequest.DO_LOG_STATS: + self.engine.do_log_stats() + elif request == RPCUtilityRequest.CHECK_HEALTH: + self.engine.check_health() def run_mp_engine(engine_args: AsyncEngineArgs, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6dcbbd433596d..ef8f98a3889b9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -82,11 +82,10 @@ async def lifespan(app: FastAPI): async def _force_log(): while True: - await asyncio.sleep(10) + await asyncio.sleep(1.) await async_engine_client.do_log_stats() - # if not engine_args.disable_log_stats: - if False: + if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) From 9886f3dc689e52c62ceea87723af3858b710f5f3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 01:54:40 +0000 Subject: [PATCH 20/29] make health check work --- vllm/engine/multiprocessing/mp_client.py | 20 ++++++++++++++++++-- vllm/engine/multiprocessing/mp_llm_engine.py | 13 +++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index 086242d28fb59..eff3b0d06e408 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -55,6 +55,10 @@ def __init__(self, ipc_path: str): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}_output_socket") + # IPC path for awk of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}_health_socket") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}_data_socket" @@ -164,7 +168,8 @@ async def _send_one_way_rpc_request(self, await socket.send_multipart((cloudpickle.dumps(request), )) # TODO: is there a way to ack this if we are using the input_socket? - # I don't think so, b/c we are using PUSH/PULL + # I don't think so, b/c we are using PUSH/PULL w/out identities so no + # way to preserve order. async def _awk_one_way_rpc_request(self, timeout: int, @@ -349,7 +354,7 @@ async def generate( # possibly setting the `errored` property. if not self._errored: try: - # await self.check_health(socket=socket) + await self.check_health() pass except Exception as e: self._errored = True @@ -371,6 +376,17 @@ async def check_health(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) + + # Await awknoledgement from MPLLMEngine. + # Note: these requests are not necessarily serial. + # I.e. if two clients A, B send CHECK_HEALTH, the + # response to A could actually be the call send by B. + # TODO: is this bad? + await self._awk_one_way_rpc_request( + timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, + expected_str=VLLM_RPC_SUCCESS_STR, + error_message="Check health timeout.", + socket=self.health_socket) async def encode(self, *args, diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 6323b1d0734f4..8ac1ade813166 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -76,6 +76,10 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}_output_socket") + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}_health_socket") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}_data_socket" @@ -213,6 +217,10 @@ def stream_outputs(self, request_outputs: List[RequestOutput]): self.output_socket.send_multipart( (pickle.dumps(request_outputs),), copy=False) + def awk_check_health(self): + self.health_socket.send_multipart( + (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) + def maybe_handle_new_input(self): """Handle new input with non-blocking IO""" while self.input_socket.poll(timeout=0) != 0: @@ -246,8 +254,9 @@ def _handle_utility_request(self, request: RPCUtilityRequest): self.engine.do_log_stats() elif request == RPCUtilityRequest.CHECK_HEALTH: self.engine.check_health() - - + # Special check_health channel for awk check health. + self.awk_check_health() + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): From 5b2f0577fdbe5bf0f86e50297f2c57254f95f7c2 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:46:43 +0000 Subject: [PATCH 21/29] format --- benchmarks/benchmark_throughput_async.py | 22 ++--- vllm/engine/multiprocessing/__init__.py | 2 +- vllm/engine/multiprocessing/mp_client.py | 76 ++++++++---------- vllm/engine/multiprocessing/mp_llm_engine.py | 84 ++++++++++---------- vllm/entrypoints/openai/api_server.py | 23 +++--- 5 files changed, 100 insertions(+), 107 deletions(-) diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py index 54eed0f4de783..217f11d14d30a 100644 --- a/benchmarks/benchmark_throughput_async.py +++ b/benchmarks/benchmark_throughput_async.py @@ -1,6 +1,5 @@ """Benchmark offline inference throughput.""" import argparse -import asyncio import json import random import time @@ -12,11 +11,11 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.entrypoints.openai.api_server import build_async_engine_client_from_engine_args -from vllm.utils import merge_async_iterators -from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, merge_async_iterators def sample_requests( @@ -92,7 +91,7 @@ async def run_vllm( load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, ) -> float: - from vllm import LLM, SamplingParams + from vllm import SamplingParams engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, @@ -123,8 +122,8 @@ async def run_vllm( decoupled = True - async with build_async_engine_client_from_engine_args(engine_args, - not decoupled) as llm: + async with build_async_engine_client_from_engine_args( + engine_args, not decoupled) as llm: # Add the requests to the engine. prompts: List[str] = [] @@ -146,13 +145,14 @@ async def run_vllm( for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): # generator = await llm.generate(prompt, sp, request_id=f"test{i}") generator = llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) + generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass end = time.perf_counter() return end - start + def run_hf( requests: List[Tuple[str, int, int]], model: str, @@ -248,7 +248,7 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - coro = run_vllm( + coro = run_vllm( requests, args.model, args.tokenizer, args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, @@ -260,7 +260,7 @@ def main(args: argparse.Namespace): args.use_v2_block_manager, args.download_dir, args.load_format, args.disable_async_output_proc) - elapsed_time = uvloop.run(coro) + elapsed_time = uvloop.run(coro) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index be7d80072f964..cf566933801e9 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -40,4 +40,4 @@ class RPCStartupRequest(Enum): RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] + RPCUtilityRequest, RPCStartupRequest] diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index eff3b0d06e408..ba3269c252ba3 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -1,8 +1,8 @@ import asyncio import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, Optional, - Union) +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union) import cloudpickle import zmq @@ -14,10 +14,8 @@ ParallelConfig, SchedulerConfig) # yapf: disable from vllm.engine.multiprocessing import (RPC_REQUEST_TYPE, - VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, - RPCGenerateRequest, - RPCStartupRequest, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) # yapf: enable from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS @@ -31,6 +29,7 @@ logger = init_logger(__name__) + class MPClientClosedError(Exception): """Exception class raised when the client is used post-close. @@ -41,6 +40,7 @@ class MPClientClosedError(Exception): So, we throw this error such that we can suppress it. """ + class MPEngineClient: def __init__(self, ipc_path: str): @@ -51,7 +51,7 @@ def __init__(self, ipc_path: str): self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) self.input_socket.connect(f"{ipc_path}_input_socket") - # Recieve streams of RequestOutput from the MPLLMEngine. + # Receive streams of RequestOutput from the MPLLMEngine. self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}_output_socket") @@ -65,7 +65,7 @@ def __init__(self, ipc_path: str): # Stream for each individual request. self.output_queues: Dict[str, asyncio.Queue] = {} self.output_handler = asyncio.create_task(self.run_output_handler()) - + @contextmanager def get_data_socket(self) -> Iterator[Socket]: socket = self.context.socket(zmq.constants.DEALER) @@ -79,7 +79,7 @@ async def run_output_handler(self): # Stream lists of RequestOutput from MPLLMEngine. while True: message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) + request_outputs: List[RequestOutput] = pickle.loads(message.buffer) for output in request_outputs: if isinstance(output, tuple): @@ -109,7 +109,8 @@ async def setup(self): # TODO: refactor OAI server to avoid needing this info. self.tokenizer = init_tokenizer_from_configs( model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc(socket)), + scheduler_config=(await + self._get_scheduler_config_rpc(socket)), parallel_config=(await self._get_parallel_config_rpc(socket)), enable_lora=bool(await self._get_lora_config_rpc(socket)), ) @@ -123,22 +124,19 @@ def close(self): # then terminate the context. self.context.destroy(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + async def _send_get_data_rpc_request(self, request: RPCStartupRequest, expected_type: Any, error_message: str, socket: Socket) -> Any: """Send an RPC request that is expecting data back.""" # Ping RPCServer with a request. - await socket.send_multipart( - (cloudpickle.dumps(request), ), - copy=False) + await socket.send_multipart((cloudpickle.dumps(request), ), copy=False) # Make sure the server responds if await socket.poll(timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS) == 0: raise TimeoutError("Server didn't reply within " - f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") + f"{VLLM_RPC_GET_DATA_TIMEOUT_MS} ms") # Await the data from the Server. frame = await socket.recv(copy=False) @@ -160,8 +158,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, return data - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, socket: Socket): """Send one-way RPC request to trigger an action.""" @@ -170,16 +167,17 @@ async def _send_one_way_rpc_request(self, # TODO: is there a way to ack this if we are using the input_socket? # I don't think so, b/c we are using PUSH/PULL w/out identities so no # way to preserve order. - - async def _awk_one_way_rpc_request(self, - timeout: int, - expected_str: str, - error_message: str, - socket: Socket,): + + async def _awk_one_way_rpc_request( + self, + timeout: int, + expected_str: str, + error_message: str, + socket: Socket, + ): if await socket.poll(timeout=timeout) == 0: raise TimeoutError(f"MPLLMEngine didn't reply within {timeout}ms") - - + frame = await socket.recv(copy=False) response = pickle.loads(frame.buffer) @@ -203,7 +201,7 @@ async def is_tracing_enabled(self) -> bool: async def _wait_for_server_rpc(self, socket: Socket): """Wait for the RPCServer to start up.""" - + # Readiness probe. request = RPCStartupRequest.IS_SERVER_READY await socket.send_multipart((cloudpickle.dumps(request), )) @@ -214,14 +212,12 @@ async def _wait_for_server_rpc(self, socket: Socket): timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, error_message="Unable to start RPC Server", socket=socket) - async def _notify_ready(self, socket: Socket): """Get the RPCServer that the RPCClient is ready""" await self._send_one_way_rpc_request( - request=RPCStartupRequest.CLIENT_IS_READY, - socket=socket) + request=RPCStartupRequest.CLIENT_IS_READY, socket=socket) async def _get_model_config_rpc(self, socket: Socket) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -250,7 +246,8 @@ async def _get_parallel_config_rpc(self, socket: Socket) -> ParallelConfig: error_message="Could not get ParallelConfig from RPC Server", socket=socket) - async def _get_scheduler_config_rpc(self, socket: Socket) -> SchedulerConfig: + async def _get_scheduler_config_rpc(self, + socket: Socket) -> SchedulerConfig: """Get SchedulerConfig from the RPCServer""" return await self._send_get_data_rpc_request( @@ -292,8 +289,7 @@ async def abort(self, request_id: str): # wall of `TimeoutError` stack traces. with suppress(MPClientClosedError, TimeoutError): await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - socket=self.input_socket) + request=RPCAbortRequest(request_id), socket=self.input_socket) async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" @@ -330,7 +326,7 @@ async def generate( self.output_queues[request_id] = queue finished = False try: - + # Send RPCGenerateRequest to the RPCServer. await self.input_socket.send_multipart((cloudpickle.dumps( RPCGenerateRequest( @@ -341,10 +337,10 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) - # ack: Frame = await socket.recv(copy=False) - # if len(ack.buffer) != 0: - # exception = pickle.loads(ack.buffer) - # raise exception + # ack: Frame = await socket.recv(copy=False) + # if len(ack.buffer) != 0: + # exception = pickle.loads(ack.buffer) + # raise exception while not finished: request_output = await queue.get() @@ -374,8 +370,7 @@ async def check_health(self) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( - request=RPCUtilityRequest.CHECK_HEALTH, - socket=self.input_socket) + request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) # Await awknoledgement from MPLLMEngine. # Note: these requests are not necessarily serial. @@ -387,7 +382,6 @@ async def check_health(self) -> None: expected_str=VLLM_RPC_SUCCESS_STR, error_message="Check health timeout.", socket=self.health_socket) - async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 8ac1ade813166..72ced337a1605 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -1,19 +1,18 @@ -import ray -import zmq -import cloudpickle import pickle -from typing import Iterator, List, Type, Union from contextlib import contextmanager +from typing import Iterator, List, Type, Union -from vllm import AsyncEngineArgs, LLMEngine, AsyncLLMEngine +import cloudpickle +import ray +import zmq + +from vllm import AsyncEngineArgs, AsyncLLMEngine, LLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.logger import init_logger -from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, - RPCGenerateRequest, - RPCAbortRequest, - RPCStartupRequest, +from vllm.engine.multiprocessing import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCStartupRequest, RPCUtilityRequest) +from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -22,15 +21,16 @@ logger = init_logger(__name__) + class MPLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. This class is used to wrap the :class:`LLMEngine` class to enable use in asynchronous manner. It runs a background loop and uses zeromq to - recieve new requests and stream outputs incrementally to another process. + receive new requests and stream outputs incrementally to another process. The :class:`LLMEngine` is kicked off when a new RPCGenerateRequest - is recieved by the input_socket. + is received by the input_socket. The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal @@ -60,15 +60,15 @@ def __init__(self, if engine_use_ray: raise NotImplementedError("Not yet supported.") - + self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) - self.ctx = zmq.Context() + self.ctx = zmq.Context() # type: ignore[attr-defined] - # Recieve input from the client. + # Receive input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) self.input_socket.bind(f"{ipc_path}_input_socket") @@ -84,10 +84,10 @@ def __init__(self, self.data_ipc_path = f"{ipc_path}_data_socket" @classmethod - def from_engine_args(cls, engine_args: AsyncEngineArgs, + def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): """Creates an RPCLLM engine from the engine arguments.""" - + engine_config = engine_args.create_engine_config() if engine_args.engine_use_ray: @@ -115,7 +115,8 @@ def cleanup(self): self.ctx.destroy(linger=0) del self.engine - def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: + def _init_engine(self, *args, + **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: """Initialize the LLMEngine""" if not self.engine_use_ray: @@ -135,20 +136,19 @@ def _init_engine(self, *args, **kwargs) -> Union[LLMEngine, "ray.ObjectRef"]: engine_class = ray.remote(num_gpus=num_gpus)( self._engine_class).remote return engine_class(*args, **kwargs) - def run_background_loop(self): """Entrypoint that kicks off the background processing loop.""" - - # Allow RPCClient to query data in startup phase. + + # Allow RPCClient to query data in startup phase. self.run_startup_loop() # Kick off core processing loop. self.run_engine_loop() - @contextmanager - def make_data_socket(self) -> Iterator[zmq.Socket]: + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] socket = self.ctx.socket(zmq.constants.ROUTER) try: socket.bind(self.data_ipc_path) @@ -158,7 +158,7 @@ def make_data_socket(self) -> Iterator[zmq.Socket]: def run_startup_loop(self) -> None: """Loop over startup RPCStatupRequest from RPCClient.""" - + with self.make_data_socket() as socket: # Loop until the RPCClient has all the data it needs. @@ -187,12 +187,13 @@ def run_startup_loop(self) -> None: response = VLLM_RPC_SUCCESS_STR # Breakout of loop once client is ready. client_is_ready = True - - socket.send_multipart( - (identity, pickle.dumps(response)), copy=False) + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) except Exception as e: - socket.send_multipart((identity, pickle.dumps(e)), copy=False) + socket.send_multipart((identity, pickle.dumps(e)), + copy=False) def run_engine_loop(self) -> None: while True: @@ -202,10 +203,10 @@ def run_engine_loop(self) -> None: # Handle any new input from the input socket. self.maybe_handle_new_input() - + # Engine step. request_outputs = self.engine.step() - + # Stream results to output socket. self.stream_outputs(request_outputs) @@ -214,9 +215,9 @@ def wait_for_new_input(self): logger.debug("Waiting for new request.") def stream_outputs(self, request_outputs: List[RequestOutput]): - self.output_socket.send_multipart( - (pickle.dumps(request_outputs),), copy=False) - + self.output_socket.send_multipart((pickle.dumps(request_outputs), ), + copy=False) + def awk_check_health(self): self.health_socket.send_multipart( (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) @@ -235,7 +236,7 @@ def maybe_handle_new_input(self): self._handle_utility_request(request) else: raise ValueError(f"Unknown RPCRequest: {request}") - + def _handle_generate_request(self, request: RPCGenerateRequest): self.engine.add_request( request_id=request.request_id, @@ -248,7 +249,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest): def _handle_abort_request(self, request: RPCAbortRequest): self.engine.abort_request([request.request_id]) - + def _handle_utility_request(self, request: RPCUtilityRequest): if request == RPCUtilityRequest.DO_LOG_STATS: self.engine.do_log_stats() @@ -256,13 +257,12 @@ def _handle_utility_request(self, request: RPCUtilityRequest): self.engine.check_health() # Special check_health channel for awk check health. self.awk_check_health() - -def run_mp_engine(engine_args: AsyncEngineArgs, - usage_context: UsageContext, + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): - engine = MPLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=usage_context, - ipc_path=ipc_path) + engine = MPLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) engine.run_background_loop() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ef8f98a3889b9..b7c0cee1af8b1 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,6 +21,9 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +# yapf: enable +from vllm.engine.multiprocessing.mp_client import MPEngineClient +from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger @@ -37,9 +40,6 @@ EmbeddingResponse, ErrorResponse, TokenizeRequest, TokenizeResponse) -# yapf: enable -from vllm.engine.multiprocessing.mp_client import MPEngineClient -from vllm.engine.multiprocessing.mp_llm_engine import run_mp_engine from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -84,7 +84,7 @@ async def _force_log(): while True: await asyncio.sleep(1.) await async_engine_client.do_log_stats() - + if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) _running_tasks.add(task) @@ -170,12 +170,12 @@ async def build_async_engine_client_from_engine_args( # so we need to spawn a new process context = multiprocessing.get_context("spawn") - engine_process = context.Process( - target=run_mp_engine, - args=(engine_args, UsageContext.OPENAI_API_SERVER, ipc_path)) + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) engine_process.start() - logger.info("Started engine process with PID %d", - engine_process.pid) + logger.info("Started engine process with PID %d", engine_process.pid) try: while True: @@ -184,9 +184,8 @@ async def build_async_engine_client_from_engine_args( break except TimeoutError: if not engine_process.is_alive(): - logger.error( - "Engine process died before responding " - "to readiness probe") + logger.error("Engine process died before responding " + "to readiness probe") yield None return From ae4564c239af500cda87c72118ccd83120388f8a Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:50:13 +0000 Subject: [PATCH 22/29] awk -> ack --- vllm/engine/multiprocessing/mp_client.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index ba3269c252ba3..201a3c0317705 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -55,7 +55,7 @@ def __init__(self, ipc_path: str): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}_output_socket") - # IPC path for awk of check_health requests. + # IPC path for ack of check_health requests. self.health_socket: Socket = self.context.socket(zmq.constants.PULL) self.health_socket.connect(f"{ipc_path}_health_socket") @@ -168,7 +168,7 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, # I don't think so, b/c we are using PUSH/PULL w/out identities so no # way to preserve order. - async def _awk_one_way_rpc_request( + async def _ack_one_way_rpc_request( self, timeout: int, expected_str: str, @@ -206,8 +206,8 @@ async def _wait_for_server_rpc(self, socket: Socket): request = RPCStartupRequest.IS_SERVER_READY await socket.send_multipart((cloudpickle.dumps(request), )) - # Raises TimeoutError if not awk, causing a retry. - await self._awk_one_way_rpc_request( + # Raises TimeoutError if not ack, causing a retry. + await self._ack_one_way_rpc_request( expected_str=VLLM_RPC_SUCCESS_STR, timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, error_message="Unable to start RPC Server", @@ -337,11 +337,6 @@ async def generate( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request)), )) - # ack: Frame = await socket.recv(copy=False) - # if len(ack.buffer) != 0: - # exception = pickle.loads(ack.buffer) - # raise exception - while not finished: request_output = await queue.get() if isinstance(request_output, BaseException): @@ -372,12 +367,12 @@ async def check_health(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.CHECK_HEALTH, socket=self.input_socket) - # Await awknoledgement from MPLLMEngine. + # Await acknowledgement from MPLLMEngine. # Note: these requests are not necessarily serial. # I.e. if two clients A, B send CHECK_HEALTH, the # response to A could actually be the call send by B. # TODO: is this bad? - await self._awk_one_way_rpc_request( + await self._ack_one_way_rpc_request( timeout=VLLM_RPC_GET_DATA_TIMEOUT_MS, expected_str=VLLM_RPC_SUCCESS_STR, error_message="Check health timeout.", From f9ccecc7048f28b3261658d28879abbf5b3b1e42 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:53:08 +0000 Subject: [PATCH 23/29] add better shutdown --- vllm/engine/multiprocessing/mp_client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index 201a3c0317705..0fd1afe953c14 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -122,8 +122,13 @@ def close(self): """Destroy the ZeroMQ Context.""" # Close all sockets associated with this context and # then terminate the context. + self.output_socket.close() + self.input_socket.close() + self.health_socket.close() self.context.destroy(linger=0) + # TODO: cancel the handler task. + async def _send_get_data_rpc_request(self, request: RPCStartupRequest, expected_type: Any, error_message: str, From 89b730b9c650864a4b2c3a4eda2378f616ae1eb8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:55:06 +0000 Subject: [PATCH 24/29] cleanup comment --- vllm/engine/multiprocessing/mp_llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 72ced337a1605..5bdcc419de2f5 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -86,7 +86,7 @@ def __init__(self, @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): - """Creates an RPCLLM engine from the engine arguments.""" + """Creates an MPLLMEngine from the engine arguments.""" engine_config = engine_args.create_engine_config() From f3dc82b584f3d2db67c87d5e8fc12d498189e902 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 02:59:20 +0000 Subject: [PATCH 25/29] more awk --> ack --- vllm/engine/multiprocessing/mp_llm_engine.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 5bdcc419de2f5..106f7a2fb8051 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -124,17 +124,7 @@ def _init_engine(self, *args, elif self.worker_use_ray: engine_class = ray.remote(num_cpus=0)(self._engine_class).remote else: - # FIXME(woosuk): This is a bit hacky. Be careful when changing the - # order of the arguments. - cache_config = kwargs["cache_config"] - parallel_config = kwargs["parallel_config"] - if (parallel_config.tensor_parallel_size == 1 - and parallel_config.pipeline_parallel_size == 1): - num_gpus = cache_config.gpu_memory_utilization - else: - num_gpus = 1 - engine_class = ray.remote(num_gpus=num_gpus)( - self._engine_class).remote + raise NotImplementedError("Not supported yet!") return engine_class(*args, **kwargs) def run_background_loop(self): @@ -218,7 +208,7 @@ def stream_outputs(self, request_outputs: List[RequestOutput]): self.output_socket.send_multipart((pickle.dumps(request_outputs), ), copy=False) - def awk_check_health(self): + def ack_check_health(self): self.health_socket.send_multipart( (pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) @@ -255,8 +245,7 @@ def _handle_utility_request(self, request: RPCUtilityRequest): self.engine.do_log_stats() elif request == RPCUtilityRequest.CHECK_HEALTH: self.engine.check_health() - # Special check_health channel for awk check health. - self.awk_check_health() + self.ack_check_health() def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, From ac97a9ebef67308dfa963279082a798971243349 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 03:00:21 +0000 Subject: [PATCH 26/29] use constant --- vllm/engine/multiprocessing/mp_llm_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index 106f7a2fb8051..aca11b9293c2e 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -21,6 +21,7 @@ logger = init_logger(__name__) +POLLING_TIMEOUT_MS = 10000 class MPLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -201,7 +202,7 @@ def run_engine_loop(self) -> None: self.stream_outputs(request_outputs) def wait_for_new_input(self): - while self.input_socket.poll(timeout=10000) == 0: + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for new request.") def stream_outputs(self, request_outputs: List[RequestOutput]): From becd7abe426e0437fde18b854a08785d7b914db0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 03:00:30 +0000 Subject: [PATCH 27/29] format --- vllm/engine/multiprocessing/mp_llm_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/multiprocessing/mp_llm_engine.py b/vllm/engine/multiprocessing/mp_llm_engine.py index aca11b9293c2e..f839a272e40f7 100644 --- a/vllm/engine/multiprocessing/mp_llm_engine.py +++ b/vllm/engine/multiprocessing/mp_llm_engine.py @@ -23,6 +23,7 @@ POLLING_TIMEOUT_MS = 10000 + class MPLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. From b7f49edd6da76a29339ee3d0954d2a583c7af642 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Sep 2024 03:03:06 +0000 Subject: [PATCH 28/29] remove set to None --- vllm/entrypoints/openai/api_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b7c0cee1af8b1..cca08e11912fd 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -134,7 +134,6 @@ async def build_async_engine_client_from_engine_args( yield async_engine_client finally: async_engine_client.shutdown_background_loop() - async_engine_client = None #TODO return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -207,8 +206,6 @@ async def build_async_engine_client_from_engine_args( from prometheus_client import multiprocess multiprocess.mark_process_dead(engine_process.pid) - async_engine_client = None #TODO - router = APIRouter() From d0f964158de7d3db00fd90796207eadcf0fac862 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 3 Sep 2024 20:50:41 -0700 Subject: [PATCH 29/29] Remove redundant pass --- vllm/engine/multiprocessing/mp_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/multiprocessing/mp_client.py b/vllm/engine/multiprocessing/mp_client.py index 0fd1afe953c14..c49a836755d4a 100644 --- a/vllm/engine/multiprocessing/mp_client.py +++ b/vllm/engine/multiprocessing/mp_client.py @@ -351,7 +351,6 @@ async def generate( if not self._errored: try: await self.check_health() - pass except Exception as e: self._errored = True logger.exception(repr(e))