Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token based (or sequence of token based) repetition penalty exclusion #26902

Closed
teknium1 opened this issue Oct 18, 2023 · 14 comments
Closed

Token based (or sequence of token based) repetition penalty exclusion #26902

teknium1 opened this issue Oct 18, 2023 · 14 comments

Comments

@teknium1
Copy link

teknium1 commented Oct 18, 2023

Feature request

I started this issue in TGI but it applies to all inference code that has a form of rep penalty, will paste my feature request notes from there here as well, find original here: huggingface/text-generation-inference#1170

Hello, I would like to propose a feature that allows you to set a list of tokens, or even token sequences, that can be excluded from repetition penalty calculations.

The reasoning for this being that, given a prompt format for multiturn, such as:

user: abc
assistant: def
user: ghi
assistant: jkl

Or even worse, a format like ChatML, where it is in now standard case using <|im_end|> as a stopping token and included in every turn, it seems only logical that given these tokens all appear in every turn, that, especially in short token turn sequences, repetition penalty will destroy the validity of these prompt formats.

While I havent noticed this using Hermes 2, it may be solely because it has long responses, where, if avg turn length is very few tokens, the problem may become more prominent.

Motivation

image

Your contribution

The following is TGI's code, I haven't looked at transformer's code, but I assume the principles are the same:

