diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 40d054cd472b8..61720cccf50b4 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -12,6 +12,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) @pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size +@pytest.mark.parametrize("detokenize", [True, False]) def test_get_prompt_logprobs( hf_runner, vllm_runner, @@ -19,6 +20,7 @@ def test_get_prompt_logprobs( dtype, chunked_prefill_token_size: int, num_top_logprobs: int, + detokenize: bool, example_prompts, ): max_num_seqs = 256 @@ -48,7 +50,8 @@ def test_get_prompt_logprobs( vllm_sampling_params = SamplingParams(max_tokens=max_tokens, logprobs=num_top_logprobs, prompt_logprobs=num_top_logprobs, - temperature=0.0) + temperature=0.0, + detokenize=detokenize) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) @@ -65,11 +68,16 @@ def test_get_prompt_logprobs( top_logprob = next(iter(top_logprobs.values())) output_string_from_most_likely_tokens.append( top_logprob.decoded_token) - output_string_from_most_likely_tokens = "".join( - output_string_from_most_likely_tokens) - assert output_text == output_string_from_most_likely_tokens, ( - "The output text from the top logprob for each token position " - "should be the same as the output text in the result.") + + if detokenize: + output_string_from_most_likely_tokens = "".join( + output_string_from_most_likely_tokens) + assert output_text == output_string_from_most_likely_tokens, ( + "The output text from the top logprob for each token position " + "should be the same as the output text in the result.") + else: + assert output_text == '' + assert output_string_from_most_likely_tokens == [None] * max_tokens # The first prompt logprob is always None assert result.prompt_logprobs[0] is None @@ -98,9 +106,10 @@ def test_get_prompt_logprobs( hf_logprob[i][-1][token_id].item(), atol=1e-2, rtol=1e-2) - assert isinstance(sample_logprob.decoded_token, str), ( - "The token should be decoded by the time it is returned " - " to the user.") + if detokenize: + assert isinstance(sample_logprob.decoded_token, str), ( + "The token should be decoded by the time it is returned" + " to the user.") # Test if prompt logprobs are correctly set. for vllm_result in vllm_results: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 44de1d7ec5607..cad44f476f06e 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -60,10 +60,10 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] prompt_logprobs = output.prompt_logprobs - if (prompt_logprobs is not None - and seq_group.sampling_params.detokenize and self.detokenizer): - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs) + if prompt_logprobs is not None: + if seq_group.sampling_params.detokenize and self.detokenizer: + self.detokenizer.decode_prompt_logprobs_inplace( + seq_group, prompt_logprobs) if not seq_group.prompt_logprobs: # The first prompt token's logprob is None because it doesn't # have tokens that are precedent.