Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix prompt_logprobs when SamplingParams.detokenize is set to False #5226

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
detokenize: bool,
example_prompts,
):
max_num_seqs = 256
Expand Down Expand Up @@ -48,7 +50,8 @@ def test_get_prompt_logprobs(
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_top_logprobs,
temperature=0.0)
temperature=0.0,
detokenize=detokenize)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)

Expand All @@ -65,11 +68,16 @@ def test_get_prompt_logprobs(
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")

if detokenize:
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
else:
assert output_text == ''
assert output_string_from_most_likely_tokens == [None] * max_tokens

# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
Expand Down Expand Up @@ -98,9 +106,10 @@ def test_get_prompt_logprobs(
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned "
" to the user.")
if detokenize:
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned"
" to the user.")

# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs
if (prompt_logprobs is not None
and seq_group.sampling_params.detokenize and self.detokenizer):
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if prompt_logprobs is not None:
if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
Expand Down
Loading