Skip to content

Commit

Permalink
Cache docs: update (#32929)
Browse files Browse the repository at this point in the history
* some changes

* more updates

* fix cache copy

* nits

* nits

* add tests
  • Loading branch information
zucchini-nlp authored Sep 4, 2024
1 parent 35f72eb commit ebbe8d8
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 39 deletions.
54 changes: 44 additions & 10 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ Effective caching helps reduce computation time and improve response rates, espe

Transformers support various caching methods, leveraging "Cache" classes to abstract and manage the caching logic.
This document outlines best practices for using these classes to maximize performance and efficiency.
Check out all the available `Cache` classes in the [API documentation](./internal/generation_utils.md).
Check out all the available `Cache` classes in the [API documentation](./internal/generation_utils).

## What is Cache and why we should care?

Imagine you’re having a conversation with someone, and instead of remembering what was said previously, you have to start from scratch every time you respond. This would be slow and inefficient, right? In the world of Transformer models, a similar concept applies, and that's where Caching keys and values come into play. From now on, I'll refer to the concept as KV Cache.

KV cache is needed to optimize the generation in autoregressive models, where the model predicts text token by token. This process can be slow since the model can generate only one token at a time, and each new prediction is dependent on the previous context. That means, to predict token number 1000 in the generation, you need information from the previous 999 tokens, which comes in the form of some matrix multiplications across the representations of those tokens. But to predict token number 1001, you also need the same information from the first 999 tokens, plus additional information from token number 1000. That is where key-value cache is used to optimize the sequential generation process by storing previous calculations to reuse in subsequent tokens, so they don't need to be computed again.

More concretely, key-value cache acts as a memory bank for these generative models, where the model stores key-value pairs derived from self-attention layers for previously processed tokens. By storing this information, the model can avoid redundant computations and instead retrieve keys and values of previous tokens from the cache.
More concretely, key-value cache acts as a memory bank for these generative models, where the model stores key-value pairs derived from self-attention layers for previously processed tokens. By storing this information, the model can avoid redundant computations and instead retrieve keys and values of previous tokens from the cache. Note that caching can be used only in inference and should be disabled when training, otherwise it might cause unexpected errors.

<details>
<summary><em>For the Curious Minds Who Like to Dive Deep</em></summary>
Expand Down Expand Up @@ -94,7 +94,7 @@ More concretely, key-value cache acts as a memory bank for these generative mode
In 🤗 Transformers, we support various Cache types to optimize the performance across different models and tasks. By default, all models generate with caching,
with the [`~DynamicCache`] class being the default cache for most models. It allows us to dynamically grow cache size, by saving more and more keys and values as we generate. If for some reason you don't want to use caches, you can pass `use_cache=False` into the `generate()` method.

Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case.
Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case. Models for which initialization is recommended should be initialized before calling the model and passed to model as a kwarg. In all other cases you can simply define desired `cache_implementation` and we take care of the rest for you.

| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation |
|------------------------|------------------|--------------------------|----------------------------|---------|-------------------------|
Expand All @@ -107,7 +107,7 @@ Refer to the table below to see the difference between cache types and choose th
| Sink Cache | Yes | No | Yes | Mid | Yes |


These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details.
These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details.

### Quantized Cache

Expand All @@ -120,6 +120,8 @@ To enable quantization of the key-value cache, one needs to indicate `cache_impl
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`~QuantizedCacheConfig`] class.
One has to indicate which quantization backend to use in the [`~QuantizedCacheConfig`], the default is `quanto`.

It is recommended to set `axis-key/axis-value` parameters in the cache config to `0` if you're using the `quanto` backend and to `1` if you're using the `HQQ` backend. For other config values, please use the defaults unless you're running out of memory. In that case, you may consider decreasing the residual length.

<Tip warning={true}>

Cache quantization can be detrimental in terms of latency if the context length is short and there is enough GPU VRAM available to run without cache quantization. It is recommended to seek balance between memory efficiency and latency.
Expand All @@ -143,7 +145,7 @@ I like rock music because it's loud and energetic. It's a great way to express m
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```

## Offloaded Cache
### Offloaded Cache

Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage.
It does so by moving the KV cache for most layers to the CPU.
Expand Down Expand Up @@ -223,7 +225,7 @@ before successfully generating 40 beams.
Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
a specific maximum size for the keys and values, allowing you to generate up to the maximum length without having to modify cache size. Check the below usage example.

For more examples with Static Cache and JIT compilation, take a look at [StaticCache & torchcompile](./llm_optims.md#static-kv-cache-and-torchcompile)
For more examples with Static Cache and JIT compilation, take a look at [StaticCache & torchcompile](./llm_optims#static-kv-cache-and-torchcompile)

```python
>>> import torch
Expand Down Expand Up @@ -306,21 +308,21 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac

### Encoder-Decoder Cache

The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper.md) models but we will be adding more models soon.
The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper) models but we will be adding more models soon.

In terms of usage, there is nothing special to be done and calling `generate()` or `forward()` will handle everything for you.


### Model-specific Cache Classes

Some models require storing previous keys, values, or states in a specific way, and the above cache classes cannot be used. For such cases, we have several specialized cache classes that are designed for specific models. These models only accept their own dedicated cache classes and do not support using any other cache types. Some examples include [`~HybridCache`] for [Gemma2](./model_doc/gemma2.md) series models or [`~MambaCache`] for [Mamba](./model_doc/mamba.md) architecture models.
Some models require storing previous keys, values, or states in a specific way, and the above cache classes cannot be used. For such cases, we have several specialized cache classes that are designed for specific models. These models only accept their own dedicated cache classes and do not support using any other cache types. Some examples include [`~HybridCache`] for [Gemma2](./model_doc/gemma2) series models or [`~MambaCache`] for [Mamba](./model_doc/mamba) architecture models.


## Iterative Generation with Cache

We have seen how to use each of the cache types when generating. What if you want to use cache in iterative generation setting, for example in applications like chatbots, where interactions involve multiple turns and continuous back-and-forth exchanges. Iterative generation with cache allows these systems to handle ongoing conversations effectively without reprocessing the entire context at each step. But there are some tips that you should know before you start implementing:

The general format when doing iterative generation is as below. First you have to initialize an empty cache of the type you want, and you can start feeding in new prompts iteratively. Keeping track of dialogues history and formatting can be done with chat templates, read more on that in [chat_templating](./chat_templating.md)
The general format when doing iterative generation is as below. First you have to initialize an empty cache of the type you want, and you can start feeding in new prompts iteratively. Keeping track of dialogues history and formatting can be done with chat templates, read more on that in [chat_templating](./chat_templating)

In case you are using Sink Cache, you have to crop your inputs to that maximum length because Sink Cache can generate text longer than its maximum window size, but it expects the first input to not exceed the maximum cache length.

Expand Down Expand Up @@ -366,4 +368,36 @@ print(messages)

## Re-use Cache to continue generation

Sometimes you would want to fist fill-in cache object with key/values for certain prefix prompt and re-use it several times to generate different sequences from it. We are working hard on adding this feature to 🤗 Transformers and will update this section soon.
Sometimes you would want to first fill-in cache object with key/values for certain prefix prompt and re-use it several times to generate different sequences from it. In that case you can construct a `Cache` object that will hold the instruction prompt, and re-use it several times with different text sequences.

```python
>>> import copy
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

>>> model_id = "meta-llama/Llama-2-7b-chat-hf"
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)

>>> # Init StaticCache with big enough max-length (1024 tokens for the below example)
>>> # You can also init a DynamicCache, if that suits you better
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)

>>> INITIAL_PROMPT = "You are a helpful assistant. "
>>> inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
>>> # This is the common prompt cached, we need to run forward without grad to be abel to copy
>>> with torch.no_grad():
... prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values

>>> prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
>>> responses = []
>>> for prompt in prompts:
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
... past_key_values = copy.deepcopy(prompt_cache)
... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
... response = tokenizer.batch_decode(outputs)[0]
... responses.append(response)

>>> print(responses)
['<s> You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTitle: The Ultimate Guide to Travelling: Tips, Tricks, and', '<s> You are a helpful assistant. What is the capital of France?\n\nYes, the capital of France is Paris.</s>']
```
Loading

0 comments on commit ebbe8d8

Please sign in to comment.