-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF generate refactor - past without encoder outputs #15944
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know this part of the library enough to give a useful review. Styling is good, but I'll defer to the others for approval :-)
# the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs` | ||
# variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in | ||
# `prepare_inputs_for_generation` | ||
if encoder_hidden_states is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Why not wrap it into a TFEncoderOutputs
class here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question! I tried that, it would be the most sensible change IMO (as the updated generate gets the encoder outputs with return_dict=True
). However, a TFEncoderOutputs
would make T5 tests fail. At this point, I had 2 options: update TF T5 or write this. Since this PR is mostly about updating the past variable, I thought it would be the path of least resistance.
Happy to change T5 instead :)
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( | ||
input_ids, return_dict_in_generate, model_kwargs | ||
) | ||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) you could maybe put the under the # 4. Prepare ...
comment and change the comment to prepare model inputs which will be used for ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Code looks much cleaner now and the prepare_inputs_for_generation_...
functions are greatly simplified.
Left 1,2 nits. Finally, it would be nice if you could also update the TF encoder-decoder model templates (copy of TFBart) so that this test doesn't fail.
The fastest way to test these things locally is to do the following:
-
- Update the templates similar to how TFBart was updated. Commit your changes
-
- Create a new TFBart-like model with the
add-new-model
command & run tests for the created model.
- Create a new TFBart-like model with the
-
- run `git reset --hard`` so that the new model code disappears again
-
- If tests are all passing, then you can commit, if not repeat 1-4
Would be nice if @Rocketknight1 could also take a look here
Let's merge this? cc @Rocketknight1 ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, the TF changes look good to me and I don't see any problems. I -think- after talking to @gante that the change of trimming input_ids
to only the last token whenever past
is present is okay, but I'm still a bit confused about how that works!
# only last token for inputs_ids if past is defined in kwargs | ||
if past: | ||
inputs = tf.expand_dims(inputs[:, -1], -1) | ||
|
||
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]} | ||
return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessary if there is no past
input variable name in GPT2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should revert this I think and maybe deprecate past
as an input argument name for all models in a seperate PR :-)
What does this PR do?
As discussed in the original TF generate refactor plan (#15562), removes the
encoder_outputs
frompast
. In practice, these changes consist mostly in:prepare_inputs_for_generation
and_reorder_cache
from PT to TF, for each class.Three important notes:
past
orencoder_outputs
were handled;cross_attn_head_mask
inprepare_inputs_for_generation
, in their PT implementation, but raised errors in TF -> I've deleted it from the function output;RUN_SLOW=1 pytest -vv tests/model_name/test_modeling_tf_model_name.py
for all affected models.