Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterative generation using Input embeds and past_key_values #35890

Merged

Conversation

yaswanth19
Copy link
Contributor

@yaswanth19 yaswanth19 commented Jan 25, 2025

What does this PR do?

Fixes #34678 #35707
Logic: If cache is present along with inputs_embeds then use inputs_embeds to generate first token for every prompt rather than only for the first token of the cache

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp

@yaswanth19
Copy link
Contributor Author

yaswanth19 commented Jan 25, 2025

Code which I am using to check the feature branch

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.generation_config.max_new_tokens = 30

prompt_cache = StaticCache(config=model.config, batch_size=1, max_cache_len=1000)

INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt")

inputs_embeds = model.get_input_embeddings()(inputs_initial_prompt.input_ids)
outputs = model.generate(inputs_embeds=inputs_embeds, past_key_values=prompt_cache)

response = tokenizer.batch_decode(outputs)[0]
print(response)

prompts = ["Help me to write a blogpost about travelling.", "Write a short note on AI"]
responses = []
for prompt in prompts:
    new_inputs = tokenizer(prompt, return_tensors="pt")
    new_input_ids = torch.cat([outputs, new_inputs.input_ids], dim=1)

    inputs_embeds = torch.cat([inputs_embeds,model.get_input_embeddings()(new_input_ids)],dim=1) # Necessary to align with cache

    outputs = model.generate(inputs_embeds=inputs_embeds, past_key_values=prompt_cache)
    response = tokenizer.batch_decode(outputs)[0]
    print(response)
    responses.append(response)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! We'd also need to add a test for continue_generate_from_inputs_embeds in https://github.com/huggingface/transformers/blob/main/tests/generation/test_utils.py

LMK if you need any help with adding/running tests :)

Comment on lines 386 to 390
if inputs_embeds is not None and input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a small comment, same way as there are 3 exceptions above?

@zucchini-nlp zucchini-nlp requested a review from gante January 27, 2025 08:25
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening a PR with the fix!

Happy to accept the PR as soon as @zucchini-nlp's comments are addressed :)

@yaswanth19 yaswanth19 force-pushed the iterative-prompting-with-embeds branch from 1890c7b to 6f62483 Compare January 30, 2025 17:10
@yaswanth19
Copy link
Contributor Author

yaswanth19 commented Jan 30, 2025

@zucchini-nlp I have written a test case to check the continuation using input embeds. Could you check if it's a valid test case or not? This test case is failing for some models, and I am looking into them (most likely need to overwrite in the model class). Also, should we make this functionality very specific to the static cache or apply it to all types of caches? I'm not sure whether this will break with other types of caches. I will test with different caches once the test cases are complete.

@zucchini-nlp
Copy link
Member

@yaswanth19 thanks a lot! Yes, the test looks good, but one small nit. I'd recommend to not use input ids from first and second batches and then flat-concat them. You can take a look at how it is done in continue_generate_from_input_ids test

And using only default cache is enough imo, if it works then StaticCache should also work

@yaswanth19
Copy link
Contributor Author

yaswanth19 commented Feb 1, 2025

@zucchini-nlp I have modified the testcase and it's failing on few models like idefics,zamba2, moshi. It's either because they have their own custom prepare_input_for_generation function, due to non-standard kv cache etc.. Can you help in running the generate CI for all the tests coz rn the CI stops once a single test fails. Running the CI on all the cases will help me in identifying all the models the testcase is failing on and then I can either overwrite or skip them.

AFAIK this feature is breaking existing testcase for only Moshi . Will need to debug it.

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 4, 2025

@yaswanth19 sorry, completely forgot about this PR. Let me quick run the tests. For the models you mentioned, we can skip them with a reason why (different cache type, needs pixels etc)

Btw, you should be able to run tests locally with pytest -k test_generate_continue_from_inputs_embeds tests/models/, they don't require any GPU

UPDATE: failing models are Bloom, Chameleon, CLVP, Cohere2, Fuyu, Ideifcs, Moshi, Qwen2-VL, Zamba. For some models it might be fixable be applying the fix in model code, we override prepare_inputs_for_generation for multimodal models usually

@yaswanth19 yaswanth19 force-pushed the iterative-prompting-with-embeds branch 2 times, most recently from 460f61a to 3793039 Compare February 5, 2025 13:22
@yaswanth19 yaswanth19 force-pushed the iterative-prompting-with-embeds branch from 3793039 to 8b718dc Compare February 5, 2025 14:14
@yaswanth19
Copy link
Contributor Author

@zucchini-nlp I think the PR is ready for review; Only catch is as I said before, this feature is breaking a existing functionality of Moshi as there instead of overwriting, it uses the general prepare_inputs_for_generation func and performs some post-processing on top of that and calculates input_embeds. Hence I am not sure how to fix/skip it.

cc: @ylacombe

@zucchini-nlp
Copy link
Member

@yaswanth19 for Moshi if adding this new condition breaks generation, we should override the method in Moshi (same way as Qwen2-VL does it) and add a comment on top why it is overriden

If it actually doesn't break anything and you mean Moshi can't do continue_from_input_embeds, feel free just to skip the test. Noone would want to continue from embeds in audio models imo

@yaswanth19
Copy link
Contributor Author

@zucchini-nlp The CI is green 😃 and the failing tests are unrelated to this PR. Ready for Review 🤗

@yaswanth19 yaswanth19 changed the title Iterative generation using Input embeds and static cache Iterative generation using Input embeds and past_key_values Feb 5, 2025
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 💛 If @zucchini-nlp is happy with the changes, we can then merge :)

(the test is a bit complex, but this is also a complex feature)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, happy with changes

@zucchini-nlp zucchini-nlp merged commit 7aee036 into huggingface:main Feb 6, 2025
21 of 25 checks passed
@lzl-mt
Copy link

lzl-mt commented Feb 7, 2025

Code which I am using to check the feature branch

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.generation_config.max_new_tokens = 30

prompt_cache = StaticCache(config=model.config, batch_size=1, max_cache_len=1000)

INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt")

inputs_embeds = model.get_input_embeddings()(inputs_initial_prompt.input_ids)
outputs = model.generate(inputs_embeds=inputs_embeds, past_key_values=prompt_cache)

response = tokenizer.batch_decode(outputs)[0]
print(response)

prompts = ["Help me to write a blogpost about travelling.", "Write a short note on AI"]
responses = []
for prompt in prompts:
    new_inputs = tokenizer(prompt, return_tensors="pt")
    new_input_ids = torch.cat([outputs, new_inputs.input_ids], dim=1)

    inputs_embeds = torch.cat([inputs_embeds,model.get_input_embeddings()(new_input_ids)],dim=1) # Necessary to align with cache

    outputs = model.generate(inputs_embeds=inputs_embeds, past_key_values=prompt_cache)
    response = tokenizer.batch_decode(outputs)[0]
    print(response)
    responses.append(response)

Thanks for your outstanding contribution! Does this PR support batch inference like this by pre-computing a attention_mask_real
model_outputs = self.llm.generate( inputs_embeds=inputs_embeds_audio, # max_length=kwargs.get("max_length", 200), max_new_tokens=kwargs.get("max_new_tokens", 200), num_beams=kwargs.get("num_beams", 4), do_sample=kwargs.get("do_sample", False), min_length=kwargs.get("min_length", 1), top_p=kwargs.get("top_p", 1.0), repetition_penalty=kwargs.get("repetition_penalty", 1.0), length_penalty=kwargs.get("length_penalty", 1.0), temperature=kwargs.get("temperature", 1.0), attention_mask=attention_mask_real, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, past_key_values=kv_cache )

MekkCyber pushed a commit that referenced this pull request Feb 7, 2025
* Iterative generation using input embeds

* ruff fix

* Added Testcase

* Updated comment

* ♻️ Refactored testcase

* Skip test for these models

* Continue generation using input embeds and cache

* Skip generate_continue_from_embeds test

* Refactor `prepare_input_for_generation` func

* Continue generation using input embeds and cache

* Modular changes fix

* Overwrite 'prepare_inputs_for_generation' function
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug when using StaticCache in Qwen2.5 Inference
4 participants