Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for logprobs in OpenAI chat API #852

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions examples/usage/openai_parallel_sample.py
Original file line number Diff line number Diff line change
@@ -106,12 +106,24 @@
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=64,
max_tokens=1,
logprobs=True,
n=1,
top_logprobs=3,
)
print(response)

# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=1,
n=1,
)
print(response)

# Chat completion
response = client.chat.completions.create(
@@ -121,8 +133,21 @@
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=64,
max_tokens=1,
logprobs=True,
top_logprobs=3,
)
print(response)

# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=1,
n=4,
)
print(response)
83 changes: 68 additions & 15 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,9 @@
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage,
ChoiceLogprobs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
@@ -54,6 +56,7 @@
FileRequest,
FileResponse,
LogProbs,
TopLogprob,
UsageInfo,
)

@@ -70,7 +73,7 @@ def __init__(self, filename: str, purpose: str):
batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {}
## map file id to file path in SGlang backend
# map file id to file path in SGlang backend
file_id_storage: Dict[str, str] = {}


@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
failed_requests += len(file_request_list)

for idx, response in enumerate(responses):
## the batch_req here can be changed to be named within a batch granularity
# the batch_req here can be changed to be named within a batch granularity
response_json = {
"id": f"batch_req_{uuid.uuid4()}",
"custom_id": file_request_list[idx].get("custom_id"),
@@ -333,13 +336,19 @@ def v1_generate_request(all_requests):

prompts = []
sampling_params_list = []
return_logprobs = []
top_logprobs_nums = []
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
prompt = request.prompt
assert (
type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
prompts.append(prompt)
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
sampling_params_list.append(
{
"temperature": request.temperature,
@@ -361,6 +370,8 @@ def v1_generate_request(all_requests):
if len(all_requests) == 1:
prompt = prompts[0]
sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
else:
@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}

adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params_list,
return_logprob=all_requests[0].logprobs is not None
and all_requests[0].logprobs > 0,
top_logprobs_num=(
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
),
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
)
@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
logprobs = None

if to_file:
## to make the choise data json serializable
# to make the choise data json serializable
choice_data = {
"index": 0,
"text": text,
@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
## remain the same but if needed we can change that
# remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "text_completion",
"created": int(time.time()),
@@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = []
sampling_params_list = []
image_data_list = []
return_logprobs = []
top_logprobs_nums = []
for request in all_requests:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
@@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
stop = request.stop
image_data = None
texts.append(prompt)
return_logprobs.append(request.logprobs)
top_logprobs_nums.append(request.top_logprobs)
sampling_params_list.append(
{
"temperature": request.temperature,
@@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = texts[0]
sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
adapted_request = GenerateReqInput(
text=texts,
image_data=image_data,
sampling_params=sampling_params_list,
stream=request.stream,
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
@@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
total_completion_tokens = 0

for idx, ret_item in enumerate(ret):
logprobs = False
if isinstance(request, List) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, List)) and request.logprobs:
logprobs = True
if logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
token_logprobs = []
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
for top_token, top_logprob in logprobs.top_logprobs[0].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)

choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]

if to_file:
## to make the choice data json serializable
# to make the choice data json serializable
choice_data = {
"index": 0,
"message": {"role": "assistant", "content": ret_item["text"]},
"logprobs": None,
"logprobs": choice_logprobs,
"finish_reason": ret_item["meta_info"]["finish_reason"],
}
else:
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
logprobs=choice_logprobs,
finish_reason=ret_item["meta_info"]["finish_reason"],
)

choices.append(choice_data)
total_prompt_tokens = prompt_tokens
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
if to_file:
responses = []
@@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
## remain the same but if needed we can change that
# remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "chat.completion",
"created": int(time.time()),
22 changes: 20 additions & 2 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)


class TopLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float


class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[TopLogprob]


class ChoiceLogprobs(BaseModel):
# build for v1/chat/completions response
content: List[ChatCompletionTokenLogprob]


class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: str


class ChatCompletionResponse(BaseModel):