Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using tf.tile in embeddings for TF models #14735

Merged
merged 3 commits into from
Dec 13, 2021

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Dec 12, 2021

What does this PR do?

Some TF models use

position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))

which assume that position_ids has size 1 along batch dimension. If users don't specify position_ids, we create it
(before using it)

if position_ids is None:
    position_ids = tf.expand_dims(
        tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
    )

which will have batch size 1. However, in INPUTS_DOCSTRING, it specifies the shape to be (batch_size, seq_len).
If a user provides a full batch for position_ids (although this is very unlikely), tf.tile shouldn't be used here.

This PR fixes this issue.

Who can review?

@Rocketknight1

@ydshieh ydshieh marked this pull request as ready for review December 12, 2021 17:49
@ydshieh ydshieh changed the title [WIP] avoid tf.tile in embeddings Avoid using tf.tile in embeddings for TF models Dec 12, 2021
@Rocketknight1
Copy link
Member

I love it, thank you for doing this! I wonder if there's a reason for using the Add() layers like that originally? It feels very odd.

@Rocketknight1
Copy link
Member

Either way, it's a straightforward change and I'm happy to merge as-is, so let me know once you're ready.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Dec 13, 2021

I love it, thank you for doing this! I wonder if there's a reason for using the Add() layers like that originally? It feels very odd.

I also feel the same, and don't know why Add() is used. I removed it here also because Add() requires the shape to be the same, including batch dim (and won't work after I removed tf.tile).

The PR is ready. I can rebase on master to see if I can make the tests green.

@ydshieh ydshieh force-pushed the remove_tf_embeddings_sum branch from 3a94ea5 to e6bc7b4 Compare December 13, 2021 16:38
@ydshieh
Copy link
Collaborator Author

ydshieh commented Dec 13, 2021

Failed tests are irrelevant to this PR. Let me know if you prefer to wait and rebase later.

@Rocketknight1
Copy link
Member

No, we're seeing those tests on every PR. I'm happy to merge now - let me know whenever the PR is done!

@ydshieh
Copy link
Collaborator Author

ydshieh commented Dec 13, 2021

No, we're seeing those tests on every PR. I'm happy to merge now - let me know whenever the PR is done!

It's is done. You can merge. Thanks!

@Rocketknight1 Rocketknight1 merged commit 15a9d01 into huggingface:master Dec 13, 2021
@Rocketknight1
Copy link
Member

Done!

@ydshieh ydshieh deleted the remove_tf_embeddings_sum branch May 5, 2022 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants