-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Classifier-Free Guidance sampling #24536
Comments
cc @gante |
Hey @Vermeille 👋 I have the impression that our MusicGen PR (still open, expected to get merged soon) introduces the bulk of the logic to make it happen -- see this file It is the same thing with a slightly different code implementation, correct? In the MusicGen PR, the model does a forward pass with 2x the batch size, where half of the batch corresponds to the unprompted tokens |
Indeed @gante ! I don't fully get how the 2x batch size thing works, but if it does, it's cool.
|
cc @sanchit-gandhi, who's probably better equipped to comment on potential differences :) |
Hey @Vermeille - thanks for the comprehensive write-up! Just a clarifying question: in your implementation, how do you construct the token ids for the model based on the conditional ids and the un-conditional ones? You mention:
Which suggests you concatenate them together in the same batch item? In MusicGen (and also the HF Diffusers library for models like Stable Diffusion), we construct our input ids by concatenating the input ids for the conditional prompt and the un-conditional prompt along the batch dimension ( input_ids = torch.concatenate([conditional_ids, unconditional_ids], dim=0) This is what's referred to by the 2x batch size 'trick' (concatenating the conditional prompt and unconditional prompt over the batch dim). There's no restriction to how these unconditional ids are formed - they can be from a 'null' input, or from a negative prompt. So we can do negative prompting in exactly the way you've described. When we run our model forward, the logits for the first half of the batch corresponds to the conditional prompt, and the second half to the unconditional prompt (or negative prompt if we use one). By splitting along the batch dim, we can partition the conditional logits and the unconditional ones: conditional_logits, unconditional_logits = torch.split(logits, batch_size // 2) -> we then perform our weighted sum over the conditional and unconditional logits for CFG. Hope that explains how the 2x batch size trick works - would be keen to hear whether this aligns with how you've run CFG in your experiments. Regarding implementing a new logits processor, we'd probably want to add this new logits processor when the time comes for integrating the model you've worked on into Have you trained a new model that uses this processor? Or built on-top of an existing one? (if it's the latter, then adding the CFG logits processor standalone makes sense, otherwise let's integrate it all in one go) |
Thank you for your detailed answer @sanchit-gandhi ! The part I'm the most unclear with regarding the 2x batch trick is how the sampling happen. Do you actually sample the same continuation token for the conditional and unconditional branch, or do they diverge in their own direction (which would be weird imho)? Regarding the integration, there is no need to train models to support CFG, it works out of the box. The paper will be out in few days, but as you can see on the figures, we employed it with LLaMA models, all Pythias, GPT-2 family, and even GPT4All. We don't train a new model. It's meant to be an addition to the .generate() method that is totally model agnostic and don't need training nor finetuning. Hence the PR with the standalone logits processor :) |
Maybe this helps! Pre-processing:
Forward pass:
CFG:
Sampling:
How have you been getting the conditional and unconditional logits in your experiments? Through two forward passes? (one with the conditional inputs and then a second with the unconditional ones). This batch size concatenation trick means you only have to run one forward pass, but with 2x the batch size The only pain point I see with getting this work in
Very cool indeed! Would be nice to have this as a standalone PR then as suggested |
Thank you!
I'm happy to address the changes that have to be made to contribute this into the lib :) |
Awesome - feel free to open a PR and tag myself and @gante! How do you do it without the 2x batch size trick? Do you do two forward passes? Just asking in case there's a simpler way we can integrate this! |
(catching up on the paper and thinking a bit about usage experience -- will comment tomorrow with specific suggestions, but I think @Vermeille's suggested implementation above will be pretty close to a great user experience with minimal compute overhead) |
here is an alternative implementation we used for some of our other experiments in the paper, for your consideration. it was designed with huggingface's typical |
Yes. Two consecutive passes. Which is indeed not that great wrt latency. |
Would be great to have both the 2x batch size and two forward passes. Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines (unless I missunderstood) |
So given you already have this ( https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L1070 ) What do you want me to add / change in the PR? |
This is correct: our focus was on getting the best results for a fixed amount of VRAM in our experiments. Hence it didn't occur to us to simply 2x the batch size. I agree that having this be togglable is a good idea and don't have any preference about the default. |
The application to LLMs seems more of a situational sampling technique. With smaller conditional generative models like MusicGen, trained from-scratch with (explicit) condition dropout, it's practically part of the model. MusicGen isn't the first AR Transformer here, last year's DALL-E Mega already did it (itself inspired by https://twitter.com/RiversHaveWings/status/1478093658716966912 ), and in these models it's essential for performance. So I'd expect "batch size 1 dramatically underutilizes available resources" to be the more common case.
Depending on model and hardware, "biggest batch size that fits" isn't necessarily optimal. On decent hardware, you can hit optimal compute utilisation before VRAM limits with batched inference in smaller models. Normalizing the summands, then interpolating with the original scores is intriguing. If adding this to the CFG implementation that's now in Transformers is still being considered, this would be unexpected as default behavior though. In diffusion models, it's not applicable, and in sequence prediction, I've only seen people combine the unnormalized scores. |
This is a technique we borrowed from Common Diffusion Noise Schedules and Sample Steps are Flawed they call CFG Rescale. You can see Imagen doing some normalizing trick too.
That's what we started with, and our results were a little bit worse. |
This method is interesting to implement from an engineering and maintenance point of view! The simplest approach would be to proceed as @Vermeille suggested: add a logits processor that calls a model forward pass for the unconditional part of the input. It would be a small self-contained piece of code, which means low long-term maintenance on our end. On the negative side, we have the 2x latency, which is more impactful than the extra VRAM (IMO). If we go the 2x batch size route, we need to implement a function like We have a plan to reorganize How about we go with @Vermeille's proposal now, which will make CFG sampling available this week with low overhead on our end, and we implement the 2x batch size version after the |
Expect a PR in few hours. Thank you for your interest and answers! |
@gante There is a name clash for the arguments to .generate(). For this PR, unless instructed otherwise before I submit it, |
@Vermeille Adding more (and partially redundant) parameterization is highly undesirable, and we'd want to favor the more general case (yours). You also have the additional requirement of renormalizing the logits before applying your logits processor. Fortunately, we haven't officially released a Let's try to fit everything together -- here's my suggestion:
This way the two strategies can coexist, share the argument, and not clash 🤗 |
Great! Thank you for the walkthrough. On it. |
Wait @gante, integrating it after the LogitNormalization is not something we want: all the prior processing (temperature, top_p, etc), will be used only on the conditional branch and not the unconditional, and will be executed before computing the CFG logits. To be fair, we haven't tested this transformation order, but being asymmetrical like this scares me. And this is is even invalid. Top-k/p may not even select the same tokens in both branches, so that will misbehave. I'm afraid I can't do that. CFG has to happen as one of the first logitprocessor |
@Vermeille looking at your code example above, I didn't notice it already had normalization inside the processor. My bad -- feel free to add it as the 1st one :) (will edit my comment above accordingly, for clarity) |
So this is the code I got to get it working. It is just a hack but if you want to playwith it just use this code from transformers import LogitsWarper
import torch
from torch.nn import functional as F
device = 'cpu'
if torch.has_cuda:
device = 'cuda'
class CFGLogits(LogitsWarper):
def __init__(self, cfg, inputs, model, verbose=True):
self.cfg = cfg
self.inputs = inputs
self.model = model
self.out = None
self.verbose = verbose
def __call__(self, input_ids, scores):
if self.cfg == 1:
return F.log_softmax(scores, dim=-1)
scores = F.log_softmax(scores, dim=-1)
if self.out is None:
self.out = self.model(self.inputs.to(device), use_cache=True)
else:
self.out = self.model(input_ids[:, -1:],
use_cache=True,
past_key_values=self.out.past_key_values)
unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
out = self.cfg * (scores - unconditional_logits) + unconditional_logits
out = F.log_softmax(out, dim=-1)
return 0.7 * out + 0.3 * scores
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
prompt = "Salve, dispiculi."
inputs = tokenizer(prompt, return_tensors='pt')
model.to(device)
outputs = model.generate(
input_ids=inputs['input_ids'].to(device),
attention_mask=inputs['attention_mask'].to(device),
max_new_tokens=125,
logits_processor=LogitsProcessorList([
# inputs_cfg usually is the last token of the prompt but there are
# possibilities of negative prompting that are explored in the paper
CFGLogits(3, inputs['input_ids'], model),
TemperatureLogitsWarper(0.8),
TopPLogitsWarper(0.95),
]),
do_sample=True,
)
print(tokenizer.decode(outputs[0])) This worked on my end |
@grantCelley Pythia models are trained on English. I'm really confused by what you're trying to achieve there. |
I was just trying to get it to work. Also it does continue in latin for a little which is interesting then goes into a romance language. But it just showed how to do it. I didn't realize that you updated the original codeblock. |
Ok this helped, generation for the same amount of tokens takes longer now, is this expected? Vanilla / no CFG, 512 token / 3 min
CFG, neg_token = last token, cfg_scale=1.5, 512 token / 5 min
CFG, neg_token = last token, cfg_scale=1.25, 512 token / 5 min
|
@grantCelley shouldnt a negative prompt of 'Latin' prohibit latin output? Do I misunderstand the concept of negative prompts? |
Yes, there are two forward passes per token now.
You are correct |
It is hard to say what negative prompt does in certain terms. I had it generate a poem and specified negative prompt as happy and it used somehow gloomy language and vice versa - so it "does" work, but beyond that I think only further experimentation will tell. |
Yes. Neg prompts in language are somewhat harder to pull off than in vision. Especially because the continuation should be kinda grammatical with the neg prompt too. Not gonna lie, we were under time constraints and having a clear neg prompt methodology was unsolved in that time frame. But we're working on it, and the example in the first post works.
Hard to say yet, but it should depend on the guidance strength (decrease the rescale_factor as you increase the guidance strength)
from the paper:
|
Thanks for explanation. |
Because the This implies to me that the two things should be separate concepts, with separate implementations...but if you (reasonably) wanted to use both focus-on-first-prompt, and negative prompting, it would be compute expensive to do them separately. That said, I do feel a little like the 'adding them back in' is a fudge-factor, trying to reduce the effect slightly. But I don't understand the math symbology in the original paper very well, so I'm very cautious about that. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
It is merged, feel free to install from |
Can you provide sample code on how to use classifier free guidance? |
Here are the docs @sersoage - you can enable CFG by passing |
thanks!!! |
@Vermeille Thanks for the code! |
@sakrat-az WE set the last token of the prompt as the negative prompt |
|
That doesn't exist. You necessarily need something negative or unconditional in the equation. Last token of the prompt is the closest way to emulate unconditional I've found. |
Once can you check my code, so I wanted to focus on the word "France". @Vermeille |
Then you did the opposite. France has to be only in the positive prompt if you want to focus on it. Here you try to sample away from France. |
@Vermeille, but I have also changed the code for the call() function. Did you check that too? |
# Description Implement classifier-free guidance function based on vLLM. The author of this paper implements this function in huggingface-transformers: huggingface/transformers#24536. The pseudo-code ``` conditional_logits = log_softmax(model(positive_prompt)) unconditional_logits = log_softmax(model(negative_prompt)) logits = unconditional_logits + cfg_scale * (conditional_logits - unconditional_logits) next_token = do_sample(logits) positive_prompt.append(next_token) negative_prompt.append(next_token) ``` usage in FlagScale can reference `tests/unit_tests/test_classifier_free_guidance.py`
hey, how do you convert log(softmax(x) to prob? If we apply exp() to logit, often we get inf value |
okay, I have my answer: transformers/src/transformers/generation/utils.py Line 3296 in 125de41
|
EDIT: ===========================
As I see many people copy pasting this initial code that was meant to be a basis for discussion, here is a cleaner version (yet not perfect! We're still doing improvement rounds with the huggingface team to improve it! Check the state of the PR until it's not merged! #24654 ).
===============================
Feature request
Hello!$P(w_t|w_{..t}, prompt)$ to that of the input deprived of the prompt $P(w_t|w_{..t})$ , by defining
I wish to contribute CFG sampling. I'm working with EleutherAI and @StellaAthena and will have a paper about it by Friday. CFG brings non trivial improvements on many standard benchmarks. It contrast the logits for the next token
And then we can blend$\log P_{\text{cfg}}$ with $\log P(w|w_{..t}, prompt)$ to smoothen that distribution a bit, but it's optional.
Motivation
My current implementation is:
I am not familiar enough with the design guidelines of HF to know if this implementation as a LogitsWarper is satisfactory.
just a few figures supporting the claims:



Your contribution
I can contribute the code but I need to be guided as I don't know the exact design guidelines and overall architecture of HF.
Thank you for your time!
The text was updated successfully, but these errors were encountered: