Skip to content

Commit

Permalink
Add token_check endpoint to OpenAI API server (#1384)
Browse files Browse the repository at this point in the history
Co-authored-by: Kalila <[email protected]>
  • Loading branch information
digisomni and Kalila authored May 21, 2023
1 parent dc06bea commit 32fc382
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ output
# Data
*.pkl
*.csv

# Build
build
11 changes: 10 additions & 1 deletion fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class UsageInfo(BaseModel):

class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
Expand Down Expand Up @@ -100,6 +100,15 @@ class ChatCompletionStreamResponse(BaseModel):
model: str
choices: List[ChatCompletionResponseStreamChoice]

class TokenCheckRequest(BaseModel):
model: str
prompt: str
max_tokens: int

class TokenCheckResponse(BaseModel):
fits: bool
tokenCount: int
contextLength: int

class EmbeddingsRequest(BaseModel):
model: Optional[str] = None
Expand Down
37 changes: 36 additions & 1 deletion fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
ModelCard,
ModelList,
ModelPermission,
TokenCheckRequest,
TokenCheckResponse,
UsageInfo,
)

Expand All @@ -67,7 +69,7 @@ class AppSettings(BaseSettings):

def create_error_response(code: int, message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, code=code).dict(), status_code=500
ErrorResponse(message=message, code=code).dict(), status_code=400
)


Expand Down Expand Up @@ -280,6 +282,39 @@ async def show_available_models():
return ModelList(data=model_cards)


# TODO: Have check_length and count_tokens share code.
@app.post("/v1/token_check")
async def count_tokens(request: TokenCheckRequest):
"""
Checks the token count against your message
This is not part of the OpenAI API spec.
"""
async with httpx.AsyncClient() as client:
worker_addr = await _get_worker_address(request.model, client)

response = await client.post(
worker_addr + "/model_details",
headers=headers,
json={},
timeout=WORKER_API_TIMEOUT,
)
context_len = response.json()["context_length"]

response = await client.post(
worker_addr + "/count_token",
headers=headers,
json={"prompt": request.prompt},
timeout=WORKER_API_TIMEOUT,
)
token_num = response.json()["count"]

can_fit = True
if token_num + request.max_tokens > context_len:
can_fit = False

return TokenCheckResponse(fits=can_fit, contextLength=context_len, tokenCount=token_num)


@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Creates a completion for the chat message"""
Expand Down

0 comments on commit 32fc382

Please sign in to comment.