-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF - Fix interchangeable past/past_key_values and revert output variable name in GPT2 #16332
Conversation
|
@patrickvonplaten hold your review, this change is not conflicting with the |
The documentation is not available anymore as the PR was closed or merged. |
@patrickvonplaten now it is properly fixed -- please check the updated description at the top :) Meanwhile, the scope increased a bit, so I'm tagging a 2nd reviewer (@Rocketknight1 ) |
@@ -423,13 +424,13 @@ def input_processing(func, config, input_ids, **kwargs): | |||
) | |||
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") | |||
|
|||
if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs: | |||
if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the root cause for the problem in the decorator -- previously, this function was called inside call
, where kwargs
contained all keyword arguments (at the very least, with their default value).
The decorator now calls this before call
and, because it does not have default values, kwargs
was empty. This meant that the past
<>past_key_values
magic, needed for gpt2+encoder_decoder, was not happening when the decorator was applied on gpt2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense!
@@ -694,14 +694,17 @@ def prepare_inputs_for_generation( | |||
): | |||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) | |||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None | |||
past_key_values = decoder_inputs.get("past_key_values") | |||
if past_key_values is None: | |||
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
@@ -878,7 +878,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x | |||
"input_ids": inputs, | |||
"attention_mask": attention_mask, | |||
"position_ids": position_ids, | |||
"past_key_values": past, | |||
"past": past, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reverting this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing - I'm not to familiar with the changes in modeling_tf_utils.py
so if possible it'd be nice if someone else could take a look here
I think Sylvain also tries to avoid |
parameter_names = list(signature.keys()) | ||
parameter_names_list = list(signature.keys()) | ||
parameter_names = set(parameter_names_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for the signature to have duplicate keys, or is this just to make if x in parameter_names
faster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't see why we need two of those. Creating them is probably slower than the lookup in the list (models have 10 arguments usually).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha yeah, I went overboard with this one -- with the number of lookups we do per call, it is faster to create the set, but we're talking about microseconds (went on to check it with timeit
). Clearly not worth adding code.
Reverting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this looks like a clean fix to the GPT2 workaround, so I'm happy to approve it. I'd really like to get rid of these old non-standard arguments next time we can make a breaking change, though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this @gante !
parameter_names = list(signature.keys()) | ||
parameter_names_list = list(signature.keys()) | ||
parameter_names = set(parameter_names_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't see why we need two of those. Creating them is probably slower than the lookup in the list (models have 10 arguments usually).
@Rocketknight1 me too 🙈 that function is a mess |
…ble name in GPT2 (#16332) * revert tf gpt2 * add test for unpack_inputs and fix test case * add changes to vision encoder decoder
Context
From the discussion in #16311 (PR that applies
@unpack_inputs
to TFgpt2
): In the generate refactor, TFgpt2
got an updatedprepare_inputs_for_generation()
, where its outputpast
got renamed intopast_key_values
(i.e. as in FLAX/PT). Patrick suggested reverting it since this prepared input could be used externally.What did I find while working on this PR?
Reverting as suggested above makes TF
gpt2
fail tests related toencoder_decoder
, which got an updatedprepare_inputs_for_generation()
in the same PR that expects apast_key_values
(and not apast
).Meanwhile, I've also noticed a related bug in the new
@unpack_inputs
decorator, where it was not preserving a previous behavior -- when the model received apast_key_values
but expected apast
input (and vice-versa), it automatically swapped the keyword. This feature was the key enabler behindencoder_decoder
+gpt2
, asencoder_decoder
was throwing outpast_key_values
prepared inputs that were caught bygpt2
'spast
argument.So, what's in this PR?
This PR fixes the two issues above, which are needed for proper behavior in all combinations of inputs to TF
gpt2
, after the introduction of the decorator:@unpack_inputs
decorator and adds tests to ensure we don't regress on some key properties of our TF input handling. After this PR,gpt2
preserves its ability to receivepast
(andpast_key_values
, if throughencoder_decoder
-like), with and without the decorator.past_key_values
intopast
whenever the change was introduced in TF generate refactor - past without encoder outputs #15944, and makes the necessary changes inencoder_decoder
-like models.