class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
    This version allows for a separate value for each sample and runs inplace when possible.
    It doesn't validate inputs.

    Args:
        repetition_penalty (`List[float]`):
            The parameter for repetition penalty. 1.0 means no penalty. See [this
            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
    """

    def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
        self.penalty = penalty
        self.penalty_tensor = torch.tensor(
            penalty, dtype=dtype, device=device
        ).unsqueeze(1)

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        score = torch.where(
            score < 0, score * self.penalty_tensor, score / self.penalty_tensor
        )

        scores.scatter_(1, input_ids, score)
        return scores

    def filter(self, indices):
        self.penalty = [self.penalty[i] for i in indices]
        if any([x != 1.0 for x in self.penalty]):
            self.penalty_tensor = self.penalty_tensor[indices]
            return self
        return None

I'm thinking we take the input_id's before getting scored and simply replacing it with input_id's that remove any from some variable setting a list of token ids or token strings->id's.

@152334H
Copy link

152334H commented Oct 18, 2023

+1

@teknium1
Copy link
Author

an addendum to this for at least transformers, maybe TGI also, is if this system works, perhaps can add it as a tokenizer config setting similar to the chat templates so users dont have to implement this for whatever inference they do

@osanseviero
Copy link
Contributor

cc @gante @Rocketknight1

@teknium1
Copy link
Author

teknium1 commented Oct 18, 2023

for reference hf transformers' code for this is even simpler:

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        scores.scatter_(1, input_ids, score)
        return scores

https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L270C1-L325C22

@Rocketknight1
Copy link
Member

This is a really good point that I hadn't thought of! This code is mostly @gante's area but I think I can see how to implement this, so if you're busy let me know and I can take it!

@teknium1
Copy link
Author

Hey also I should add, it would be really beneficial if we can select whole sections of a prompt format to omit as well. Thinking more on this - for lets say, RAG or Memories, other things where you have context that the model should be drawing on as closely as possible, it seems rep penalty would interfere with its ability to properly recite from source material in context. This seems doable but probably harder to implement.

There are also much harder to deal with problems like the following:

User: What's your favorite color?
Assistant: Blue!
User: What did you say your favorite color was?
Assistant: hmm, Red!

Where this may be the model being dumb, or may be because of rep penalty. For solving this one, I really have no idea, since its so dynamic and situational..

@WolframRavenwolf
Copy link

WolframRavenwolf commented Oct 19, 2023

Yes, we generally need a better solution than simple (stupid) repetition penalty. There are many use cases where the LLM needs to repeat information verbatim, especially when generating code. ChatGPT/GPT-4 does that extremely well, where you talk back and forth to iterate over the code. Local LLMs wouldn't be able to handle that with regular repetition penalty settings.

Or imagine a situation like this (disregarding that LLMs aren't calculators ;)):

0+2=2, 2+0=2, 3-1=2, 4-2=2, 1+1=

With all those 2's already in the prompt, the answer is likely not what we'd expect.

@gante
Copy link
Member

gante commented Oct 19, 2023

Hey folks! 👋

Before further discussion, a recap of the repetition penalty:

  1. It was designed for greedy methods to avoid low-quality repetitions. This problem was not present in sampling, which is what most LLMs rely on;
  2. In my experience, and from what I've heard from users over time, reasonable values of repetition penalty (e.g. 1.2 or 1.3) usually prevent low-quality repetitions and not critical pieces of the generated text (like <|im_end|>);
  3. It is only applied at most once per token. The sequences the the the and the will have the same penalty for the token corresponding to the;
  4. It may negatively impact when used in problems where repetition is expected, like multi-turn chats or summarization. Playing with other parameters that are more closely related to the model's probability distribution of the next token (like top_k, temperature, and top_p) is often preferred -- e.g. increasing temperature makes repetition less likely.

Based on what I know, I would need stronger evidence, well beyond a few examples (that might be false positives since LLMs typically sample), to support your claims. I also agree with these two comments that are skeptical about this being a problem that the "improved" repetition penalty would solve (1 2).

However, this does not prevent you from using it with transformers! You can pass custom logit processor instances through the logits_processor argument in generate 🤗

In conclusion, I'm skeptical but open to being convinced otherwise with evidence :)

@WolframRavenwolf
Copy link

Thanks for the recap!

Since I'm doing a lot of model comparisons and tests with multi-turn chats, I use deterministic settings (do_sample=false with oobabooga's text-generation-webui or with llama.cpp/koboldcpp temperature=0, top_k=1, top_p=0, top_a=0) to eliminate as many random factors as possible. I'm using repetition penalty 1.18, range 0, no slope.

My intention is to test what the model considers the most likely generation, which isn't perfect of course, but outside of running an infinite number of gens and picking the average, it's the best I could come up with. Always open for better ideas, though.

Just so you know my setup and with which settings I observed the issues I consider as problems caused by repetition penalty for many months. If you think no repetition penalty would be better (now that llama.cpp's tokenizer bug that messes up EOS and other special tokens is fixed - ggml-org/llama.cpp#3538 - which could have contributed to the excessive repetition issues so many Llama 2 models exhibited), I'd happily test going without repetition penalty.

@gante
Copy link
Member

gante commented Oct 19, 2023

@StefanDanielSchwarz thank you for your reply :)

In a deterministic setup you should see improvements with a moderate repetition penalty like yours, as it is common for the model to repeat concepts (or even get into never-ending loops). The best would be a blind test with sampling, like it is done in lmsys' chatbot arena. After a few hundred evals, it should be clear whether it makes a difference or not to exclude special tokens from the repetition penalty or not. Keep in mind that, if the difference is minimal, less code is better!

@WolframRavenwolf
Copy link

@gante But how would you handle the use case of e. g. code generation? Imagine a pair programmer/co-pilot scenario which I use a lot with ChatGPT/GPT-4: Describe what program you want, LLM gives you the code, you tell it what to change, and after a lot of back-and-forth, it's usable. The slightest repetition penalty could ruin that, so we'd probably need a way to exempt code blocks from repetition penalty. Same for RAG/retrieved memories as @teknium1 mentioned.

@gante
Copy link
Member

gante commented Oct 20, 2023

@StefanDanielSchwarz It may indeed, but this is the first time I'm reading this issue (keeping in mind that I'm tagged in everything generate-related) and, from a numerical point of view, I don't think the repetition penalty is strong enough to block high-confidence cases like those. Nevertheless, I recommend experimenting with it, as I might be wrong :)

Regardless of my opinion, we need to see clear results before we add new code to transformers 🤗 Otherwise, the codebase will grow beyond our means to maintain it. This is why we have provided a code path for custom operations like the one this issue is proposing!

@WolframRavenwolf
Copy link

@gante Thanks for your feedback again. I agree with you about clear results being required. Hopefully this discussion has raised awareness of this (potential) issue.

So far I've been quite happy with repetition penalty 1.18 and my deterministic settings, and the problems I noticed might be attributed to other factors like improper tokenization, quantization, or model-specific quirks. So I'll keep my eyes open and hope others do the same, so that if there is an actual issue, it will eventually be proven and fixed.

Thank you all for your attention and please do keep up the great work! 😎👍

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants