Skip to content

Commit

Permalink
[ BugFix ] Prompt Logprobs Detokenization (vllm-project#6223)
Browse files Browse the repository at this point in the history
Co-authored-by: Zifei Tong <[email protected]>
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
2 people authored and Alvant committed Oct 26, 2024
1 parent be7ccb6 commit 0977231
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 32 deletions.
5 changes: 4 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ steps:

- label: Engine Test
mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization

- label: Entrypoints Test
mirror_hardwares: [amd]
Expand Down
109 changes: 87 additions & 22 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, List, Optional

import pytest
from transformers import AutoTokenizer
Expand Down Expand Up @@ -139,6 +139,15 @@ def create_dummy_logprobs(
} for token_id in complete_sequence_token_ids]


def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]
) -> List[Optional[Dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: List[Optional[Dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
Expand Down Expand Up @@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,

@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True])
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
detokenizer: Detokenizer):
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
sampling_params = SamplingParams(skip_special_tokens=True,
prompt_logprobs=1)

# Run sequentially.
Expand All @@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
decoded_prompt_logprobs = dummy_logprobs
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group,
dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
1:] # type: ignore

if skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that this will only be true if we skip
# special tokens.
assert complete_sequence == "".join([
logprobs[token_id].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])
assert complete_sequence != "".join([
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
text = text_full[len(text_first):]

# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that the first logprob is None.
assert text == "".join([
logprobs[token_id].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
])


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_prompt_logprobs_chunked_prefill(
vllm_runner,
model,
chunked_prefill_token_size: int,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size

with vllm_runner(model,
dtype="half",
max_logprobs=5,
gpu_memory_utilization=0.5,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:

vllm_sampling_params = SamplingParams(max_tokens=10,
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)

for idx, result in enumerate(vllm_results):
assert result.prompt_logprobs is not None
assert result.prompt_logprobs[0] is None

# Compared detokenized prompts ids to original prompt.
generated_string = ""
for (prompt_token,
prompt_logprobs) in zip(result.prompt_token_ids[1:],
result.prompt_logprobs[1:]):
# prompt_logprobs is a dict of the token_id: logprob
# We select the token_id corresponding to the actual prompt
# Decoded token in the detokenized string corresponding to this
# prompt token.
generated_string += prompt_logprobs[prompt_token].decoded_token

assert generated_string == example_prompts[idx], (
"Detokenized prompt logprobs do not match original prompt")
19 changes: 14 additions & 5 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,23 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs

# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []

if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group.prompt_logprobs = [None]
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))

seq_group.prompt_logprobs.extend(prompt_logprobs)

def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Expand Down
16 changes: 12 additions & 4 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ def get_tokenizer_for_seq(self,
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)

def decode_prompt_logprobs_inplace(
self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[
int, Logprob]]],
position_offset: int) -> None:
"""Decodes the logprobs for the prompt of a sequence group.
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
Returns:
The prompt logprobs with the decoded tokens.
Expand All @@ -47,8 +50,13 @@ def decode_prompt_logprobs_inplace(
next_iter_tokens: List[str] = []
prev_tokens = None

for token_position, prompt_logprobs_for_token in enumerate(
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
prompt_logprobs):

# Absolute token position equals the index in the logprobs
# list plus the offset of the entire logprobs list relative
# to the start of the sequence.
token_position = token_position_in_logprob + position_offset
if not prompt_logprobs_for_token:
continue
for token_id, sample_logprob in prompt_logprobs_for_token.items():
Expand Down

0 comments on commit 0977231

Please sign in to comment.