diff --git a/api/api.py b/api/api.py index e52870d60..bef0eb914 100644 --- a/api/api.py +++ b/api/api.py @@ -8,7 +8,7 @@ import uuid from abc import ABC from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from build.utils import device_sync @@ -87,6 +87,11 @@ class StreamOptions: include_usage: bool = False +@dataclass +class ResponseFormat: + type: Optional[str] = None + + @dataclass class CompletionRequest: """A full chat completion request. @@ -94,24 +99,27 @@ class CompletionRequest: See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details. """ + messages: List[_AbstractMessage] model: str - prompt: str - messages: Optional[List[_AbstractMessage]] - frequency_penalty: float = 0.0 - temperature: float = 0.0 - stop: Optional[List[str]] = None - stream: bool = False - stream_options: Optional[StreamOptions] = None - echo: bool = False - frequency_penalty: float = 0.0 - guided_decode_json_schema: str = None - guided_decode_json_schema_path: str = None + frequency_penalty: float = 0.0 # unimplemented + logit_bias: Optional[Dict[str, float]] = None # unimplemented + logprobs: Optional[bool] = None # unimplemented + top_logprobs: Optional[int] = None # unimplemented + max_tokens: Optional[int] = None # unimplemented n: int = 1 - presence_penalty: float = 0 - logit_bias: Optional[Dict[str, float]] = None - logprobs: Optional[bool] = None - top_logprobs: Optional[int] = None - max_tokens: Optional[int] = None + presence_penalty: float = 0 # unimplemented + response_format: Optional[ResponseFormat] = None # unimplemented + seed: Optional[int] = None # unimplemented + service_tier: Optional[str] = None # unimplemented + stop: Optional[List[str]] = None # unimplemented + stream: bool = False + stream_options: Optional[StreamOptions] = None # unimplemented + temperature: Optional[float] = 1.0 # unimplemented + top_p: Optional[float] = 1.0 # unimplemented + tools: Optional[List[Any]] = None # unimplemented + tool_choice: Optional[Union[str, Any]] = None # unimplemented + parallel_tool_calls: Optional[bool] = None # unimplemented + user: Optional[str] = None # unimplemented @dataclass @@ -121,10 +129,10 @@ class CompletionChoice: See the "The chat completion object >>> choices" section of the OpenAI API docs for more details. """ - finish_reason: str index: int message: AssistantMessage - logprobs: Optional[List[Any]] + finish_reason: str = None + logprobs: Optional[List[Any]] = None @dataclass @@ -151,9 +159,9 @@ class CompletionResponse: created: int model: str system_fingerprint: str - usage: UsageStats - object: str = "chat.completion" service_tier: Optional[str] = None + usage: Optional[UsageStats] = None + object: str = "chat.completion" @dataclass @@ -193,8 +201,8 @@ class CompletionResponseChunk: created: int model: str system_fingerprint: str - object: str = "chat.completion.chunk" service_tier: Optional[str] = None + object: str = "chat.completion.chunk" usage: Optional[UsageStats] = None @@ -220,10 +228,27 @@ def __init__(self, *args, **kwargs): if self.draft_model is not None else self.model.config.max_seq_length ) + # The System fingerprint is a unique identifier for the model and its configuration. + # Currently, this is not implemented in a + self.system_fingerprint = ( + self.builder_args.device + type(self.builder_args.precision).__name__ + ) - def completion(self, completion_request: CompletionRequest): + def chunked_completion(self, completion_request: CompletionRequest): """Handle a chat completion request and yield a chunked response. + ** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented. + Current treatment of parameters is described below. + + - messages: The server consumes the final element of the array as the prompt. + - model: This has no impact on the server state, i.e. changing the model in the request + will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server. + - temperature: This is used to control the randomness of the response. + - system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change. + + See https://github.com/pytorch/torchchat/issues/973 for more details. + + Args: completion_request: Request object with prompt and other parameters. @@ -235,13 +260,16 @@ def completion(self, completion_request: CompletionRequest): # Initialize counters for chunk responses and encode the prompt. id = str(uuid.uuid4()) + idx = 0 buffer = [] encoded = self.encode_tokens( - completion_request.prompt, bos=True, device=self.builder_args.device + completion_request.messages[-1].get("content"), + bos=True, + device=self.builder_args.device, ) generator_args = GeneratorArgs( - completion_request.prompt, + completion_request.messages[-1].get("content"), encoded_prompt=encoded, chat_mode=False, ) @@ -291,21 +319,45 @@ def callback(x, *, done_generating=False): choices=[choice_chunk], created=int(time.time()), model=completion_request.model, - system_fingerprint=uuid.UUID(int=uuid.getnode()), + system_fingerprint=self.system_fingerprint, ) yield chunk_response self.start_pos += y.size(0) idx += 1 # Yield an ending chunk indicating the generation has completed. - end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos") + end_chunk = CompletionChoiceChunk( + ChunkDelta(None, None, None), idx, finish_reason="stop" + ) yield CompletionResponseChunk( id=str(id), choices=[end_chunk], created=int(time.time()), model=completion_request.model, - system_fingerprint=uuid.UUID(int=uuid.getnode()), + system_fingerprint=self.system_fingerprint, + ) + + def sync_completion(self, request: CompletionRequest): + """Handle a chat completion request and yield a single, non-chunked response""" + output = "" + for chunk in self.chunked_completion(request): + if not chunk.choices[0].finish_reason: + output += chunk.choices[0].delta.content + + message = AssistantMessage(content=output) + return CompletionResponse( + id=str(uuid.uuid4()), + choices=[ + CompletionChoice( + finish_reason="stop", + index=0, + message=message, + ) + ], + created=int(time.time()), + model=request.model, + system_fingerprint=self.system_fingerprint, ) def _callback(self, x, *, buffer, done_generating): diff --git a/api/models.py b/api/models.py new file mode 100644 index 000000000..45e459294 --- /dev/null +++ b/api/models.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from dataclasses import dataclass +from pwd import getpwuid +from typing import List, Union + +from download import is_model_downloaded, load_model_configs + +"""Helper functions for the OpenAI API Models endpoint. + +See https://platform.openai.com/docs/api-reference/models for the full specification and details. +Please create an issue if anything doesn't match the specification. +""" + + +@dataclass +class ModelInfo: + """The Model object per the OpenAI API specification containing information about a model. + + See https://platform.openai.com/docs/api-reference/models/object for more details. + """ + + id: str + created: int + owner: str + object: str = "model" + + +@dataclass +class ModelInfoList: + """A list of ModelInfo objects.""" + + data: List[ModelInfo] + object: str = "list" + + +def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]: + """Implementation of the OpenAI API Retrieve Model endpoint. + + See https://platform.openai.com/docs/api-reference/models/retrieve + + Inputs: + args: command line arguments + model_id: the id of the model requested + + Returns: + ModelInfo describing the specified if it is downloaded, None otherwise. + """ + if model_config := load_model_configs().get(model_id): + if is_model_downloaded(model_id, args.model_directory): + path = args.model_directory / model_config.name + created = int(os.path.getctime(path)) + owner = getpwuid(os.stat(path).st_uid).pw_name + + return ModelInfo(id=model_config.name, created=created, owner=owner) + return None + return None + + +def get_model_info_list(args) -> ModelInfo: + """Implementation of the OpenAI API List Models endpoint. + + See https://platform.openai.com/docs/api-reference/models/list + + Inputs: + args: command line arguments + + Returns: + ModelInfoList describing all downloaded models. + """ + data = [] + for model_id, model_config in load_model_configs().items(): + if is_model_downloaded(model_id, args.model_directory): + path = args.model_directory / model_config.name + created = int(os.path.getctime(path)) + owner = getpwuid(os.stat(path).st_uid).pw_name + + data.append(ModelInfo(id=model_config.name, created=created, owner=owner)) + response = ModelInfoList(data=data) + return response diff --git a/generate.py b/generate.py index 21d54373c..b8a2ef2a7 100644 --- a/generate.py +++ b/generate.py @@ -450,11 +450,15 @@ def generate( sequential_prefill=True, callback=lambda x: x, max_seq_length: int, + seed: Optional[int] = None, **sampling_kwargs, ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. """ + if seed: + torch.manual_seed(seed) + is_speculative = draft_model is not None device, dtype = prompt.device, prompt.dtype diff --git a/server.py b/server.py index a9132b98d..7d5fab009 100644 --- a/server.py +++ b/server.py @@ -4,62 +4,89 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator +import json + +from dataclasses import asdict +from typing import Dict, List, Union + +from api.api import CompletionRequest, OpenAiApiGenerator +from api.models import get_model_info_list, retrieve_model_info from build.builder import BuilderArgs, TokenizerArgs -from flask import Flask, jsonify, request, Response +from flask import Flask, request, Response from generate import GeneratorArgs -app = Flask(__name__) -# Messages and gen are kept global so they can be accessed by the flask app endpoints. -messages: list = [] -gen: OpenAiApiGenerator = None - -@app.route("/chat", methods=["POST"]) -def chat_endpoint(): +def create_app(args): """ - Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt. - This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat) + Creates a flask app that can be used to serve the model as a chat API. """ - data = request.get_json() - - # Add user message to chat history - messages.append(data["messages"][-1]) - prompt = messages[-1]["content"] - - # Generate the assistant response - req = CompletionRequest( - model=gen.builder_args.checkpoint_path, - prompt=prompt, - temperature=0, - messages=[], - ) + app = Flask(__name__) + + gen: OpenAiApiGenerator = initialize_generator(args) + + def _del_none(d: Union[Dict, List]) -> Union[Dict, List]: + """Recursively delete None values from a dictionary.""" + if type(d) is dict: + return {k: _del_none(v) for k, v in d.items() if v} + elif type(d) is list: + return [_del_none(v) for v in d if v] + return d + + @app.route("/chat", methods=["POST"]) + def chat_endpoint(): + """ + Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt. + This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat) + + ** Warning ** : Not all arguments of the CompletionRequest are consumed. + + See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details. + + If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise, + a single CompletionResponse object will be returned. + """ + + print(" === Completion Request ===") + + # Parse the request in to a CompletionRequest object + data = request.get_json() + req = CompletionRequest(**data) + + if data.get("stream") == "true": + + def chunk_processor(chunked_completion_generator): + """Inline function for postprocessing CompletionResponseChunk objects. + + Here, we just jsonify the chunk and yield it as a string. + """ + for chunk in chunked_completion_generator: + if (next_tok := chunk.choices[0].delta.content) is None: + next_tok = "" + print(next_tok, end="") + yield json.dumps(_del_none(asdict(chunk))) - response = "" + return Response( + chunk_processor(gen.chunked_completion(req)), + mimetype="text/event-stream", + ) + else: + response = gen.sync_completion(req) - def unwrap(completion_generator): - token_count = 0 - for chunk_response in completion_generator: - content = chunk_response.choices[0].delta.content - if not gen.is_llama3_model or content not in set( - gen.tokenizer.special_tokens.keys() - ): - yield content if content is not None else "" - if content == gen.tokenizer.eos_id(): - yield "." - token_count += 1 + return json.dumps(_del_none(asdict(response))) - if data.get("stream") == "true": - return Response(unwrap(gen.completion(req)), mimetype="text/event-stream") - else: - for content in unwrap(gen.completion(req)): - response += content + @app.route("/models", methods=["GET"]) + def models_endpoint(): + return json.dumps(asdict(get_model_info_list(args))) - # Add assistant response to chat history - messages.append(AssistantMessage(content=response)) + @app.route("/models/", methods=["GET"]) + def models_retrieve_endpoint(model_id): + if response := retrieve_model_info(args, model_id): + return json.dumps(asdict(response)) + else: + return "Model not found", 404 - return jsonify({"response": response}) + return app def initialize_generator(args) -> OpenAiApiGenerator: @@ -81,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator: def main(args): - global gen - gen = initialize_generator(args) + app = create_app(args) app.run()