Skip to content

Commit

Permalink
Fix ExponentialDecayLengthPenalty negative logits issue (#25594)
Browse files Browse the repository at this point in the history
* Fix issues in test_exponential_decay_length_penalty

Fix tests which were broken and add validation of negative scores.

Current test didn't take into account that ExponentialDecayLengthPenalty updates the score inplace, resulting in updates to base tested Tensor.

In addition, the gt assert had empty Tensors due to indexing along the batch dimension.

Test is currently expected to fail to show ExponentialDecayLengthPenalty issues with negative scores

* Fix ExponentialDecayLengthPenalty negative logits issue

In cases where the scores are negative, ExponentialDecayLengthPenalty decreases the score of eos_token_id instead of increasing it.
To fix this issue we compute the penalty of the absolute value and add it to the original score.

* Add examples for ExponentialDecayLengthPenalty

* Fix styling issue in ExponentialDecayLengthPenalty doc

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* Style and quality fix

* Fix example outputs

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
pokjay and ArthurZucker authored Sep 12, 2023
1 parent d65c4a4 commit 6acc27e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
54 changes: 51 additions & 3 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,8 +1297,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class ExponentialDecayLengthPenalty(LogitsProcessor):
r"""
[`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been
reached.
[`LogitsProcessor`] that exponentially increases the score of the `eos_token_id` after `start_index` has been
reached. This allows generating shorter sequences without having a hard cutoff, allowing the `eos_token` to be
predicted in a meaningful position.
Args:
exponential_decay_length_penalty (`tuple(int, float)`):
Expand All @@ -1308,6 +1309,51 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
input_ids_seq_length (`int`):
The length of the input sequence.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> text = "Just wanted to let you know, I"
>>> inputs = tokenizer(text, return_tensors="pt")
>>> # Generate sequences without exponential penalty. We want short sentences, so we limit max_length=30
>>> # see that the answer tends to end abruptly
>>> outputs = model.generate(**inputs, do_sample=True, temperature=0.9, max_length=30, pad_token_id=50256)
>>> print(tokenizer.batch_decode(outputs)[0])
Just wanted to let you know, I'm not even a lawyer. I'm a man. I have no real knowledge of politics. I'm a
>>> # Generate sequences with exponential penalty, we add the exponential_decay_length_penalty=(start_index, decay_factor)
>>> # We see that instead of cutting at max_tokens, the output comes to an end before (at 25 tokens) and with more meaning
>>> # What happens is that starting from `start_index` the EOS token score will be increased by decay_factor exponentially
>>> outputs = model.generate(
... **inputs,
... do_sample=True,
... temperature=0.9,
... max_length=30,
... pad_token_id=50256,
... exponential_decay_length_penalty=(15, 1.6),
... )
>>> print(tokenizer.batch_decode(outputs)[0])
Just wanted to let you know, I've got a very cool t-shirt educating people on how to use the Internet<|endoftext|>
>>> # Generate sequences with smaller decay_factor, still improving the hard cutoff mid-sentence
>>> outputs = model.generate(
... **inputs,
... do_sample=True,
... temperature=0.9,
... max_length=30,
... pad_token_id=50256,
... exponential_decay_length_penalty=(15, 1.05),
... )
>>> print(tokenizer.batch_decode(outputs)[0])
Just wanted to let you know, I've been working on it for about 6 months and now it's in Alpha.<|endoftext|>
```
"""

def __init__(
Expand All @@ -1327,7 +1373,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
for i in self.eos_token_id:
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
penalty_idx = cur_len - self.regulation_start
# To support negative logits we compute the penalty of the absolute value and add to the original logit
scores[:, i] = scores[:, i] + torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
return scores


Expand Down
19 changes: 12 additions & 7 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,18 +717,23 @@ def test_exponential_decay_length_penalty(self):

# check that penalty is not applied before start
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_start = length_decay_processor(input_ids, scores)
scores_before_start = torch.clone(scores) # clone scores as precessor updates them inplace
scores_before_start = length_decay_processor(input_ids, scores_before_start)
self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist())

# check that penalty is applied after start
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_after_start = length_decay_processor(input_ids, scores)
self.assertTrue(
torch.gt(
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)
scores_after_start = torch.clone(scores) # clone scores as precessor updates them inplace
scores_after_start = length_decay_processor(input_ids, scores_after_start)
self.assertTrue(torch.gt(scores_after_start[:, eos_token_id], scores[:, eos_token_id]).all())

# check the penalty increases negative scores
input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size)
scores = torch.neg(self._get_uniform_logits(batch_size, vocab_size))
scores_after_start = torch.clone(scores) # clone scores as precessor updates them inplace
scores_after_start = length_decay_processor(input_ids, scores_after_start)
self.assertTrue(torch.gt(scores_after_start[:, eos_token_id], scores[:, eos_token_id]).all())

def test_normalization(self):
input_ids = None
Expand Down

0 comments on commit 6acc27e

Please sign in to comment.