Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF - Fix interchangeable past/past_key_values and revert output variable name in GPT2 #16332

Merged
merged 11 commits into from
Mar 23, 2022

Conversation

gante
Copy link
Member

@gante gante commented Mar 22, 2022

Context

From the discussion in #16311 (PR that applies @unpack_inputs to TF gpt2): In the generate refactor, TF gpt2 got an updated prepare_inputs_for_generation(), where its output past got renamed into past_key_values (i.e. as in FLAX/PT). Patrick suggested reverting it since this prepared input could be used externally.

What did I find while working on this PR?

Reverting as suggested above makes TF gpt2 fail tests related to encoder_decoder, which got an updated prepare_inputs_for_generation() in the same PR that expects a past_key_values (and not a past).

Meanwhile, I've also noticed a related bug in the new @unpack_inputs decorator, where it was not preserving a previous behavior -- when the model received a past_key_values but expected a past input (and vice-versa), it automatically swapped the keyword. This feature was the key enabler behind encoder_decoder+gpt2, as encoder_decoder was throwing out past_key_values prepared inputs that were caught by gpt2's past argument.

So, what's in this PR?

This PR fixes the two issues above, which are needed for proper behavior in all combinations of inputs to TF gpt2, after the introduction of the decorator:

  1. corrects the bug in the @unpack_inputs decorator and adds tests to ensure we don't regress on some key properties of our TF input handling. After this PR, gpt2 preserves its ability to receive past (and past_key_values, if through encoder_decoder-like), with and without the decorator.
  2. It also reverts past_key_values into past whenever the change was introduced in TF generate refactor - past without encoder outputs #15944, and makes the necessary changes in encoder_decoder-like models.

@gante gante requested a review from patrickvonplaten March 22, 2022 12:51
@gante gante changed the title Revert past variable name in TF GPT Revert past variable name in TF GPT2 Mar 22, 2022
@gante
Copy link
Member Author

gante commented Mar 22, 2022

(Wait, there is an error)
Should be good now
Nope

@gante
Copy link
Member Author

gante commented Mar 22, 2022

@patrickvonplaten hold your review, this change is not conflicting with the encoder_decoder models. I believe I know why, digging deeper.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@gante gante changed the title Revert past variable name in TF GPT2 TF - Fix interchangeable past/past_key_values and revert past variable name in GPT2 Mar 22, 2022
@gante gante changed the title TF - Fix interchangeable past/past_key_values and revert past variable name in GPT2 TF - Fix interchangeable past/past_key_values and revert output variable name in GPT2 Mar 22, 2022
@gante
Copy link
Member Author

gante commented Mar 22, 2022

@patrickvonplaten now it is properly fixed -- please check the updated description at the top :)

Meanwhile, the scope increased a bit, so I'm tagging a 2nd reviewer (@Rocketknight1 )

@gante gante requested a review from Rocketknight1 March 22, 2022 19:41
@@ -423,13 +424,13 @@ def input_processing(func, config, input_ids, **kwargs):
)
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")

if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs:
if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
Copy link
Member Author

@gante gante Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the root cause for the problem in the decorator -- previously, this function was called inside call, where kwargs contained all keyword arguments (at the very least, with their default value).

The decorator now calls this before call and, because it does not have default values, kwargs was empty. This meant that the past<>past_key_values magic, needed for gpt2+encoder_decoder, was not happening when the decorator was applied on gpt2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

@@ -694,14 +694,17 @@ def prepare_inputs_for_generation(
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@@ -878,7 +878,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_x
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past,
"past": past,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reverting this

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing - I'm not to familiar with the changes in modeling_tf_utils.py so if possible it'd be nice if someone else could take a look here

@Rocketknight1
Copy link
Member

I think Sylvain also tries to avoid modeling_tf_utils.py as much as possible too these days, lol. Let me take a look!

Comment on lines 407 to 408
parameter_names = list(signature.keys())
parameter_names_list = list(signature.keys())
parameter_names = set(parameter_names_list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for the signature to have duplicate keys, or is this just to make if x in parameter_names faster?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't see why we need two of those. Creating them is probably slower than the lookup in the list (models have 10 arguments usually).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha yeah, I went overboard with this one -- with the number of lookups we do per call, it is faster to create the set, but we're talking about microseconds (went on to check it with timeit). Clearly not worth adding code.

Reverting.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this looks like a clean fix to the GPT2 workaround, so I'm happy to approve it. I'd really like to get rid of these old non-standard arguments next time we can make a breaking change, though!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this @gante !

Comment on lines 407 to 408
parameter_names = list(signature.keys())
parameter_names_list = list(signature.keys())
parameter_names = set(parameter_names_list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't see why we need two of those. Creating them is probably slower than the lookup in the list (models have 10 arguments usually).

@gante
Copy link
Member Author

gante commented Mar 23, 2022

I'd really like to get rid of these old non-standard arguments next time we can make a breaking change, though!

@Rocketknight1 me too 🙈 that function is a mess

@gante gante merged commit 9e8c37d into huggingface:main Mar 23, 2022
FrancescoSaverioZuppichini pushed a commit that referenced this pull request Mar 24, 2022
…ble name in GPT2 (#16332)

* revert tf gpt2

* add test for unpack_inputs and fix test case

* add changes to vision encoder decoder
@gante gante deleted the tf_past_revert branch March 28, 2022 16:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants