-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add type hints for PyTorch Models #16425
Add type hints for PyTorch Models #16425
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
…to type hint 'config' accordingly for model _init__ method.
…r' failing tests because of Bart methods custom correction overwriting .
Hello all, I changed the cokkiecutter template code because the Changing a method in Bart model: ##########
## From ##
##########
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
########
## To ##
########
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: torch.Size,
inputs_embeds: torch.FloatTensor,
past_key_values_length: int,
) -> Optional[torch.Tensor]: The code correcter script would make the following change: # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
self,
attention_mask: torch.Tensor,
input_shape: torch.Size,
inputs_embeds: torch.FloatTensor,
past_key_values_length: int,
) -> Optional[torch.Tensor]: Which caused the errors in one of the test runs. Just commenting the reason here in case someone is taking a look at this PR later. |
… that Data2Vec was dependent on, to use 'make fix-copies'.
Hello all, I have updated the code base with type hints for a few models. Thanks cc: @Rocketknight1 |
Hello all, The ImportError: cannot import name 'get_current_traceback' from 'werkzeug.debug.tbtools' (/home/runner/.local/lib/python3.8/site-packages/werkzeug/debug/tbtools.py) I am unsure as to what is causing the error and any leads on how to resolve this issue would be appreciated. Thank you |
Wow, this is a huge PR! Did you do this manually, or have you figured out some kind of tool for it? |
Hello @Rocketknight1 , Yeah, I made all this manually. This was how I spent my weekend 😛. |
That's amazing! I'll try to review now. |
This is a huge and very impressive PR, thank you! The main suggestion I have is that bools are not annotated in some cases, e.g. |
Hello, Sure, I can also a take a look once again to fix the missing ones. Thanks for the update and glad you liked the work. |
Absolutely! I saw in some cases |
Sure, later in the process I figured out the type and I had added for a few files. Fill fix for others as well. |
Note that |
Ohhh, thanks for the heads up. |
@karthikrangasai The best way to make sure the type hints are correct is to check the [Model Name]_INPUTS_DOCSTRING, right before the first user interfaced forward method |
Hello @Tegzes , I have type hinted the entire file, from first function to last class. So i might have missed something in other places. |
Hi @karthikrangasai ! This is totally my bad - other PRs came in and I reviewed them without realizing they would create conflicts with your one. Would it be possible to break this PR up into a few separate ones and submit them one at a time? That greatly reduces the chances of conflicts for each one, and it'll make it possible for me to add specific comments/suggestions, whereas at this size I really can just give general advice! |
Hello @Rocketknight1 , Yeah sure. I will break the PR into multiple ones based on the corrections made or the model that was type hinted. Should I close this one then ? |
Hi @karthikrangasai, sorry for the delay! Yeah, it's probably easiest to close this one, make new ones and just tag me in them. Thank you! |
What does this PR do?
Add type hints to as many PyTorch models as possible.
This PR targets the following models to type hint entire files:
Any other file that has been edited is a result of running
make fix-copies
.In the next PR, I will target few other models to type hint complete files.
Fixes #16059
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Rocketknight1