Skip to content

Commit

Permalink
JSON formatted response using OpenAI API types for server completion …
Browse files Browse the repository at this point in the history
…requests
  • Loading branch information
vmpuri committed Aug 5, 2024
1 parent 6303c8c commit 4e26b22
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 61 deletions.
99 changes: 70 additions & 29 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import uuid
from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from build.utils import device_sync

Expand Down Expand Up @@ -87,31 +87,39 @@ class StreamOptions:
include_usage: bool = False


@dataclass
class ResponseFormat:
type: Optional[str] = None


@dataclass
class CompletionRequest:
"""A full chat completion request.
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
"""

messages: List[_AbstractMessage]
model: str
prompt: str
messages: Optional[List[_AbstractMessage]]
frequency_penalty: float = 0.0
temperature: float = 0.0
stop: Optional[List[str]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
echo: bool = False
frequency_penalty: float = 0.0
guided_decode_json_schema: str = None
guided_decode_json_schema_path: str = None
frequency_penalty: float = 0.0 # unimplemented
logit_bias: Optional[Dict[str, float]] = None # unimplemented
logprobs: Optional[bool] = None # unimplemented
top_logprobs: Optional[int] = None # unimplemented
max_tokens: Optional[int] = None # unimplemented
n: int = 1
presence_penalty: float = 0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
presence_penalty: float = 0 # unimplemented
response_format: Optional[ResponseFormat] = None # unimplemented
seed: Optional[int] = None # unimplemented
service_tier: Optional[str] = None # unimplemented
stop: Optional[List[str]] = None # unimplemented
stream: bool = False
stream_options: Optional[StreamOptions] = None # unimplemented
temperature: Optional[float] = 1.0 # unimplemented
top_p: Optional[float] = 1.0 # unimplemented
tools: Optional[List[Any]] = None # unimplemented
tool_choice: Optional[Union[str, Any]] = None # unimplemented
parallel_tool_calls: Optional[bool] = None # unimplemented
user: Optional[str] = None # unimplemented


@dataclass
Expand All @@ -121,10 +129,10 @@ class CompletionChoice:
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
"""

finish_reason: str
index: int
message: AssistantMessage
logprobs: Optional[List[Any]]
finish_reason: str = None
logprobs: Optional[List[Any]] = None


@dataclass
Expand All @@ -151,9 +159,9 @@ class CompletionResponse:
created: int
model: str
system_fingerprint: str
usage: UsageStats
object: str = "chat.completion"
service_tier: Optional[str] = None
usage: Optional[UsageStats] = None
object: str = "chat.completion"


@dataclass
Expand Down Expand Up @@ -193,8 +201,8 @@ class CompletionResponseChunk:
created: int
model: str
system_fingerprint: str
object: str = "chat.completion.chunk"
service_tier: Optional[str] = None
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None


Expand All @@ -220,8 +228,13 @@ def __init__(self, *args, **kwargs):
if self.draft_model is not None
else self.model.config.max_seq_length
)
# The System fingerprint is a unique identifier for the model and its configuration.
# Currently, this is not implemented in a
self.system_fingerprint = (
self.builder_args.device + type(self.builder_args.precision).__name__
)

def completion(self, completion_request: CompletionRequest):
def chunked_completion(self, completion_request: CompletionRequest):
"""Handle a chat completion request and yield a chunked response.
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
Expand All @@ -230,7 +243,8 @@ def completion(self, completion_request: CompletionRequest):
- messages: The server consumes the final element of the array as the prompt.
- model: This has no impact on the server state, i.e. changing the model in the request
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
- temperature: This is used to control the randomness of the response. The server will use the temperature
- temperature: This is used to control the randomness of the response.
- system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
See https://github.com/pytorch/torchchat/issues/973 for more details.
Expand All @@ -246,13 +260,16 @@ def completion(self, completion_request: CompletionRequest):

# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

idx = 0
buffer = []
encoded = self.encode_tokens(
completion_request.prompt, bos=True, device=self.builder_args.device
completion_request.messages[-1].get("content"),
bos=True,
device=self.builder_args.device,
)
generator_args = GeneratorArgs(
completion_request.prompt,
completion_request.messages[-1].get("content"),
encoded_prompt=encoded,
chat_mode=False,
)
Expand Down Expand Up @@ -302,21 +319,45 @@ def callback(x, *, done_generating=False):
choices=[choice_chunk],
created=int(time.time()),
model=completion_request.model,
system_fingerprint=uuid.UUID(int=uuid.getnode()),
system_fingerprint=self.system_fingerprint,
)
yield chunk_response
self.start_pos += y.size(0)
idx += 1

