Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI API JSON formatted #995

Merged
merged 3 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 80 additions & 28 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import uuid
from abc import ABC
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from build.utils import device_sync

Expand Down Expand Up @@ -87,31 +87,39 @@ class StreamOptions:
include_usage: bool = False


@dataclass
class ResponseFormat:
type: Optional[str] = None


@dataclass
class CompletionRequest:
"""A full chat completion request.

See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
"""

messages: List[_AbstractMessage]
model: str
prompt: str
messages: Optional[List[_AbstractMessage]]
frequency_penalty: float = 0.0
temperature: float = 0.0
stop: Optional[List[str]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
echo: bool = False
frequency_penalty: float = 0.0
guided_decode_json_schema: str = None
guided_decode_json_schema_path: str = None
frequency_penalty: float = 0.0 # unimplemented
logit_bias: Optional[Dict[str, float]] = None # unimplemented
logprobs: Optional[bool] = None # unimplemented
top_logprobs: Optional[int] = None # unimplemented
max_tokens: Optional[int] = None # unimplemented
n: int = 1
presence_penalty: float = 0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
presence_penalty: float = 0 # unimplemented
response_format: Optional[ResponseFormat] = None # unimplemented
seed: Optional[int] = None # unimplemented
service_tier: Optional[str] = None # unimplemented
stop: Optional[List[str]] = None # unimplemented
stream: bool = False
stream_options: Optional[StreamOptions] = None # unimplemented
temperature: Optional[float] = 1.0 # unimplemented
top_p: Optional[float] = 1.0 # unimplemented
tools: Optional[List[Any]] = None # unimplemented
tool_choice: Optional[Union[str, Any]] = None # unimplemented
parallel_tool_calls: Optional[bool] = None # unimplemented
user: Optional[str] = None # unimplemented


@dataclass
Expand All @@ -121,10 +129,10 @@ class CompletionChoice:
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
"""

finish_reason: str
index: int
message: AssistantMessage
logprobs: Optional[List[Any]]
finish_reason: str = None
logprobs: Optional[List[Any]] = None


@dataclass
Expand All @@ -151,9 +159,9 @@ class CompletionResponse:
created: int
model: str
system_fingerprint: str
usage: UsageStats
object: str = "chat.completion"
service_tier: Optional[str] = None
usage: Optional[UsageStats] = None
object: str = "chat.completion"


@dataclass
Expand Down Expand Up @@ -193,8 +201,8 @@ class CompletionResponseChunk:
created: int
model: str
system_fingerprint: str
object: str = "chat.completion.chunk"
service_tier: Optional[str] = None
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None


Expand All @@ -220,10 +228,27 @@ def __init__(self, *args, **kwargs):
if self.draft_model is not None
else self.model.config.max_seq_length
)
# The System fingerprint is a unique identifier for the model and its configuration.
# Currently, this is not implemented in a
self.system_fingerprint = (
self.builder_args.device + type(self.builder_args.precision).__name__
)
Comment on lines +233 to +235
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment that this field doesn't match the spec, but is populated

We'll fix in a separate PR


def completion(self, completion_request: CompletionRequest):
def chunked_completion(self, completion_request: CompletionRequest):
"""Handle a chat completion request and yield a chunked response.

** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
Current treatment of parameters is described below.

- messages: The server consumes the final element of the array as the prompt.
- model: This has no impact on the server state, i.e. changing the model in the request
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
- temperature: This is used to control the randomness of the response.
- system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.

See https://github.com/pytorch/torchchat/issues/973 for more details.


Args:
completion_request: Request object with prompt and other parameters.

Expand All @@ -235,13 +260,16 @@ def completion(self, completion_request: CompletionRequest):

# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

idx = 0
buffer = []
encoded = self.encode_tokens(
completion_request.prompt, bos=True, device=self.builder_args.device
completion_request.messages[-1].get("content"),
bos=True,
device=self.builder_args.device,
)
generator_args = GeneratorArgs(
completion_request.prompt,
completion_request.messages[-1].get("content"),
encoded_prompt=encoded,
chat_mode=False,
)
Expand Down Expand Up @@ -291,21 +319,45 @@ def callback(x, *, done_generating=False):
choices=[choice_chunk],
created=int(time.time()),
model=completion_request.model,
system_fingerprint=uuid.UUID(int=uuid.getnode()),
system_fingerprint=self.system_fingerprint,
)
yield chunk_response
self.start_pos += y.size(0)
idx += 1

# Yield an ending chunk indicating the generation has completed.
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
end_chunk = CompletionChoiceChunk(
ChunkDelta(None, None, None), idx, finish_reason="stop"
)

yield CompletionResponseChunk(
id=str(id),
choices=[end_chunk],
created=int(time.time()),
model=completion_request.model,
system_fingerprint=uuid.UUID(int=uuid.getnode()),
system_fingerprint=self.system_fingerprint,
)

def sync_completion(self, request: CompletionRequest):
"""Handle a chat completion request and yield a single, non-chunked response"""
output = ""
for chunk in self.chunked_completion(request):
if not chunk.choices[0].finish_reason:
output += chunk.choices[0].delta.content

message = AssistantMessage(content=output)
return CompletionResponse(
id=str(uuid.uuid4()),
choices=[
CompletionChoice(
finish_reason="stop",
index=0,
message=message,
)
],
created=int(time.time()),
model=request.model,
system_fingerprint=self.system_fingerprint,
)

def _callback(self, x, *, buffer, done_generating):
Expand Down
86 changes: 86 additions & 0 deletions api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os

from dataclasses import dataclass
from pwd import getpwuid
from typing import List, Union

from download import is_model_downloaded, load_model_configs

"""Helper functions for the OpenAI API Models endpoint.

See https://platform.openai.com/docs/api-reference/models for the full specification and details.
Please create an issue if anything doesn't match the specification.
"""


@dataclass
class ModelInfo:
"""The Model object per the OpenAI API specification containing information about a model.

See https://platform.openai.com/docs/api-reference/models/object for more details.
"""

id: str
created: int
owner: str
object: str = "model"


@dataclass
class ModelInfoList:
"""A list of ModelInfo objects."""

data: List[ModelInfo]
object: str = "list"


def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
"""Implementation of the OpenAI API Retrieve Model endpoint.

See https://platform.openai.com/docs/api-reference/models/retrieve

Inputs:
args: command line arguments
model_id: the id of the model requested

Returns:
ModelInfo describing the specified if it is downloaded, None otherwise.
"""
if model_config := load_model_configs().get(model_id):
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
created = int(os.path.getctime(path))
owner = getpwuid(os.stat(path).st_uid).pw_name

return ModelInfo(id=model_config.name, created=created, owner=owner)
return None
return None


def get_model_info_list(args) -> ModelInfo:
"""Implementation of the OpenAI API List Models endpoint.

See https://platform.openai.com/docs/api-reference/models/list

Inputs:
args: command line arguments

Returns:
ModelInfoList describing all downloaded models.
"""
data = []
for model_id, model_config in load_model_configs().items():
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
created = int(os.path.getctime(path))
owner = getpwuid(os.stat(path).st_uid).pw_name

data.append(ModelInfo(id=model_config.name, created=created, owner=owner))
response = ModelInfoList(data=data)
return response
4 changes: 4 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,15 @@ def generate(
sequential_prefill=True,
callback=lambda x: x,
max_seq_length: int,
seed: Optional[int] = None,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
if seed:
torch.manual_seed(seed)

is_speculative = draft_model is not None
device, dtype = prompt.device, prompt.dtype

Expand Down
Loading
Loading