Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ BugFix ] Prompt Logprobs Detokenization #6223

Merged
merged 22 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def create_dummy_logprobs(
} for token_id in complete_sequence_token_ids]


def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
# logprob for the first prompt token is not defined.
return create_dummy_logprobs(complete_sequence_token_ids)[1:]


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
Expand Down Expand Up @@ -178,8 +184,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True])
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes prompt logprobs correctly."""
Expand All @@ -192,19 +197,73 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
decoded_prompt_logprobs = dummy_logprobs

if skip_special_tokens:
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids[1:]
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
text = tokenzier.decode(token_ids,
skip_special_tokens=skip_special_tokens)
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that this will only be true if we skip
# special tokens.
assert complete_sequence == "".join([
logprobs[token_id].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
assert text == "".join([
logprobs[token_id].decoded_token
for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs)
])
assert complete_sequence != "".join([
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs)
])


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
def test_decode_logprobs_regression(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_decode_logprobs_regression(
def test_decode_logprobs_chunked_prefill(

vllm_runner,
model,
chunked_prefill_token_size: int,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size

with vllm_runner(model,
dtype="half",
max_logprobs=5,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:

vllm_sampling_params = SamplingParams(max_tokens=10,
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)

for idx, result in enumerate(vllm_results):
assert result.prompt_logprobs is not None
assert result.prompt_logprobs[0] is None

# Compared detokenized prompts ids to original prompt.
generated_string = ""
for (prompt_token,
prompt_logprobs) in zip(result.prompt_token_ids[1:],
result.prompt_logprobs[1:]):
# prompt_logprobs is a dict of the token_id: logprob
# We select the token_id corresponding to the actual prompt
# Decoded token in the detokenized string corresponding to this
# prompt token.
generated_string += prompt_logprobs[prompt_token].decoded_token

assert generated_string == example_prompts[idx], (
"Detokenized prompt logprobs do not match original prompt")
19 changes: 14 additions & 5 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,23 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs

# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if prompt_logprobs is not None:
if not seq_group.prompt_logprobs:
prompt_logprobs = [None] + prompt_logprobs
seq_group.prompt_logprobs = []

if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group.prompt_logprobs = [None]
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))

seq_group.prompt_logprobs.extend(prompt_logprobs)

def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
Expand Down
16 changes: 12 additions & 4 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ def get_tokenizer_for_seq(self,
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)

def decode_prompt_logprobs_inplace(
self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: List[Optional[Dict[
int, Logprob]]],
position_offset: int) -> None:
"""Decodes the logprobs for the prompt of a sequence group.

Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).

Returns:
The prompt logprobs with the decoded tokens.
Expand All @@ -47,8 +50,13 @@ def decode_prompt_logprobs_inplace(
next_iter_tokens: List[str] = []
prev_tokens = None

for token_position, prompt_logprobs_for_token in enumerate(
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
prompt_logprobs):

# Absolute token position equals the index in the logprobs
# list plus the offset of the entire logprobs list relative
# to the start of the sequence.
token_position = token_position_in_logprob + position_offset
if not prompt_logprobs_for_token:
continue
for token_id, sample_logprob in prompt_logprobs_for_token.items():
Expand Down
Loading