diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index fffb5b8100ec7..2c805e18eebae 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -1,5 +1,5 @@ import asyncio -from typing import Tuple +from typing import List, Tuple import pytest @@ -13,6 +13,7 @@ allow_module_level=True) ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B", + enforce_eager=True, disable_log_requests=True) @@ -53,17 +54,63 @@ async def test_load(monkeypatch): generate(engine, request_id, NUM_EXPECTED_TOKENS))) # Confirm that we got all the EXPECTED tokens from the requests. - failed_request_id = None - tokens = None for task in tasks: num_generated_tokens, request_id = await task - if (num_generated_tokens != NUM_EXPECTED_TOKENS - and failed_request_id is None): - failed_request_id = request_id - tokens = num_generated_tokens - - assert failed_request_id is None, ( - f"{failed_request_id} generated {tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") + assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( + f"{request_id} generated {num_generated_tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + assert not engine.output_processor.has_unfinished_requests() + engine.shutdown() + + +@pytest.mark.asyncio +async def test_abort(monkeypatch): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args(ENGINE_ARGS) + + NUM_REQUESTS = 100 + NUM_EXPECTED_TOKENS = 100 + REQUEST_IDS_TO_ABORT = range(1, 100, 10) + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks: List[asyncio.Task] = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate(engine, request_id, NUM_EXPECTED_TOKENS))) + + # API server cancels requests when they disconnect. + for idx in REQUEST_IDS_TO_ABORT: + tasks[idx].cancel() + await asyncio.sleep(0.1) + + # Confirm the other requests are okay. + for idx, task in enumerate(tasks): + # Confirm that it was actually canceled. + if idx in REQUEST_IDS_TO_ABORT: + with pytest.raises(asyncio.CancelledError): + await task + else: + # Otherwise, make sure the request was not impacted. + num_generated_tokens, request_id = await task + assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( + f"{request_id} generated {num_generated_tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + assert not engine.output_processor.has_unfinished_requests() + + # Confirm we can do another generation. + request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}" + task = asyncio.create_task( + generate(engine, request_id, NUM_EXPECTED_TOKENS)) + num_generated_tokens, request_id = await task + assert num_generated_tokens == NUM_EXPECTED_TOKENS + assert not engine.output_processor.has_unfinished_requests() engine.shutdown() diff --git a/tests/v1/engine/test_detokenizer.py b/tests/v1/engine/test_output_processor.py similarity index 65% rename from tests/v1/engine/test_detokenizer.py rename to tests/v1/engine/test_output_processor.py index aeae697ca32b0..4735c6f947537 100644 --- a/tests/v1/engine/test_detokenizer.py +++ b/tests/v1/engine/test_output_processor.py @@ -3,11 +3,18 @@ import pytest from transformers import AutoTokenizer +from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest -from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.output_processor import OutputProcessor TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" +VLLM_CONFIG = EngineArgs(model=TOKENIZER_NAME).create_engine_config() +TOKENIZER_GROUP = init_tokenizer_from_configs(VLLM_CONFIG.model_config, + VLLM_CONFIG.scheduler_config, + VLLM_CONFIG.parallel_config, + VLLM_CONFIG.lora_config) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) FULL_STRINGS = [ @@ -66,7 +73,7 @@ def get_outputs(self) -> List[EngineCoreOutput]: "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) def test_incremental_detokenization(request_output_kind: RequestOutputKind): - detokenizer = Detokenizer(TOKENIZER_NAME) + output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) engine_core = MockEngineCore(GENERATION_TOKENS) # Make N requests. @@ -93,7 +100,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): # Add requests to the detokenizer. for request in requests: - detokenizer.add_request(request) + output_processor.add_request(request) gen_strings = {} gen_tokens = {} @@ -104,7 +111,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): break # Step the Detokenizer. - request_outputs, requests_to_abort = detokenizer.step(outputs) + processed_outputs = output_processor.process_outputs(outputs, ) + request_outputs = processed_outputs.request_outputs + requests_to_abort = processed_outputs.reqs_to_abort assert len(requests_to_abort) == 0 # Update tracking. @@ -128,13 +137,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind): assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}" assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}" - assert detokenizer.get_num_unfinished_requests() == 0 - assert not detokenizer.has_unfinished_requests() + assert output_processor.get_num_unfinished_requests() == 0 + assert not output_processor.has_unfinished_requests() @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) def test_stop_string(include_stop_str_in_output: bool): - detokenizer = Detokenizer(TOKENIZER_NAME) + output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False) engine_core = MockEngineCore(GENERATION_TOKENS) # Make N requests. @@ -162,7 +171,7 @@ def test_stop_string(include_stop_str_in_output: bool): # Add requests to the detokenizer. for request in requests: - detokenizer.add_request(request) + output_processor.add_request(request) gen_strings = {} aborted = [] @@ -173,7 +182,9 @@ def test_stop_string(include_stop_str_in_output: bool): break # Step the Detokenizer. - request_outputs, requests_to_abort = detokenizer.step(outputs) + processed_outputs = output_processor.process_outputs(outputs) + request_outputs = processed_outputs.request_outputs + requests_to_abort = processed_outputs.reqs_to_abort for request_output in request_outputs: # If aborted, we should not get a request output. assert request_output.request_id not in aborted @@ -214,5 +225,71 @@ def test_stop_string(include_stop_str_in_output: bool): assert gen_str == ref_str_exc_stop, ( f"{gen_str=}, {ref_str_exc_stop=}") - assert detokenizer.get_num_unfinished_requests() == 0 - assert not detokenizer.has_unfinished_requests() + assert output_processor.get_num_unfinished_requests() == 0 + assert not output_processor.has_unfinished_requests() + + +def test_iteration_stats(): + output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=True) + engine_core = MockEngineCore(GENERATION_TOKENS) + + # Make N requests. + requests = [ + EngineCoreRequest( + request_id=f"request-{idx}", + prompt=prompt, + prompt_token_ids=prompt_tokens, + arrival_time=0, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + eos_token_id=None, + lora_request=None, + sampling_params=SamplingParams(), + ) for idx, ( + prompt, + prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) + ] + + # Add all requests except one to the OutputProcessor. + num_active = len(GENERATION_TOKENS) - 1 + for request in requests[:num_active]: + output_processor.add_request(request) + inactive_request = requests[num_active] + + # First iteration has 2 prefills. + outputs = engine_core.get_outputs()[:num_active] + processed_outputs = output_processor.process_outputs(outputs) + iteration_stats = processed_outputs.iteration_stats + total_prompt_tokens = sum( + [len(prompt_tokens) for prompt_tokens in PROMPT_TOKENS[:num_active]]) + + assert iteration_stats.num_prompt_tokens == total_prompt_tokens + assert iteration_stats.num_generation_tokens == num_active + + # Just decodes in this step. + outputs = engine_core.get_outputs()[:num_active] + processed_outputs = output_processor.process_outputs(outputs) + iteration_stats = processed_outputs.iteration_stats + + assert iteration_stats.num_prompt_tokens == 0 + assert iteration_stats.num_generation_tokens == num_active + + # Add a new request - prefill and 2 decodes in this step. + output_processor.add_request(inactive_request) + num_active += 1 + outputs = engine_core.get_outputs()[:num_active] + processed_outputs = output_processor.process_outputs(outputs) + iteration_stats = processed_outputs.iteration_stats + total_prompt_tokens = len(PROMPT_TOKENS[num_active - 1]) + + assert iteration_stats.num_prompt_tokens == total_prompt_tokens + assert iteration_stats.num_generation_tokens == num_active + + # Just decodes in this step. + outputs = engine_core.get_outputs()[:num_active] + processed_outputs = output_processor.process_outputs(outputs) + iteration_stats = processed_outputs.iteration_stats + + assert iteration_stats.num_prompt_tokens == 0 + assert iteration_stats.num_generation_tokens == num_active diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index e0ceb59dffcbd..a74699f7513e6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union +from typing import AsyncGenerator, List, Mapping, Optional, Type, Union from vllm.config import ModelConfig, VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -18,11 +18,11 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase -from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -59,9 +59,6 @@ def __init__( lora_config=vllm_config.lora_config) self.tokenizer.ping() - # Request streams (map of request_id -> queue). - self.rid_to_queue: Dict[str, asyncio.Queue] = {} - # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( model_config=vllm_config.model_config, @@ -71,13 +68,9 @@ def __init__( input_registry=input_registry, ) - # Detokenizer (converts EngineCoreOutputs --> RequestOutput). - self.detokenizer = Detokenizer( - tokenizer_name=vllm_config.model_config.tokenizer, - tokenizer_mode=vllm_config.model_config.tokenizer_mode, - trust_remote_code=vllm_config.model_config.trust_remote_code, - revision=vllm_config.model_config.tokenizer_revision, - ) + # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). + self.output_processor = OutputProcessor(self.tokenizer, + log_stats=self.log_stats) # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_client( @@ -140,9 +133,9 @@ async def add_request( """Add new request to the AsyncLLM.""" # 1) Create a new output queue for the request. - if request_id in self.rid_to_queue: + if self.output_processor.is_request_active(request_id): raise ValueError(f"Request id {request_id} already running.") - self.rid_to_queue[request_id] = asyncio.Queue() + queue: asyncio.Queue[RequestOutput] = asyncio.Queue() # 2) Convert Input --> Request. request = self.processor.process_inputs(request_id, prompt, params, @@ -151,8 +144,8 @@ async def add_request( prompt_adapter_request, priority) - # 3) Add the request to Detokenizer (this process). - self.detokenizer.add_request(request) + # 3) Add the request to OutputProcessor (this process). + self.output_processor.add_request(request, queue) # 4) Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(request) @@ -160,7 +153,7 @@ async def add_request( if self.log_requests: logger.info("Added request %s.", request_id) - return self.rid_to_queue[request_id] + return queue # TODO: we should support multiple prompts in one call, as you # can do with LLM.generate. So that for multi-prompt completion @@ -217,10 +210,9 @@ async def generate( # task switching under load which helps performance). out = q.get_nowait() if q.qsize() > 0 else await q.get() - # Note: both Detokenizer and EngineCore handle their + # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. if out.finished: - del self.rid_to_queue[request_id] yield out break @@ -233,57 +225,51 @@ async def generate( await self.abort(request_id) raise - def _process_request_outputs(self, request_outputs: List[RequestOutput]): - """Process outputs by putting them into per-request queues.""" - - for request_output in request_outputs: - request_id = request_output.request_id - - # Note: it is possible a request was aborted and removed from - # the state due to client cancellations, so if we encounter a - # request id not in the state, we skip. - if request_id in self.rid_to_queue: - self.rid_to_queue[request_id].put_nowait(request_output) - async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" try: while True: - # 1) Pull EngineCoreOutput from the EngineCore. + # 1) Pull EngineCoreOutputs from the EngineCore. outputs = await self.engine_core.get_output_async() - # 2) Detokenize based on the output. - request_outputs, reqs_to_abort = self.detokenizer.step( + # 2) Process EngineCoreOutputs. + processed_outputs = self.output_processor.process_outputs( outputs.outputs) + # NOTE: RequestOutputs are pushed to their queues. + assert len(processed_outputs.request_outputs) == 0 - # 3) Put the RequestOutputs into the per-request queues. - self._process_request_outputs(request_outputs) + # 3) Abort any reqs that finished due to stop strings. + await self.engine_core.abort_requests_async( + processed_outputs.reqs_to_abort) - # 4) Abort any requests that finished due to stop strings. - await self.engine_core.abort_requests_async(reqs_to_abort) - - # 5) Log any stats. - await self._log_stats(scheduler_stats=outputs.scheduler_stats) + # 4) Logging. + # TODO(rob): make into a coroutine and launch it in + # background thread once we add Prometheus. + self._log_stats( + scheduler_stats=outputs.scheduler_stats, + iteration_stats=processed_outputs.iteration_stats, + ) except Exception as e: logger.exception("EngineCore output handler hit an error: %s", e) kill_process_tree(os.getpid()) async def abort(self, request_id: str) -> None: - """Abort RequestId in self, detokenizer, and engine core.""" + """Abort RequestId in OutputProcessor and EngineCore.""" request_ids = [request_id] await self.engine_core.abort_requests_async(request_ids) - self.detokenizer.abort_requests(request_ids) + self.output_processor.abort_requests(request_ids) - # If a request finishes while we await then the request_id - # will be removed from the tracked queues before we get here. - if request_id in self.rid_to_queue: - del self.rid_to_queue[request_id] + if self.log_requests: + logger.info("Aborted request %s.", request_id) - async def _log_stats(self, scheduler_stats: SchedulerStats): - """Log stats to the stat loggers.""" + def _log_stats( + self, + scheduler_stats: SchedulerStats, + iteration_stats: IterationStats, + ): if not self.log_stats: return @@ -314,8 +300,7 @@ async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - assert lora_request is None - return self.detokenizer.tokenizer + return self.tokenizer.get_lora_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: return False diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9d6ae725e9d2b..ac0f0f14bf1ab 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -105,7 +105,8 @@ def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) def abort_requests(self, request_ids: List[str]) -> None: - self.engine_core.abort_requests(request_ids) + if len(request_ids) > 0: + self.engine_core.abort_requests(request_ids) def shutdown(self): self.engine_core.shutdown() @@ -221,7 +222,8 @@ def add_request(self, request: EngineCoreRequest) -> None: self._send_input(EngineCoreRequestType.ADD, request) def abort_requests(self, request_ids: List[str]) -> None: - self._send_input(EngineCoreRequestType.ABORT, request_ids) + if len(request_ids) > 0: + self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: self._send_input(EngineCoreRequestType.PROFILE, diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 65be9e58e03c8..4a8b61beec037 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,18 +1,25 @@ from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import List, Optional, Union from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest logger = init_logger(__name__) +@dataclass +class DetokenizerOutput: + output_text: str + token_ids: List[int] + finished: bool + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + + @dataclass class IncrementalDetokenizer: @@ -20,6 +27,7 @@ class IncrementalDetokenizer: output_text: str tokens: List[str] token_ids: List[int] + prompt_len: int # Stop strings stop: List[str] @@ -34,11 +42,6 @@ class IncrementalDetokenizer: spaces_between_special_tokens: bool output_kind: RequestOutputKind - # TODO: Probably decouple these - request_id: str - prompt: Optional[str] - prompt_token_ids: List[int] - # Tokenizer for this request tokenizer: AnyTokenizer @@ -48,8 +51,7 @@ class IncrementalDetokenizer: @property def output_token_ids(self) -> List[int]: - assert len(self.token_ids) >= len(self.prompt_token_ids) - return self.token_ids[len(self.prompt_token_ids):] + return self.token_ids[self.prompt_len:] @classmethod def from_new_request( @@ -87,25 +89,25 @@ def from_new_request( spaces_between_special_tokens=request.sampling_params. spaces_between_special_tokens, output_kind=request.sampling_params.output_kind, - request_id=request.request_id, - prompt=request.prompt, - prompt_token_ids=request.prompt_token_ids, + prompt_len=len(request.prompt_token_ids), tokenizer=tokenizer, stop_buffer_length=stop_buffer_length, ) - def add_tokens( + def update_from_output( self, - new_token_ids: List[int], - finish_reason: Optional[str], - stop_reason: Optional[Union[int, str, None]], - ) -> Optional[RequestOutput]: + output: EngineCoreOutput, + ) -> Optional[DetokenizerOutput]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. 2) Update the RequestOutput with the new text. """ + new_token_ids = output.new_token_ids + finish_reason = output.finish_reason + stop_reason = output.stop_reason + # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. @@ -158,21 +160,8 @@ def add_tokens( output_text = self._get_next_output_text(finished, delta) token_ids = new_token_ids if delta else self.output_token_ids - request_output = RequestOutput.new( - self.request_id, - self.prompt, - self.prompt_token_ids, - output_text, - token_ids, - finished, - ) - - if finished: - completion_output = request_output.outputs[0] - completion_output.finish_reason = finish_reason - completion_output.stop_reason = stop_reason - - return request_output + return DetokenizerOutput(output_text, token_ids, finished, + finish_reason, stop_reason) def _get_next_output_text(self, finished: bool, delta: bool) -> str: """If delta is True, only new text since the last call to @@ -189,85 +178,3 @@ def _get_next_output_text(self, finished: bool, delta: bool) -> str: self._last_output_text_offset = length return self.output_text[last_offset:length] return "" - - -class Detokenizer: - - def __init__(self, - tokenizer_name: str, - tokenizer_mode: str = "auto", - trust_remote_code: bool = False, - revision: Optional[str] = None): - # TODO: once we support LoRA, we should should pass the tokenizer - # here. We currently have two copies (this + in the LLMEngine). - self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - revision=revision) - - # Request id -> IncrementalDetokenizer - self.request_states: Dict[str, IncrementalDetokenizer] = {} - - def is_request_active(self, request_id: str): - return request_id in self.request_states - - def get_num_unfinished_requests(self): - return len(self.request_states) - - def has_unfinished_requests(self) -> bool: - return len(self.request_states) > 0 - - def abort_requests( - self, - request_ids: Iterable[str], - ) -> None: - """Remove the request_ids from the Detokenizer.""" - - for request_id in request_ids: - self.request_states.pop(request_id, None) - - def add_request( - self, - request: EngineCoreRequest, - ): - """Add new request to the Detokenizer.""" - - assert (request.request_id not in self.request_states) - - request_state = IncrementalDetokenizer.from_new_request( - self.tokenizer, request) - self.request_states[request.request_id] = request_state - - def step( - self, encore_core_outputs: List[EngineCoreOutput] - ) -> Tuple[List[RequestOutput], List[str]]: - """Update state and request the RequestOutputs to the LLMEngine.""" - - request_outputs: List[RequestOutput] = [] - requests_to_abort: List[str] = [] - for engine_core_output in encore_core_outputs: - request_id = engine_core_output.request_id - detokenizer = self.request_states.get(request_id) - if detokenizer is None: - # Ignore output for already-aborted request. - continue - - # Detokenize and update state. - request_output = detokenizer.add_tokens( - new_token_ids=engine_core_output.new_token_ids, - finish_reason=engine_core_output.finish_reason, - stop_reason=engine_core_output.stop_reason, - ) - - if request_output is not None: - # Add to RequestOutputs list. - request_outputs.append(request_output) - - # Free completed requests. - if request_output.finished: - self.request_states.pop(request_id) - if not engine_core_output.finished: - requests_to_abort.append(request_id) - - # Return to EngineClient. - return request_outputs, requests_to_abort diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index ac392f5e4f4cf..f5999ccda6447 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -18,7 +18,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -60,13 +60,9 @@ def __init__( input_registry=input_registry, mm_registry=mm_registry) - # Detokenizer (converts EngineCoreOutputs --> RequestOutput) - self.detokenizer = Detokenizer( - tokenizer_name=vllm_config.model_config.tokenizer, - tokenizer_mode=vllm_config.model_config.tokenizer_mode, - trust_remote_code=vllm_config.model_config.trust_remote_code, - revision=vllm_config.model_config.tokenizer_revision, - ) + # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). + self.output_processor = OutputProcessor(self.tokenizer, + log_stats=False) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( @@ -103,10 +99,10 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.detokenizer.get_num_unfinished_requests() + return self.output_processor.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: - return self.detokenizer.has_unfinished_requests() + return self.output_processor.has_unfinished_requests() @classmethod def validate_outputs(cls, outputs, output_type): @@ -116,7 +112,7 @@ def abort_request(self, request_ids: List[str]) -> None: """Remove request_ids from EngineCore and Detokenizer.""" self.engine_core.abort_requests(request_ids) - self.detokenizer.abort_requests(request_ids) + self.output_processor.abort_requests(request_ids) def add_request( self, @@ -137,8 +133,8 @@ def add_request( prompt_adapter_request, priority) - # 2) Add the request to Detokenizer. - self.detokenizer.add_request(request) + # 2) Make a new RequestState and queue. + self.output_processor.add_request(request) # 3) Add the request to EngineCore. self.engine_core.add_request(request) @@ -148,15 +144,14 @@ def step(self) -> List[RequestOutput]: # 1) Get EngineCoreOutput from the EngineCore. outputs = self.engine_core.get_output() - # 2) Detokenizer the EngineCoreOutput. - request_outputs, requests_to_abort = self.detokenizer.step( + # 2) Process EngineCoreOutputs. + processed_outputs = self.output_processor.process_outputs( outputs.outputs) - # 3) Abort requests that finished due to stopping criteria. - if requests_to_abort: - self.abort_request(requests_to_abort) + # 3) Abort any reqs that finished due to stop strings. + self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - return request_outputs + return processed_outputs.request_outputs def get_model_config(self): return self.model_config diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py new file mode 100644 index 0000000000000..749f4f5043c97 --- /dev/null +++ b/vllm/v1/engine/output_processor.py @@ -0,0 +1,200 @@ +import asyncio +from dataclasses import dataclass +from typing import Dict, List, Optional + +from vllm.outputs import RequestOutput +from vllm.transformers_utils.detokenizer_utils import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest +from vllm.v1.engine.detokenizer import (DetokenizerOutput, + IncrementalDetokenizer) +from vllm.v1.metrics.stats import IterationStats + + +@dataclass +class OutputProcessorOutput: + + request_outputs: List[RequestOutput] + reqs_to_abort: List[str] + iteration_stats: IterationStats + + +class RequestState: + + def __init__( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: List[int], + detokenizer: IncrementalDetokenizer, + queue: Optional[asyncio.Queue[RequestOutput]], + ): + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.prompt_len = len(prompt_token_ids) + self.detokenizer = detokenizer + self.is_prefilling = True + self.queue = queue + + @classmethod + def from_new_request( + cls, + tokenizer: AnyTokenizer, + request: EngineCoreRequest, + queue: Optional[asyncio.Queue[RequestOutput]] = None, + ) -> "RequestState": + return cls( + request_id=request.request_id, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + detokenizer=IncrementalDetokenizer.from_new_request( + tokenizer=tokenizer, + request=request, + ), + queue=queue, + ) + + +class OutputProcessor: + """Process EngineCoreOutputs into RequestOutputs.""" + + def __init__( + self, + tokenizer: BaseTokenizerGroup, + log_stats: bool, + ): + self.log_stats = log_stats + self.tokenizer = tokenizer + self.request_states: Dict[str, RequestState] = {} + + def is_request_active(self, request_id: str) -> bool: + return request_id in self.request_states + + def get_num_unfinished_requests(self): + return len(self.request_states) + + def has_unfinished_requests(self) -> bool: + return len(self.request_states) > 0 + + def abort_requests( + self, + request_ids: List[str], + ) -> None: + for request_id in request_ids: + self.request_states.pop(request_id, None) + + def add_request( + self, + request: EngineCoreRequest, + queue: Optional[asyncio.Queue[RequestOutput]] = None, + ) -> None: + request_id = request.request_id + if request_id in self.request_states: + raise ValueError(f"Request id {request_id} already running.") + + self.request_states[request_id] = RequestState.from_new_request( + tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), + request=request, + queue=queue) + + def process_outputs( + self, + engine_core_outputs: List[EngineCoreOutput], + ) -> OutputProcessorOutput: + """ + Process the EngineCoreOutputs: + 1) Compute stats for logging + 2) Detokenize + 3) Create and handle RequestOutput objects: + * If there is a queue (for usage with AsyncLLM), + put the RequestOutput objects into the queue for + handling by the per-request generate() tasks. + + * If there is no queue (for usage with LLMEngine), + return a list of RequestOutput objects. + + ****************** NOTE FOR DEVELOPERS ****************** + + VLLM V1 minimizes the number of python loops over the full + batch to ensure system overheads are minimized. This is the + only function that should loop over EngineCoreOutputs. + + If you need to touch every element of the batch, implement a + method called XXXClass.update_from_output() to be called + within the loop below. For examples, see: + * IterationStats.update_from_output() + * Detokenizer.update_from_output() + + TODO(rob): add Protocol makes update_from_output explicit. + + ********************************************************** + """ + + request_outputs: List[RequestOutput] = [] + reqs_to_abort: List[str] = [] + iteration_stats = IterationStats(self.log_stats) + for engine_core_output in engine_core_outputs: + req_id = engine_core_output.request_id + req_state = self.request_states.get(req_id) + if req_state is None: + # Ignore output for already-aborted request. + continue + + # 1) Compute stats for this iteration. + iteration_stats.update_from_output(engine_core_output, + req_state.is_prefilling, + req_state.prompt_len) + req_state.is_prefilling = False + + # 2) Detokenize the token ids into text. + detokenizer_output = req_state.detokenizer.update_from_output( + engine_core_output) + + # 3) Create and handle RequestOutput objects. + if request_output := self._make_request_output( + req_state, detokenizer_output): + if req_state.queue is not None: + # AsyncLLM: put into queue for handling by generate(). + req_state.queue.put_nowait(request_output) + else: + # LLMEngine: return list of RequestOutputs. + request_outputs.append(request_output) + + # Free completed requests. + if request_output.finished: + self.request_states.pop(req_id) + if not engine_core_output.finished: + # If req not finished in EngineCore, but Detokenizer + # detected stop string, abort needed in EngineCore. + reqs_to_abort.append(req_id) + + return OutputProcessorOutput( + request_outputs=request_outputs, + reqs_to_abort=reqs_to_abort, + iteration_stats=iteration_stats, + ) + + def _make_request_output( + self, + request_state: RequestState, + detokenizer_output: Optional[DetokenizerOutput], + ) -> Optional[RequestOutput]: + + if detokenizer_output is None: + return None + + request_output = RequestOutput.new( + request_state.request_id, + request_state.prompt, + request_state.prompt_token_ids, + detokenizer_output.output_text, + detokenizer_output.token_ids, + detokenizer_output.finished, + ) + if detokenizer_output.finished: + completion_output = request_output.outputs[0] + completion_output.finish_reason = detokenizer_output.finish_reason + completion_output.stop_reason = detokenizer_output.stop_reason + + return request_output diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 5ebb4fd5b37db..60cb986f8bbce 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -1,4 +1,8 @@ from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.v1.engine import EngineCoreOutput @dataclass @@ -10,3 +14,26 @@ class SchedulerStats: # gpu_cache_usage: float = 0.0 # gpu_prefix_cache_hit_rate: float = 0.0 + + +class IterationStats: + """Stats associated with a single set of EngineCoreOutputs.""" + + def __init__(self, log_stats: bool): + self.log_stats = log_stats + self.num_generation_tokens = 0 + self.num_prompt_tokens = 0 + + def update_from_output(self, output: "EngineCoreOutput", + is_prefilling: bool, prompt_len: int): + if not self.log_stats: + return + + self.num_generation_tokens += len(output.new_token_ids) + if is_prefilling: + # This relies on the invariant that EngineCore does + # not stream outputs for partially completed prefills + # (scheduler.update_from_output makes EngineCoreOutput + # iff num_computed_tokens == num_tokens). + assert (len(output.new_token_ids) > 0) + self.num_prompt_tokens += prompt_len