Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] OpenAI server: propagate usage accounting to FastAPI middleware layer #8672

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens: Optional[int] = 0


class RequestResponseMetadata(BaseModel):
request_id: str
final_usage_info: Optional[UsageInfo] = None


class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str
description: Optional[str] = None
Expand Down
26 changes: 23 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
Expand Down Expand Up @@ -175,6 +176,11 @@ async def create_chat_completion(
"--enable-auto-tool-choice and --tool-call-parser to be set")

request_id = f"chat-{random_uuid()}"

request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata

try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
Expand Down Expand Up @@ -241,11 +247,13 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
request, result_generator, request_id, conversation, tokenizer,
request_metadata)

try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer)
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -262,6 +270,7 @@ async def chat_completion_stream_generator(
request_id: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
Expand Down Expand Up @@ -580,6 +589,13 @@ async def chat_completion_stream_generator(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
total_tokens=num_prompt_tokens + num_completion_tokens)

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e)
Expand All @@ -595,6 +611,7 @@ async def chat_completion_full_generator(
request_id: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.base_model_paths[0].name
Expand Down Expand Up @@ -714,6 +731,9 @@ async def chat_completion_full_generator(
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)

request_metadata.final_usage_info = usage

response = ChatCompletionResponse(
id=request_id,
created=created_time,
Expand Down
37 changes: 29 additions & 8 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse, UsageInfo)
ErrorResponse,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
Expand Down Expand Up @@ -94,6 +96,10 @@ async def create_completion(
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())

request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
Expand Down Expand Up @@ -165,13 +171,15 @@ async def create_completion(

# Streaming response
if stream:
return self.completion_stream_generator(request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer)
return self.completion_stream_generator(
request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer,
request_metadata=request_metadata)

# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
Expand All @@ -198,6 +206,7 @@ async def create_completion(
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
Expand Down Expand Up @@ -227,6 +236,7 @@ async def completion_stream_generator(
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
Expand Down Expand Up @@ -346,6 +356,14 @@ async def completion_stream_generator(
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
Expand All @@ -360,6 +378,7 @@ def request_output_to_completion_response(
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
Expand Down Expand Up @@ -433,6 +452,8 @@ def request_output_to_completion_response(
total_tokens=num_prompt_tokens + num_generated_tokens,
)

request_metadata.final_usage_info = usage

return CompletionResponse(
id=request_id,
created=created_time,
Expand Down
Loading