Skip to content

Commit

Permalink
Fix prompt len in parallel sampling (#928)
Browse files Browse the repository at this point in the history
  • Loading branch information
yichuan520030910320 authored Aug 5, 2024
1 parent 399cad9 commit fd7926e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
21 changes: 11 additions & 10 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
responses.append(response)
return responses
else:
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
Expand Down Expand Up @@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):

def v1_chat_generate_response(request, ret, to_file=False):
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0

for idx, ret_item in enumerate(ret):
logprobs = False
Expand Down Expand Up @@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]

if to_file:
# to make the choice data json serializable
Expand All @@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
)

choices.append(choice_data)
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens

if to_file:
responses = []

Expand All @@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
responses.append(response)
return responses
else:
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
Expand Down
5 changes: 0 additions & 5 deletions test/srt/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ def run_completion(
prompt_arg = prompt_input
num_choices = 1

if parallel_sample_num:
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
# parallel sampling.
num_prompt_tokens *= parallel_sample_num

response = client.completions.create(
model=self.model,
prompt=prompt_arg,
Expand Down

0 comments on commit fd7926e

Please sign in to comment.