# Yield an ending chunk indicating the generation has completed.
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
end_chunk = CompletionChoiceChunk(
ChunkDelta(None, None, None), idx, finish_reason="stop"
)

yield CompletionResponseChunk(
id=str(id),
choices=[end_chunk],
created=int(time.time()),
model=completion_request.model,
system_fingerprint=uuid.UUID(int=uuid.getnode()),
system_fingerprint=self.system_fingerprint,
)

def sync_completion(self, request: CompletionRequest):
"""Handle a chat completion request and yield a single, non-chunked response"""
output = ""
for chunk in self.chunked_completion(request):
if not chunk.choices[0].finish_reason:
output += chunk.choices[0].delta.content

message = AssistantMessage(content=output)
return CompletionResponse(
id=str(uuid.uuid4()),
choices=[
CompletionChoice(
finish_reason="stop",
index=0,
message=message,
)
],
created=int(time.time()),
model=request.model,
system_fingerprint=self.system_fingerprint,
)

def _callback(self, x, *, buffer, done_generating):
Expand Down
4 changes: 4 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,15 @@ def generate(
sequential_prefill=True,
callback=lambda x: x,
max_seq_length: int,
seed: Optional[int] = None,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
if seed:
torch.manual_seed(seed)

is_speculative = draft_model is not None
device, dtype = prompt.device, prompt.dtype

Expand Down
81 changes: 49 additions & 32 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,36 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator
import json

from dataclasses import asdict
from typing import Dict, List, Union

from api.api import AssistantMessage, CompletionRequest, OpenAiApiGenerator, UserMessage

from build.builder import BuilderArgs, TokenizerArgs
from flask import Flask, jsonify, request, Response
from flask import Flask, request, Response
from generate import GeneratorArgs


"""
Creates a flask app that can be used to serve the model as a chat API.
"""
app = Flask(__name__)
# Messages and gen are kept global so they can be accessed by the flask app endpoints.
messages: list = []
gen: OpenAiApiGenerator = None


def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
"""Recursively delete None values from a dictionary."""
if type(d) is dict:
return {k: _del_none(v) for k, v in d.items() if v}
elif type(d) is list:
return [_del_none(v) for v in d if v]
return d


@app.route("/chat", methods=["POST"])
def chat_endpoint():
"""
Expand All @@ -26,45 +44,44 @@ def chat_endpoint():
See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
a single CompletionResponse object will be returned.
"""
data = request.get_json()

# Add user message to chat history
messages.append(data["messages"][-1])
prompt = messages[-1]["content"]

# Generate the assistant response
req = CompletionRequest(
model=gen.builder_args.checkpoint_path,
prompt=prompt,
temperature=0,
messages=[],
)
print(" === Completion Request ===")

response = ""
# Parse the request in to a CompletionRequest object
data = request.get_json()
req = CompletionRequest(**data)

def unwrap(completion_generator):
token_count = 0
for chunk_response in completion_generator:
content = chunk_response.choices[0].delta.content
if not gen.is_llama3_model or content not in set(
gen.tokenizer.special_tokens.keys()
):
yield content if content is not None else ""
if content == gen.tokenizer.eos_id():
yield "."
token_count += 1
# Add the user message to our internal message history.
messages.append(UserMessage(**req.messages[-1]))

if data.get("stream") == "true":
return Response(unwrap(gen.completion(req)), mimetype="text/event-stream")

def chunk_processor(chunked_completion_generator):
"""Inline function for postprocessing CompletionResponseChunk objects.
Here, we just jsonify the chunk and yield it as a string.
"""
messages.append(AssistantMessage(content=""))
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
messages[-1].content += next_tok
print(next_tok, end="")
yield json.dumps(_del_none(asdict(chunk)))

return Response(
chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream"
)
else:
for content in unwrap(gen.completion(req)):
response += content
response = gen.sync_completion(req)

# Add assistant response to chat history
messages.append(AssistantMessage(content=response))
messages.append(response.choices[0].message)
print(messages[-1].content)

return jsonify({"response": response})
return json.dumps(_del_none(asdict(response)))


def initialize_generator(args) -> OpenAiApiGenerator:
Expand Down

0 comments on commit 4e26b22

Please sign in to comment.