-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving memory efficiency further 🚀 #30860
Comments
That's actually something we should really do, in the light of #29943 which has this: transformers/src/transformers/models/jamba/modeling_jamba.py Lines 1657 to 1662 in e2ecd86
|
(clone is missing) |
This is true except in assisted generation, where we want the logits for all candidate tokens 😛 But we can generalize to "we only ever want as many logits as input tokens". 👉 regarding keeping all the logits at prefill time: in our 👉 regarding casting the logits with |
Yeah I think it should be okay. Our tests are gonna fail but I would want a bench result to see if the break is worth it! |
Great! I'll open a PR soon and will provide benchmarks. |
@Cyrilvallez this issue is complete, correct? |
Indeed! Closing it |
## Summary The analogous `logits.float()` calls were moved in the Hugging Face modeling source code to be inside the `if labels is not None` block to avoid upcasting logits unless they are being used in a loss calculation; this avoids a memory spike during inference if the model is in lower precision. * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/llama/modeling_llama.py#L1211-L1212 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/mixtral/modeling_mixtral.py#L1329-L1330 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/phi3/modeling_phi3.py#L1303-L1304 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/qwen2/modeling_qwen2.py#L1206-L1207 Some of your models already have this change: https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/mistral.py#L114-L116 https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/gemma.py#L114-L116 See also: * huggingface/transformers#30860 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
Feature request
Removing the line
logits = logits.float()
in mostModelForCausalLM
. This would allow to save a lot of memory for models with large vocabulary size. This allows to divide the memory peak by more than 2 on Llama3.Motivation
This is in relation to my work in #30536.
I noticed that almost all
ModelForCausalLM
contain the following line in theforward
:Now, since most models are now used in (b)float16, or even quantized, that line will almost always double the memory footprint of the logits. As the vocabulary size can be quite big (e.g. Llama3), this result in a lot of memory being used.
I suspect that it was originally introduced so that later manipulations of the logits (processors, warpers...) can be applied without losing too much precision. However, in
generate()
we only ever use the last token logits, not the whole logit matrix. So this is a huge waste of memory.Your contribution
If the casting of the logits to float is indeed only used for not losing precision in their manipulations, I propose to only cast the last token to
float
in each decoding strategy function.So, instead of:
in
forward()
, doin each decoding strategy function. It would only cast the last token vector to float which is negligible in term of memory overhead.
As an example of the potential memory gains, running this very simple code snippet on Llama3 8B (vocabulary size 128256):
gives:


That is, more than dividing by 2 the memory footprint. This is because the vocabulary size is so large that computing the logits from the hidden states is actually more costly than computing the hidden states themselves. Thus when casting to
float()
, we more than double the memory requirements (double for the new logits + the overhead when actually copying).Of course, other models usually have smaller vocabulary size so will not benefit as much, but still the memory peak will decrease by a non-negligible portion for all applicable models (see below for Mistral, ~30% memory gain). And Llama3, which is I believe the hottest open-source model at the moment will be much more efficient.
mistral_ratio_example.pdf
Of course, if this casting to float is made for something else that I overlooked, this may not be applicable. Otherwise, I would be happy to make the change.
@ArthurZucker @gante
Cheers,
Cyril
The text was updated successfully, but these errors were encountered: