-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Iterative generation using Input embeds and past_key_values
#35890
Iterative generation using Input embeds and past_key_values
#35890
Conversation
Code which I am using to check the feature branch
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR! We'd also need to add a test for continue_generate_from_inputs_embeds
in https://github.com/huggingface/transformers/blob/main/tests/generation/test_utils.py
LMK if you need any help with adding/running tests :)
src/transformers/generation/utils.py
Outdated
if inputs_embeds is not None and input_ids.shape[1] == 0: | ||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] | ||
elif ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a small comment, same way as there are 3 exceptions above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening a PR with the fix!
Happy to accept the PR as soon as @zucchini-nlp's comments are addressed :)
1890c7b
to
6f62483
Compare
@zucchini-nlp I have written a test case to check the continuation using input embeds. Could you check if it's a valid test case or not? This test case is failing for some models, and I am looking into them (most likely need to overwrite in the model class). Also, should we make this functionality very specific to the static cache or apply it to all types of caches? I'm not sure whether this will break with other types of caches. I will test with different caches once the test cases are complete. |
@yaswanth19 thanks a lot! Yes, the test looks good, but one small nit. I'd recommend to not use And using only default cache is enough imo, if it works then StaticCache should also work |
@zucchini-nlp I have modified the testcase and it's failing on few models like AFAIK this feature is breaking existing testcase for only |
@yaswanth19 sorry, completely forgot about this PR. Let me quick run the tests. For the models you mentioned, we can skip them with a reason why (different cache type, needs pixels etc) Btw, you should be able to run tests locally with UPDATE: failing models are Bloom, Chameleon, CLVP, Cohere2, Fuyu, Ideifcs, Moshi, Qwen2-VL, Zamba. For some models it might be fixable be applying the fix in model code, we override |
460f61a
to
3793039
Compare
3793039
to
8b718dc
Compare
@zucchini-nlp I think the PR is ready for review; Only catch is as I said before, this feature is breaking a existing functionality of cc: @ylacombe |
@yaswanth19 for If it actually doesn't break anything and you mean Moshi can't do |
@zucchini-nlp The CI is green 😃 and the failing tests are unrelated to this PR. Ready for Review 🤗 |
past_key_values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating 💛 If @zucchini-nlp is happy with the changes, we can then merge :)
(the test is a bit complex, but this is also a complex feature)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, happy with changes
Thanks for your outstanding contribution! Does this PR support batch inference like this by pre-computing a attention_mask_real |
* Iterative generation using input embeds * ruff fix * Added Testcase * Updated comment * ♻️ Refactored testcase * Skip test for these models * Continue generation using input embeds and cache * Skip generate_continue_from_embeds test * Refactor `prepare_input_for_generation` func * Continue generation using input embeds and cache * Modular changes fix * Overwrite 'prepare_inputs_for_generation' function
What does this PR do?
Fixes #34678 #35707
Logic: If cache is present along with
inputs_embeds
then useinputs_embeds
to generate first token for every prompt rather than only for the first token of the cacheBefore submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@zucchini-nlp