-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce by 2 the memory requirement in generate()
🔥🔥🔥
#30536
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR and great in depth analysis.
One thing we could / should also do is remove this reference no?
For now we pass the cache as an input to every single layer, which also returns it.
I suppose that having less of this could solve what you mention about having 2 copies of the cache.
Anyhow if we can fix the memory consumption let's go 🔥
Changing the cache format is super breaking and cannot be done easily but we could do our best to avoid copies!
Hi @ArthurZucker, thanks for the feedback! Unfortunately, I don't think we can handle all cases by keeping the same data structure. For example, if someone passes the I checked the rest of the code base and I agree that it is a big change, but it is very much worth it. I'm definitely down to collaborate on it if you wish! |
Hi @ArthurZucker, I am done with the work. Could you please review it? At this point, it should be 100% backward compatible. The only change is that now Here is a final benchmark of the improved performances for every decoding strategy available in Transformers. It was still done on one of my RTX 4090, and with Mistral-7B-v0.1: Note that for assisted decoding, for 1000 tokens and more I ran into a weird CUDA error ( As you can see, for non-beam decoding methods, the memory is divided by 2x as advertised (minus some small overhead, so around 86% improvement in practice for Mistral; between 91% and 98% for Llama2 7B and Llama3 8B). For beam methods, as I still rely on Finally, however, as the input size becomes large, the memory bottleneck may become the first model forward pass instead of the cache size. In settings where the input size is large compared to the number of new tokens generated (say e.g. input size 4000 and new tokens 5), the memory footprint will be dominated by the cost of the first model forward pass instead of the cache size. In those figures (the first is for beam-based methods, the second for non beam-based), we can observe "3 zones". First, as the number of new tokens generated is very small compared to the input size, we have a plateau of "minimal memory improvement", then as the number of new tokens increases a sharp increase in memory improvement, until the size of the cache is similar to the memory footprint of the first forward pass. At this point, we then plateau to the zone of « maximum memory improvement », which is roughly 2 times more efficient for all non beam-based methods, and roughly 1.5 more efficient for all beam-based decoding strategies. I checked and we ALWAYS hit the zone of minimal improved performance, which means in this case (see figure) that ANY call to non beam-based method will yield at least a memory improvement of ~1.22 for Mistral and ~1.62 for Llama2 (even generating only 2 new tokens for any input size), with still even better improved efficiency (up to between 1.85 and 2) for a large range of input size and max new tokens. The same is true for beam-based methods, with minimal improvement ~1.3 for Mistral and more for Llama2 (up to ~1.5 max for a large range of inputs). I hope this is clear and can be integrated as soon as possible! Don't hesitate to tell me if I am missing something that should be done before merging. |
Hi @ArthurZucker, I was investigating why we observe those "minimal improvements" even with very large input sizes and just 2 new tokens. I found out that the reason was that the logits in every iteration leaked to the next iteration though some hanging references to the outputs dictionary. This was true for every decoding strategy except assisted decoding. I fixed it in my last 2 commits. Non-beam methodsThat means that for any input size, the memory peak was not caused by the forward pass of the first iteration with the large input, but during the 2nd iteration, when the copying of the full cache was happening. At that point, the memory peak consisted of
Now, both the cache and the logits trivially scale linearly with the input size. Since flash attention, the memory needed for the forward pass also scales linearly with the input size, thus Now, as we generate more tokens after the first iteration, and the cache size exceeds the Beam methodsFor beam methods, the story is slightly different. Because reordering the cache allocates a full copy of the cache, the leak of the cache and the logits caused
and when generating enough new tokens, the Contrastive searchContrastive search is a bit of a special case. As a non-beam strategy, it will benefit from 2x memory savings but will not scale in the same way. Because we artificially multiply the size of the cache by As you see on the image, as Restarting from previous cacheNow, because greedy and sample search never perform any copies of the cache anymore, restarting from previous cache such as:
will actually benefit from a 3x memory improvement (that is, we divide the memory footprint by 3x). This is because before we would always have 2 copies of the cache internally during the second call to As most chat applications will actually always use greedy or sample decoding, and will restart every new conversation turn from the previous cache (at least they should most of the time), it means that these chat applications will always save between 2x and 3x memory. TL;DRThis PR significantly reduces the memory footprint of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez kudos for this really high quality contribution 🤗 🚀
Your in-depth explanations are very useful, and all of it makes sense.
I think at some point we did want to change the way the cache was stored, to also remove the two transpose operation that are always needed in the modeling code (which is not the case for jax codes).
@gante some potential break is the cache format, but I think we can add an EfficientDynamicCache
class that would be used upon activation. @Cyrilvallez this would be for use better than having two code paths and makes more sense for isolating bugs / maintenance!
PS: do you have any gist on how to reproduce memory benchmarks and etc?
WDYT @gante
src/transformers/cache_utils.py
Outdated
# Whenever we have more than N new K-V value, cat() them. That way, we keep a relatively low number | ||
# of tensors in self.key_cache[layer_idx], which is more efficient to later cat() them all, and we only | ||
# copy a small subset into memory whenever we cat() the last N K-V states | ||
N = 50 | ||
index = None | ||
for i, x in enumerate(self.key_cache[layer_idx]): | ||
if x.shape[-2] == 1: | ||
index = i | ||
break | ||
if index is not None and len(self.key_cache[layer_idx]) - 1 - index > N: | ||
self.key_cache[layer_idx] = self.key_cache[layer_idx][:index] + [ | ||
torch.cat(self.key_cache[layer_idx][index:], dim=-2) | ||
] | ||
self.value_cache[layer_idx] = self.value_cache[layer_idx][:index] + [ | ||
torch.cat(self.value_cache[layer_idx][index:], dim=-2) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an interesting way to do some bucketing
src/transformers/cache_utils.py
Outdated
# Whenever we have more than N new K-V value, cat() them. That way, we keep a relatively low number | ||
# of tensors in self.key_cache[layer_idx], which is more efficient to later cat() them all, and we only | ||
# copy a small subset into memory whenever we cat() the last N K-V states | ||
N = 50 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be a generation_config argument
src/transformers/generation/utils.py
Outdated
@@ -1302,7 +1302,10 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): | |||
if isinstance(model_kwargs["past_key_values"], Cache): | |||
past_length = model_kwargs["past_key_values"].get_seq_length() | |||
else: | |||
past_length = model_kwargs["past_key_values"][0][0].shape[2] | |||
if isinstance(model_kwargs["past_key_values"][0][0], list): | |||
past_length = sum(x.shape[-2] for x in model_kwargs["past_key_values"][0][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the past key values is a EfficientCache class, this should be easier to extract. + @gante using the cache_positions would be the best here!
Thanks for reviewing @ArthurZucker! Here is a link to the benchmarks I ran: https://gist.github.com/Cyrilvallez/ce1adfad1d561c1e8dc92666ab5a9e8c It is a bit messy but you should find everything. Each benchmark was run on the official version 4.40.1, saving outputs with |
@Cyrilvallez what do you think about creating a new cache class? |
@ArthurZucker I agree that a new class would be clearer and more maintainable. Not sure what you meant by "that would be used upon activation" however? I would use the new EfficientDynamicCache as the new default for all models actually using DynamicCache by default if not provided. As I cannot think of any backward compatibility issues (even if some people save cache to file to restart from it, the new class will work), doing otherwise would be counter-productive I think (it would be weird to have to pass an argument to activate a feature that does the same as before, but more efficiently in all cases). I don't have your experience maintaining such a widely used codebase however, so if you think people may run into problems, we could raise a warning to point them to how to use the old DynamicCache if they run into any issue? |
By default we don't change anything, and just passing "cache_implementaiton="efficient" |
Ok, then by default we will still benefit from removing the leak of the logits which is already a big gain. I will make the necessary changes next Monday 👌🏻 |
@Cyrilvallez very cool in-depth exploration 😮🔥 And also very impactful consequences of the suggested changes! Reading the discussion, from a usage perspective, I agree with Arthur: a separate class would be ideal! With a separate cache class, we:
I'll review the full code after it is implemented as a separate class, as we all seem aligned on what should be done 🤗 @Cyrilvallez ping me when it's ready :D |
@Cyrilvallez out of curiosity: have you considered/explored expanding the cache with fixed-size blocks every time we hit the limit, similarly to paged attention? |
@ArthurZucker @gante I realized yesterday that what actually creates the copies is not the current
Doing so will avoid the back and forth switch from legacy to cache, and thus all references in the decoding strategies functions will point to the same This would greatly simplify the above approach while retaining all benefits (even more as |
b835d4d
to
8095a4a
Compare
@ArthurZucker @gante The work is ready for final review! As previously said, Models without support for Moreover, this work paves the way to the full adoption of the new Finally, here is the final benchmark of performances: Beam-based methods now benefit from using 3X less memory because even Finally, as we are using |
New idea to further improve memory: #30860 |
Damn that's impressive! Reviewing now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright! In general 🔥 looks great
- I think splitting the PR into 3 would be the best!
- PR with the clone and del of the output
- PR with the removal of the support for tuples
- PR with the new efficient cache.
Unless 2 and 3 have to go together!
@gante will be the one reviewing, and deciding so let's get his opinion! 🤗
In any case, amazing work! 🚀
@@ -382,6 +384,7 @@ def __init__(self, **kwargs): | |||
|
|||
# Cache implementation | |||
self.cache_implementation = kwargs.pop("cache_implementation", None) | |||
self.return_legacy_cache = kwargs.pop("cache_implementation", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bit weird to fetch cache_implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming return_legacy_cache
should be fetched instead (and not cache_implementation
) :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, should be return_legacy_cache
, this one is clearly my mistake!
src/transformers/generation/utils.py
Outdated
# New efficient cache | ||
elif isinstance(data, DynamicCache): | ||
return data.split(full_batch_size, split_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a nice way to put it IMO
# We may have an initialized but empty DynamicCache during first iteration | ||
past_exist = past_key_values is not None and not ( | ||
isinstance(past_key_values, DynamicCache) and len(past_key_values) == 0 | ||
) | ||
if past_exist: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are in generate, this should not even exist anyway and cache positions should be used!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this particular piece of code is a blocker for me. It's something we should and can handle in generate no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very very cool developments! It's amazing to see that we could get the memory benefits without adding a new class 💛 Really good work, @Cyrilvallez 🔥
I've added a few comments, mostly related to long-term maintenance
@ArthurZucker @Cyrilvallez -- I'm happy with this being done in a single PR :) After we 3 are happy with the changes, before merging, I will run a few slow tests on my end to confirm the PR is fully backwards compatible!
src/transformers/cache_utils.py
Outdated
self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :] | ||
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :] | ||
|
||
def split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: | |
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: |
Suggestion: split
can mean many things, as the object contains 2 lists of 4D tensors (so 6 possible "dimensions" to split). A more precise name helps with readability 🤗
src/transformers/cache_utils.py
Outdated
def from_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": | ||
"""This is the opposite of the above `split()` method. This will be used by `stack_model_outputs` in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def from_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": | |
"""This is the opposite of the above `split()` method. This will be used by `stack_model_outputs` in | |
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": | |
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in |
(continuation of the suggestion above)
@@ -382,6 +384,7 @@ def __init__(self, **kwargs): | |||
|
|||
# Cache implementation | |||
self.cache_implementation = kwargs.pop("cache_implementation", None) | |||
self.return_legacy_cache = kwargs.pop("cache_implementation", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming return_legacy_cache
should be fetched instead (and not cache_implementation
) :)
src/transformers/generation/utils.py
Outdated
if isinstance(past, DynamicCache) and not self._supports_dynamic_cache_class: | ||
raise ValueError( | ||
f"{self.__class__.__name__} does not support an instance of `DynamicCache` as `past_key_values`. Please " | ||
"check the model documentation for supported cache formats." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this new if
and the _supports_dynamic_cache_class
model attribute are redundant: all models with self._supports_cache_class = True
should support DynamicCache
. As such, both this if
and the new model attribute can be removed.
The exception is Jamba, which has a custom Cache
(and no support for the legacy class). lmk if custom logic is needed for Jamba, so we can find a solution that doesn't require a new model attribute 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha yes, this has indeed become redundant since the clarification for cache supports attributes in 9d889f8! No need for this attribute anymore then 👌 I rebased my branch to apply the latest changes in main, and removed the attribute
src/transformers/generation/utils.py
Outdated
# Remove potential default DynamicCache if assistant does not support it | ||
assistant_kwargs = copy.copy(model_kwargs) | ||
if assistant_model is not None: | ||
if use_dynamic_cache_by_default and not assistant_model._supports_dynamic_cache_class: | ||
if len(assistant_kwargs["past_key_values"]) == 0: | ||
del assistant_kwargs["past_key_values"] | ||
else: | ||
assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
request: can we move this logic to AssistedCandidateGenerator.__init__
? That way, all logic regarding the initialization of the assistant attributes stays in one place :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! This will indeed be much cleaner
93cd2ea
to
6608872
Compare
@ArthurZucker @gante I applied all changes following your comments! Repo consistency and code quality errors do not come from files I modified (I think code quality errors come from ruff version now being 0.4.4, previously 0.1.x). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for iterating and making transformers
better for all of us 💛
@Cyrilvallez I believe CI is red due to external causes, this PR should fix it after it gets merged @ArthurZucker can you give this PR a final check? :) |
Ran a few tests locally (slow tests for popular models, slow |
Having a look! |
9e2968e
to
3c0999b
Compare
Yep, according to @younes this should not work. I'll review and push another empty commit |
Hi @zucchini-nlp! When rebasing I noticed that in your recent #30483 you made QuantizedCache a subclass of DynamicCache. However, some code paths in Before, I had added a few methods ( This means that if If this PR gets merged, overwriting the 5 methods mentioned above should make |
@ArthurZucker Rebasing is done and all CIs are green! |
@Cyrilvallez Right, QuantizedCache stores most of the past kv in a private list, so I think these methods would not work even before your changes. Thanks for noticing! I will think if we want to make quantized cache compatible with these generation strategies, otherwise raise an error when user tries to go that path. And amazing work in this PR for making generation better! 🔥 |
Yes, it will fail in the current state of the library, but adding support will be straight-forward and won't require messing with 'generation.utils' after this is merged! |
Reviewing today sorry for the delay |
Hi @ArthurZucker, any news? 😁 |
Super excited to see this get merged, to hopefully resolve #30019 Thank you all for your hard work on this 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM appart from the modeling modifications!
@@ -389,6 +389,57 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens | |||
cache.update(key_states, value_states, layer_idx) | |||
return cache | |||
|
|||
def crop(self, maximum_length: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to update the docstring of the class to explain why we have these methods! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
"""
this needs to be updated!
# We may have an initialized but empty DynamicCache during first iteration | ||
past_exist = past_key_values is not None and not ( | ||
isinstance(past_key_values, DynamicCache) and len(past_key_values) == 0 | ||
) | ||
if past_exist: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this particular piece of code is a blocker for me. It's something we should and can handle in generate no?
@ArthurZucker The modeling update was mostly used in the case Should be good now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! 🚀 a last nit and let's merge
@@ -389,6 +389,57 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens | |||
cache.update(key_states, value_states, layer_idx) | |||
return cache | |||
|
|||
def crop(self, maximum_length: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
"""
this needs to be updated!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker With how things evolved in the end, the docstring for |
I am more thinking about explaining that we do internally some bucketing |
The bucketing part was removed in the end because it added unnecessary complexity to the code, and we could get all memory benefits without it! So I don't think that the docstring needs any more details (the |
Okay! Merging then! Congrats and great work 🤗 🚀 |
What does this PR do?
Change the data structure and implementation of
DynamicCache
to halve the memory requirement ingenerate()
at no noticeable speed degradation.Reason
I was working on precise memory estimation of
generate()
and noticed that the expected memory peak and the one observed are different (observed one is much higher). I was able to track down the problem to the current implementation ofDynamicCache
.Since
torch.cat()
creates a copy, and that duringgenerate
we loop over each token using the previous cache in the inputs (thus the old cache is referenced in the inputs so cannot be garbage collected), the current implementation basically has 2 copies of the full cache in every iteration. By changing the data structure fromTuple[Tuple[Tensor]]
toTuple[Tuple[List[Tensor]]]
, I was able to avoid this copy and reduce the memory footprint ofgenerate()
by 2.However, that means that potentially thousands of tensors must be
cat()
'ed whencache.update()
is called to feed the correct tensor to theAttentionLayer
. This results in a speed degradation. I was able to mitigate this issue by periodicallycat()
'ing the tensors in the cache when there are more thanN
(I used 50 as of now), which incurs a negligible memory increase, as the cache for a sequence length ofN=50
is usually completely negligible compared to the size of the full cache (hundreds, or even more frequently thousands). This strategy almost completely removes the speeds degradation, allowing to get the best of both worlds.N
could even be chosen dynamically ingenerate()
depending on the input length and max new tokens, but 50 seemed like a good heuristic to start with.Basically, at a very small performance penalty that is visible only for very large sequence length, we reduce the memory footprint by 2, which by itself allows to increase the batch size by 2 (so should be able to actually speed up the process as passing a sequence of 2 times the batch at lower speed should still be faster then runnings 2 loops at faster speed).
Of course, the best would still be to use a
StaticCache
as I saw that you started implementing, but a DynamicCache is still very much useful, and should not imply to double its effective memory footprint when used in loop referencing itself.Benchmark
Here you can see the benchmark I ran using a RTX 4090 (24 GB) and Mistral 7B, Llama2 7B and Llama3 8B.
Fix batch size of 1 and input length of 300, variable new token number
Fix new token number of 2000 and input length of 300, variable batch size
Integration in Transformers
As I am changing the data structure of the DynamicCache, I had to modify how it is used in models modeling. For now, I only modified
LlamaForCausalLM
andMistralForCausalLM
which are using theDynamicCache
by default to test my implementation. Also, the change of data structure may have impacts elsewhere that I overlooked (if other code rely on attributes of the cache, e.g. a call tocache[0][0].shape
would be anAttributeError
now). If so, do not hesitate to point me towards these parts and I can modify it.I would be happy to help integrate the change in all models if you decide to move forward with the PR.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker and @younesbelkada and @gante