Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF generate refactor - past without encoder outputs #15944

Merged
merged 10 commits into from
Mar 8, 2022

Conversation

gante
Copy link
Member

@gante gante commented Mar 4, 2022

What does this PR do?

As discussed in the original TF generate refactor plan (#15562), removes the encoder_outputs from past. In practice, these changes consist mostly in:

  1. Delete the lines flagged by Patrick;
  2. Adapt prepare_inputs_for_generation and _reorder_cache from PT to TF, for each class.

Three important notes:

  1. Beam search was still in the old format, and a few changes there were needed to enable the changes above. They were mostly about how past or encoder_outputs were handled;
  2. Some models have cross_attn_head_mask in prepare_inputs_for_generation, in their PT implementation, but raised errors in TF -> I've deleted it from the function output;
  3. I've run RUN_SLOW=1 pytest -vv tests/model_name/test_modeling_tf_model_name.py for all affected models.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know this part of the library enough to give a useful review. Styling is good, but I'll defer to the others for approval :-)

# the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs`
# variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in
# `prepare_inputs_for_generation`
if encoder_hidden_states is not None:
Copy link
Contributor

@patrickvonplaten patrickvonplaten Mar 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Why not wrap it into a TFEncoderOutputs class here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! I tried that, it would be the most sensible change IMO (as the updated generate gets the encoder outputs with return_dict=True). However, a TFEncoderOutputs would make T5 tests fail. At this point, I had 2 options: update TF T5 or write this. Since this PR is mostly about updating the past variable, I thought it would be the path of least resistance.

Happy to change T5 instead :)

model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
input_ids, return_dict_in_generate, model_kwargs
)
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) you could maybe put the under the # 4. Prepare ... comment and change the comment to prepare model inputs which will be used for ...

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Code looks much cleaner now and the prepare_inputs_for_generation_... functions are greatly simplified.

Left 1,2 nits. Finally, it would be nice if you could also update the TF encoder-decoder model templates (copy of TFBart) so that this test doesn't fail.

The fastest way to test these things locally is to do the following:

    1. Update the templates similar to how TFBart was updated. Commit your changes
    1. Create a new TFBart-like model with the add-new-model command & run tests for the created model.
    1. run `git reset --hard`` so that the new model code disappears again
    1. If tests are all passing, then you can commit, if not repeat 1-4

Would be nice if @Rocketknight1 could also take a look here

@patrickvonplaten
Copy link
Contributor

Let's merge this? cc @Rocketknight1 ?

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, the TF changes look good to me and I don't see any problems. I -think- after talking to @gante that the change of trimming input_ids to only the last token whenever past is present is okay, but I'm still a bit confused about how that works!

@gante gante merged commit 70203b5 into huggingface:master Mar 8, 2022
@gante gante deleted the destroy_past branch March 8, 2022 14:46
# only last token for inputs_ids if past is defined in kwargs
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)

return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary if there is no past input variable name in GPT2

Copy link
Contributor

@patrickvonplaten patrickvonplaten Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should revert this I think and maybe deprecate past as an input argument name for all models in a seperate PR :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants