From ebbe8d8014900e16876938c3e4a0d15c6f8acb67 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 4 Sep 2024 12:05:31 +0200 Subject: [PATCH] Cache docs: update (#32929) * some changes * more updates * fix cache copy * nits * nits * add tests --- docs/source/en/kv_cache.md | 54 +++++++++++++++++++++----- src/transformers/cache_utils.py | 67 +++++++++++++++++++-------------- tests/utils/test_cache_utils.py | 32 ++++++++++++++++ 3 files changed, 114 insertions(+), 39 deletions(-) diff --git a/docs/source/en/kv_cache.md b/docs/source/en/kv_cache.md index 1ae97497d2ff..be566437a347 100644 --- a/docs/source/en/kv_cache.md +++ b/docs/source/en/kv_cache.md @@ -22,7 +22,7 @@ Effective caching helps reduce computation time and improve response rates, espe Transformers support various caching methods, leveraging "Cache" classes to abstract and manage the caching logic. This document outlines best practices for using these classes to maximize performance and efficiency. -Check out all the available `Cache` classes in the [API documentation](./internal/generation_utils.md). +Check out all the available `Cache` classes in the [API documentation](./internal/generation_utils). ## What is Cache and why we should care? @@ -30,7 +30,7 @@ Imagine you’re having a conversation with someone, and instead of remembering KV cache is needed to optimize the generation in autoregressive models, where the model predicts text token by token. This process can be slow since the model can generate only one token at a time, and each new prediction is dependent on the previous context. That means, to predict token number 1000 in the generation, you need information from the previous 999 tokens, which comes in the form of some matrix multiplications across the representations of those tokens. But to predict token number 1001, you also need the same information from the first 999 tokens, plus additional information from token number 1000. That is where key-value cache is used to optimize the sequential generation process by storing previous calculations to reuse in subsequent tokens, so they don't need to be computed again. -More concretely, key-value cache acts as a memory bank for these generative models, where the model stores key-value pairs derived from self-attention layers for previously processed tokens. By storing this information, the model can avoid redundant computations and instead retrieve keys and values of previous tokens from the cache. +More concretely, key-value cache acts as a memory bank for these generative models, where the model stores key-value pairs derived from self-attention layers for previously processed tokens. By storing this information, the model can avoid redundant computations and instead retrieve keys and values of previous tokens from the cache. Note that caching can be used only in inference and should be disabled when training, otherwise it might cause unexpected errors.
For the Curious Minds Who Like to Dive Deep @@ -94,7 +94,7 @@ More concretely, key-value cache acts as a memory bank for these generative mode In 🤗 Transformers, we support various Cache types to optimize the performance across different models and tasks. By default, all models generate with caching, with the [`~DynamicCache`] class being the default cache for most models. It allows us to dynamically grow cache size, by saving more and more keys and values as we generate. If for some reason you don't want to use caches, you can pass `use_cache=False` into the `generate()` method. -Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case. +Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case. Models for which initialization is recommended should be initialized before calling the model and passed to model as a kwarg. In all other cases you can simply define desired `cache_implementation` and we take care of the rest for you. | Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation | |------------------------|------------------|--------------------------|----------------------------|---------|-------------------------| @@ -107,7 +107,7 @@ Refer to the table below to see the difference between cache types and choose th | Sink Cache | Yes | No | Yes | Mid | Yes | -These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details. +These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details. ### Quantized Cache @@ -120,6 +120,8 @@ To enable quantization of the key-value cache, one needs to indicate `cache_impl Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`~QuantizedCacheConfig`] class. One has to indicate which quantization backend to use in the [`~QuantizedCacheConfig`], the default is `quanto`. +It is recommended to set `axis-key/axis-value` parameters in the cache config to `0` if you're using the `quanto` backend and to `1` if you're using the `HQQ` backend. For other config values, please use the defaults unless you're running out of memory. In that case, you may consider decreasing the residual length. + Cache quantization can be detrimental in terms of latency if the context length is short and there is enough GPU VRAM available to run without cache quantization. It is recommended to seek balance between memory efficiency and latency. @@ -143,7 +145,7 @@ I like rock music because it's loud and energetic. It's a great way to express m I like rock music because it's loud and energetic. I like to listen to it when I'm feeling ``` -## Offloaded Cache +### Offloaded Cache Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage. It does so by moving the KV cache for most layers to the CPU. @@ -223,7 +225,7 @@ before successfully generating 40 beams. Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates a specific maximum size for the keys and values, allowing you to generate up to the maximum length without having to modify cache size. Check the below usage example. -For more examples with Static Cache and JIT compilation, take a look at [StaticCache & torchcompile](./llm_optims.md#static-kv-cache-and-torchcompile) +For more examples with Static Cache and JIT compilation, take a look at [StaticCache & torchcompile](./llm_optims#static-kv-cache-and-torchcompile) ```python >>> import torch @@ -306,21 +308,21 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac ### Encoder-Decoder Cache -The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper.md) models but we will be adding more models soon. +The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper) models but we will be adding more models soon. In terms of usage, there is nothing special to be done and calling `generate()` or `forward()` will handle everything for you. ### Model-specific Cache Classes -Some models require storing previous keys, values, or states in a specific way, and the above cache classes cannot be used. For such cases, we have several specialized cache classes that are designed for specific models. These models only accept their own dedicated cache classes and do not support using any other cache types. Some examples include [`~HybridCache`] for [Gemma2](./model_doc/gemma2.md) series models or [`~MambaCache`] for [Mamba](./model_doc/mamba.md) architecture models. +Some models require storing previous keys, values, or states in a specific way, and the above cache classes cannot be used. For such cases, we have several specialized cache classes that are designed for specific models. These models only accept their own dedicated cache classes and do not support using any other cache types. Some examples include [`~HybridCache`] for [Gemma2](./model_doc/gemma2) series models or [`~MambaCache`] for [Mamba](./model_doc/mamba) architecture models. ## Iterative Generation with Cache We have seen how to use each of the cache types when generating. What if you want to use cache in iterative generation setting, for example in applications like chatbots, where interactions involve multiple turns and continuous back-and-forth exchanges. Iterative generation with cache allows these systems to handle ongoing conversations effectively without reprocessing the entire context at each step. But there are some tips that you should know before you start implementing: -The general format when doing iterative generation is as below. First you have to initialize an empty cache of the type you want, and you can start feeding in new prompts iteratively. Keeping track of dialogues history and formatting can be done with chat templates, read more on that in [chat_templating](./chat_templating.md) +The general format when doing iterative generation is as below. First you have to initialize an empty cache of the type you want, and you can start feeding in new prompts iteratively. Keeping track of dialogues history and formatting can be done with chat templates, read more on that in [chat_templating](./chat_templating) In case you are using Sink Cache, you have to crop your inputs to that maximum length because Sink Cache can generate text longer than its maximum window size, but it expects the first input to not exceed the maximum cache length. @@ -366,4 +368,36 @@ print(messages) ## Re-use Cache to continue generation -Sometimes you would want to fist fill-in cache object with key/values for certain prefix prompt and re-use it several times to generate different sequences from it. We are working hard on adding this feature to 🤗 Transformers and will update this section soon. +Sometimes you would want to first fill-in cache object with key/values for certain prefix prompt and re-use it several times to generate different sequences from it. In that case you can construct a `Cache` object that will hold the instruction prompt, and re-use it several times with different text sequences. + +```python +>>> import copy +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache + +>>> model_id = "meta-llama/Llama-2-7b-chat-hf" +>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda") +>>> tokenizer = AutoTokenizer.from_pretrained(model_id) + +>>> # Init StaticCache with big enough max-length (1024 tokens for the below example) +>>> # You can also init a DynamicCache, if that suits you better +>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16) + +>>> INITIAL_PROMPT = "You are a helpful assistant. " +>>> inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda") +>>> # This is the common prompt cached, we need to run forward without grad to be abel to copy +>>> with torch.no_grad(): +... prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values + +>>> prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"] +>>> responses = [] +>>> for prompt in prompts: +... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda") +... past_key_values = copy.deepcopy(prompt_cache) +... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) +... response = tokenizer.batch_decode(outputs)[0] +... responses.append(response) + +>>> print(responses) +[' You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTitle: The Ultimate Guide to Travelling: Tips, Tricks, and', ' You are a helpful assistant. What is the capital of France?\n\nYes, the capital of France is Paris.'] +``` \ No newline at end of file diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 80c36b9f68ee..3c157018ecd2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -305,15 +305,16 @@ class DynamicCache(Cache): ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> past_key_values = DynamicCache() >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() ``` """ @@ -680,16 +681,17 @@ class QuantoQuantizedCache(QuantizedCache): >>> # Run pip install quanto first if you don't have it yet >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> cache_config = QuantizedCacheConfig(nbits=4) >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + QuantoQuantizedCache() ``` """ @@ -739,16 +741,17 @@ class HQQQuantizedCache(QuantizedCache): >>> # Run pip install hqq first if you don't have it yet >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + HQQQuantizedCache() ``` """ @@ -806,15 +809,16 @@ class SinkCache(Cache): ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + SinkCache() ``` """ @@ -992,17 +996,18 @@ class StaticCache(Cache): ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() ``` """ @@ -1161,17 +1166,18 @@ class SlidingWindowCache(StaticCache): ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + SlidingWindowCache() ``` """ @@ -1281,7 +1287,8 @@ class EncoderDecoderCache(Cache): >>> cross_attention_cache = DynamicCache() >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + EncoderDecoderCache() ``` """ @@ -1453,8 +1460,8 @@ class HybridCache(Cache): ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") @@ -1463,7 +1470,8 @@ class HybridCache(Cache): >>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() ``` """ @@ -1645,7 +1653,8 @@ class MambaCache: >>> # Prepare a cache class and pass it to model's forward >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv = outputs.past_key_values + >>> outputs.past_key_values + MambaCache() ``` """ diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 0bb604c96f8c..fb5459be10cb 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest from packaging import version @@ -616,3 +617,34 @@ def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): model.generate(generation_config=offloaded, **inputs) offloaded_peak_memory = torch.cuda.max_memory_allocated(device) assert offloaded_peak_memory < original_peak_memory + + @require_torch_gpu + def test_cache_copy(self): + model_name = "microsoft/Phi-3-mini-4k-instruct" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) + + prompt_cache = StaticCache( + config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16 + ) + + INITIAL_PROMPT = "You are a helpful assistant. " + inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda") + # This is the common prompt cached, we need to run forward without grad to be abel to copy + with torch.no_grad(): + prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values + + prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"] + responses = [] + for prompt in prompts: + new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda") + past_key_values = copy.deepcopy(prompt_cache) + outputs = model.generate(**new_inputs, past_key_values=past_key_values, max_new_tokens=40) + response = tokenizer.batch_decode(outputs)[0] + responses.append(response) + + EXPECTED_DECODED_TEXT = [ + "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", + 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the' + ] # fmt: skip + self.assertTrue(responses == EXPECTED_DECODED_TEXT)