From dea8d6044d95a5fb7b56bc541282864bac08825b Mon Sep 17 00:00:00 2001
From: vmpuri <45368418+vmpuri@users.noreply.github.com>
Date: Mon, 5 Aug 2024 16:32:39 -0700
Subject: [PATCH] OpenAI API JSON formatted  (#995)

* Add warning comments referring to unimplemented functionality

* JSON formatted response using OpenAI API types for server completion requests

* Add models endpoint (#1000)
---
 api/api.py    | 108 +++++++++++++++++++++++++++++++++------------
 api/models.py |  86 ++++++++++++++++++++++++++++++++++++
 generate.py   |   4 ++
 server.py     | 118 ++++++++++++++++++++++++++++++--------------------
 4 files changed, 242 insertions(+), 74 deletions(-)
 create mode 100644 api/models.py

diff --git a/api/api.py b/api/api.py
index e52870d60..bef0eb914 100644
--- a/api/api.py
+++ b/api/api.py
@@ -8,7 +8,7 @@
 import uuid
 from abc import ABC
 from dataclasses import dataclass
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
 
 from build.utils import device_sync
 
@@ -87,6 +87,11 @@ class StreamOptions:
     include_usage: bool = False
 
 
+@dataclass
+class ResponseFormat:
+    type: Optional[str] = None
+
+
 @dataclass
 class CompletionRequest:
     """A full chat completion request.
@@ -94,24 +99,27 @@ class CompletionRequest:
     See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
     """
 
+    messages: List[_AbstractMessage]
     model: str
-    prompt: str
-    messages: Optional[List[_AbstractMessage]]
-    frequency_penalty: float = 0.0
-    temperature: float = 0.0
-    stop: Optional[List[str]] = None
-    stream: bool = False
-    stream_options: Optional[StreamOptions] = None
-    echo: bool = False
-    frequency_penalty: float = 0.0
-    guided_decode_json_schema: str = None
-    guided_decode_json_schema_path: str = None
+    frequency_penalty: float = 0.0  # unimplemented
+    logit_bias: Optional[Dict[str, float]] = None  # unimplemented
+    logprobs: Optional[bool] = None  # unimplemented
+    top_logprobs: Optional[int] = None  # unimplemented
+    max_tokens: Optional[int] = None  # unimplemented
     n: int = 1
-    presence_penalty: float = 0
-    logit_bias: Optional[Dict[str, float]] = None
-    logprobs: Optional[bool] = None
-    top_logprobs: Optional[int] = None
-    max_tokens: Optional[int] = None
+    presence_penalty: float = 0  # unimplemented
+    response_format: Optional[ResponseFormat] = None  # unimplemented
+    seed: Optional[int] = None  # unimplemented
+    service_tier: Optional[str] = None  # unimplemented
+    stop: Optional[List[str]] = None  # unimplemented
+    stream: bool = False
+    stream_options: Optional[StreamOptions] = None  # unimplemented
+    temperature: Optional[float] = 1.0  # unimplemented
+    top_p: Optional[float] = 1.0  # unimplemented
+    tools: Optional[List[Any]] = None  # unimplemented
+    tool_choice: Optional[Union[str, Any]] = None  # unimplemented
+    parallel_tool_calls: Optional[bool] = None  # unimplemented
+    user: Optional[str] = None  # unimplemented
 
 
 @dataclass
@@ -121,10 +129,10 @@ class CompletionChoice:
     See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
     """
 
-    finish_reason: str
     index: int
     message: AssistantMessage
-    logprobs: Optional[List[Any]]
+    finish_reason: str = None
+    logprobs: Optional[List[Any]] = None
 
 
 @dataclass
@@ -151,9 +159,9 @@ class CompletionResponse:
     created: int
     model: str
     system_fingerprint: str
-    usage: UsageStats
-    object: str = "chat.completion"
     service_tier: Optional[str] = None
+    usage: Optional[UsageStats] = None
+    object: str = "chat.completion"
 
 
 @dataclass
@@ -193,8 +201,8 @@ class CompletionResponseChunk:
     created: int
     model: str
     system_fingerprint: str
-    object: str = "chat.completion.chunk"
     service_tier: Optional[str] = None
+    object: str = "chat.completion.chunk"
     usage: Optional[UsageStats] = None
 
 
@@ -220,10 +228,27 @@ def __init__(self, *args, **kwargs):
             if self.draft_model is not None
             else self.model.config.max_seq_length
         )
+        # The System fingerprint is a unique identifier for the model and its configuration.
+        # Currently, this is not implemented in a
+        self.system_fingerprint = (
+            self.builder_args.device + type(self.builder_args.precision).__name__
+        )
 
-    def completion(self, completion_request: CompletionRequest):
+    def chunked_completion(self, completion_request: CompletionRequest):
         """Handle a chat completion request and yield a chunked response.
 
+        ** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
+        Current treatment of parameters is described below.
+
+        - messages: The server consumes the final element of the array as the prompt.
+        - model: This has no impact on the server state, i.e. changing the model in the request
+        will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
+        - temperature: This is used to control the randomness of the response.
+        - system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
+
+        See https://github.com/pytorch/torchchat/issues/973 for more details.
+
+
         Args:
             completion_request: Request object with prompt and other parameters.
 
@@ -235,13 +260,16 @@ def completion(self, completion_request: CompletionRequest):
 
         # Initialize counters for chunk responses and encode the prompt.
         id = str(uuid.uuid4())
+
         idx = 0
         buffer = []
         encoded = self.encode_tokens(
-            completion_request.prompt, bos=True, device=self.builder_args.device
+            completion_request.messages[-1].get("content"),
+            bos=True,
+            device=self.builder_args.device,
         )
         generator_args = GeneratorArgs(
-            completion_request.prompt,
+            completion_request.messages[-1].get("content"),
             encoded_prompt=encoded,
             chat_mode=False,
         )
@@ -291,21 +319,45 @@ def callback(x, *, done_generating=False):
                 choices=[choice_chunk],
                 created=int(time.time()),
                 model=completion_request.model,
-                system_fingerprint=uuid.UUID(int=uuid.getnode()),
+                system_fingerprint=self.system_fingerprint,
             )
             yield chunk_response
             self.start_pos += y.size(0)
             idx += 1
 
         # Yield an ending chunk indicating the generation has completed.
-        end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
+        end_chunk = CompletionChoiceChunk(
+            ChunkDelta(None, None, None), idx, finish_reason="stop"
+        )
 
         yield CompletionResponseChunk(
             id=str(id),
             choices=[end_chunk],
             created=int(time.time()),
             model=completion_request.model,
-            system_fingerprint=uuid.UUID(int=uuid.getnode()),
+            system_fingerprint=self.system_fingerprint,
+        )
+
+    def sync_completion(self, request: CompletionRequest):
+        """Handle a chat completion request and yield a single, non-chunked response"""
+        output = ""
+        for chunk in self.chunked_completion(request):
+            if not chunk.choices[0].finish_reason:
+                output += chunk.choices[0].delta.content
+
+        message = AssistantMessage(content=output)
+        return CompletionResponse(
+            id=str(uuid.uuid4()),
+            choices=[
+                CompletionChoice(
+                    finish_reason="stop",
+                    index=0,
+                    message=message,
+                )
+            ],
+            created=int(time.time()),
+            model=request.model,
+            system_fingerprint=self.system_fingerprint,
         )
 
     def _callback(self, x, *, buffer, done_generating):
diff --git a/api/models.py b/api/models.py
new file mode 100644
index 000000000..45e459294
--- /dev/null
+++ b/api/models.py
@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+
+from dataclasses import dataclass
+from pwd import getpwuid
+from typing import List, Union
+
+from download import is_model_downloaded, load_model_configs
+
+"""Helper functions for the OpenAI API Models endpoint.
+
+See https://platform.openai.com/docs/api-reference/models for the full specification and details.
+Please create an issue if anything doesn't match the specification.
+"""
+
+
+@dataclass
+class ModelInfo:
+    """The Model object per the OpenAI API specification containing information about a model.
+
+    See https://platform.openai.com/docs/api-reference/models/object for more details.
+    """
+
+    id: str
+    created: int
+    owner: str
+    object: str = "model"
+
+
+@dataclass
+class ModelInfoList:
+    """A list of ModelInfo objects."""
+
+    data: List[ModelInfo]
+    object: str = "list"
+
+
+def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
+    """Implementation of the OpenAI API Retrieve Model endpoint.
+
+    See https://platform.openai.com/docs/api-reference/models/retrieve
+
+    Inputs:
+        args: command line arguments
+        model_id: the id of the model requested
+
+    Returns:
+        ModelInfo describing the specified if it is downloaded, None otherwise.
+    """
+    if model_config := load_model_configs().get(model_id):
+        if is_model_downloaded(model_id, args.model_directory):
+            path = args.model_directory / model_config.name
+            created = int(os.path.getctime(path))
+            owner = getpwuid(os.stat(path).st_uid).pw_name
+
+            return ModelInfo(id=model_config.name, created=created, owner=owner)
+        return None
+    return None
+
+
+def get_model_info_list(args) -> ModelInfo:
+    """Implementation of the OpenAI API List Models endpoint.
+
+    See https://platform.openai.com/docs/api-reference/models/list
+
+    Inputs:
+        args: command line arguments
+
+    Returns:
+        ModelInfoList describing all downloaded models.
+    """
+    data = []
+    for model_id, model_config in load_model_configs().items():
+        if is_model_downloaded(model_id, args.model_directory):
+            path = args.model_directory / model_config.name
+            created = int(os.path.getctime(path))
+            owner = getpwuid(os.stat(path).st_uid).pw_name
+
+            data.append(ModelInfo(id=model_config.name, created=created, owner=owner))
+    response = ModelInfoList(data=data)
+    return response
diff --git a/generate.py b/generate.py
index eff086afd..5920bd656 100644
--- a/generate.py
+++ b/generate.py
@@ -452,11 +452,15 @@ def generate(
         sequential_prefill=True,
         callback=lambda x: x,
         max_seq_length: int,
+        seed: Optional[int] = None,
         **sampling_kwargs,
     ) -> torch.Tensor:
         """
         Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
         """
+        if seed:
+            torch.manual_seed(seed)
+
         is_speculative = draft_model is not None
         device, dtype = prompt.device, prompt.dtype
 
diff --git a/server.py b/server.py
index a9132b98d..7d5fab009 100644
--- a/server.py
+++ b/server.py
@@ -4,62 +4,89 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator
+import json
+
+from dataclasses import asdict
+from typing import Dict, List, Union
+
+from api.api import CompletionRequest, OpenAiApiGenerator
+from api.models import get_model_info_list, retrieve_model_info
 
 from build.builder import BuilderArgs, TokenizerArgs
-from flask import Flask, jsonify, request, Response
+from flask import Flask, request, Response
 from generate import GeneratorArgs
 
-app = Flask(__name__)
-# Messages and gen are kept global so they can be accessed by the flask app endpoints.
-messages: list = []
-gen: OpenAiApiGenerator = None
 
-
-@app.route("/chat", methods=["POST"])
-def chat_endpoint():
+def create_app(args):
     """
-    Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
-    This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
+    Creates a flask app that can be used to serve the model as a chat API.
     """
-    data = request.get_json()
-
-    # Add user message to chat history
-    messages.append(data["messages"][-1])
-    prompt = messages[-1]["content"]
-
-    # Generate the assistant response
-    req = CompletionRequest(
-        model=gen.builder_args.checkpoint_path,
-        prompt=prompt,
-        temperature=0,
-        messages=[],
-    )
+    app = Flask(__name__)
+
+    gen: OpenAiApiGenerator = initialize_generator(args)
+
+    def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
+        """Recursively delete None values from a dictionary."""
+        if type(d) is dict:
+            return {k: _del_none(v) for k, v in d.items() if v}
+        elif type(d) is list:
+            return [_del_none(v) for v in d if v]
+        return d
+
+    @app.route("/chat", methods=["POST"])
+    def chat_endpoint():
+        """
+        Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
+        This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
+
+        ** Warning ** : Not all arguments of the CompletionRequest are consumed.
+
+        See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
+
+        If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
+        a single CompletionResponse object will be returned.
+        """
+
+        print(" === Completion Request ===")
+
+        # Parse the request in to a CompletionRequest object
+        data = request.get_json()
+        req = CompletionRequest(**data)
+
+        if data.get("stream") == "true":
+
+            def chunk_processor(chunked_completion_generator):
+                """Inline function for postprocessing CompletionResponseChunk objects.
+
+                Here, we just jsonify the chunk and yield it as a string.
+                """
+                for chunk in chunked_completion_generator:
+                    if (next_tok := chunk.choices[0].delta.content) is None:
+                        next_tok = ""
+                    print(next_tok, end="")
+                    yield json.dumps(_del_none(asdict(chunk)))
 
-    response = ""
+            return Response(
+                chunk_processor(gen.chunked_completion(req)),
+                mimetype="text/event-stream",
+            )
+        else:
+            response = gen.sync_completion(req)
 
-    def unwrap(completion_generator):
-        token_count = 0
-        for chunk_response in completion_generator:
-            content = chunk_response.choices[0].delta.content
-            if not gen.is_llama3_model or content not in set(
-                gen.tokenizer.special_tokens.keys()
-            ):
-                yield content if content is not None else ""
-            if content == gen.tokenizer.eos_id():
-                yield "."
-            token_count += 1
+            return json.dumps(_del_none(asdict(response)))
 
-    if data.get("stream") == "true":
-        return Response(unwrap(gen.completion(req)), mimetype="text/event-stream")
-    else:
-        for content in unwrap(gen.completion(req)):
-            response += content
+    @app.route("/models", methods=["GET"])
+    def models_endpoint():
+        return json.dumps(asdict(get_model_info_list(args)))
 
-    # Add assistant response to chat history
-    messages.append(AssistantMessage(content=response))
+    @app.route("/models/<model_id>", methods=["GET"])
+    def models_retrieve_endpoint(model_id):
+        if response := retrieve_model_info(args, model_id):
+            return json.dumps(asdict(response))
+        else:
+            return "Model not found", 404
 
-    return jsonify({"response": response})
+    return app
 
 
 def initialize_generator(args) -> OpenAiApiGenerator:
@@ -81,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator:
 
 
 def main(args):
-    global gen
-    gen = initialize_generator(args)
+    app = create_app(args)
     app.run()