Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix min_tokens when eos_token_id is None #4389

Merged
merged 4 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens,
eos_token_id=0,
*,
stop_token_ids: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
Expand All @@ -216,7 +216,7 @@ def create_sampling_params(min_tokens,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
)
sampling_params.eos_token_id = eos_token_id
sampling_params.all_stop_token_ids.add(eos_token_id)
return sampling_params

def create_sequence_data(num_input=3, num_generated=0):
Expand Down Expand Up @@ -461,10 +461,7 @@ def run_test_case(*,
for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_row)):

tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:
tokens_to_check.extend(sampling_params.stop_token_ids)
tokens_to_check = set(tokens_to_check)
tokens_to_check = sampling_params.all_stop_token_ids

if should_penalize:
for token_id in tokens_to_check:
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,10 @@ def add_request(
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# Add the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)

Expand Down
14 changes: 6 additions & 8 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,17 @@ def _apply_min_tokens_penalty(

start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)
seqs_to_penalize.append(j)

if seqs_to_penalize:
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
Expand Down Expand Up @@ -645,7 +643,7 @@ def _sample(
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,
Expand Down
4 changes: 2 additions & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def __init__(
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# injected by the engine
self.eos_token_id = None
# eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids)

def _verify_args(self) -> None:
if self.n < 1:
Expand Down
Loading