-
Notifications
You must be signed in to change notification settings - Fork 230
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add warning comments referring to unimplemented functionality * JSON formatted response using OpenAI API types for server completion requests * Add models endpoint (#1000)
- Loading branch information
Showing
4 changed files
with
242 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
|
||
from dataclasses import dataclass | ||
from pwd import getpwuid | ||
from typing import List, Union | ||
|
||
from download import is_model_downloaded, load_model_configs | ||
|
||
"""Helper functions for the OpenAI API Models endpoint. | ||
See https://platform.openai.com/docs/api-reference/models for the full specification and details. | ||
Please create an issue if anything doesn't match the specification. | ||
""" | ||
|
||
|
||
@dataclass | ||
class ModelInfo: | ||
"""The Model object per the OpenAI API specification containing information about a model. | ||
See https://platform.openai.com/docs/api-reference/models/object for more details. | ||
""" | ||
|
||
id: str | ||
created: int | ||
owner: str | ||
object: str = "model" | ||
|
||
|
||
@dataclass | ||
class ModelInfoList: | ||
"""A list of ModelInfo objects.""" | ||
|
||
data: List[ModelInfo] | ||
object: str = "list" | ||
|
||
|
||
def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]: | ||
"""Implementation of the OpenAI API Retrieve Model endpoint. | ||
See https://platform.openai.com/docs/api-reference/models/retrieve | ||
Inputs: | ||
args: command line arguments | ||
model_id: the id of the model requested | ||
Returns: | ||
ModelInfo describing the specified if it is downloaded, None otherwise. | ||
""" | ||
if model_config := load_model_configs().get(model_id): | ||
if is_model_downloaded(model_id, args.model_directory): | ||
path = args.model_directory / model_config.name | ||
created = int(os.path.getctime(path)) | ||
owner = getpwuid(os.stat(path).st_uid).pw_name | ||
|
||
return ModelInfo(id=model_config.name, created=created, owner=owner) | ||
return None | ||
return None | ||
|
||
|
||
def get_model_info_list(args) -> ModelInfo: | ||
"""Implementation of the OpenAI API List Models endpoint. | ||
See https://platform.openai.com/docs/api-reference/models/list | ||
Inputs: | ||
args: command line arguments | ||
Returns: | ||
ModelInfoList describing all downloaded models. | ||
""" | ||
data = [] | ||
for model_id, model_config in load_model_configs().items(): | ||
if is_model_downloaded(model_id, args.model_directory): | ||
path = args.model_directory / model_config.name | ||
created = int(os.path.getctime(path)) | ||
owner = getpwuid(os.stat(path).st_uid).pw_name | ||
|
||
data.append(ModelInfo(id=model_config.name, created=created, owner=owner)) | ||
response = ModelInfoList(data=data) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.