-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove additional float/clone() for perf #1374
Conversation
Extra aten::item, cast is causing perf degradation
@@ -3580,8 +3579,7 @@ def _assisted_decoding( | |||
|
|||
# 2.3. Process the new logits | |||
# .float() is needed to retain precision for later logits manipulations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, I think you should at least remove this comment
@jiminha Have you got time to check if we get the same output sequences removing all these |
I tested all text_generation and test_encoder_decoder test cases from pytest and compared the outputs and it all looked the same. We'd like to understand more though why the original transformer added this float() for all logits' last output computation, what specific test cases that they ran and found this float() is needed. Would you be able to check with them? |
@jiminha It's explained in the first message of this issue: huggingface/transformers#30860 Basically, we don't want to cast all the logits to float in the forward causal-lm modelsand we only do it:
So we should probably remove this float here and keep the ones in |
Thanks for the link and explanation. So are we actually having double float cast on this v4.45 in the code since we added float() on the generation/utils.py and also kept the same logic in the model(causal LM) file as well? Maybe that's why I didn't see any regression even after removing float() on the util. The perf regression what I saw was in this t5 model (t5 test_encoder_decoder test) which doesn't use the causalLM and didn't have float conversion to begin with, so the duration increased with this change. |
@@ -2370,7 +2369,7 @@ def _sample( | |||
next_token_scores = logits_processor(input_ids, next_token_logits) | |||
else: | |||
# .float() is needed to retain precision for later logits manipulations | |||
next_token_logits = outputs.logits[:, -1, :].float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider keeping this
@@ -2814,7 +2813,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): | |||
else: | |||
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) | |||
else: | |||
next_token_logits = outputs.logits[:, -1, :].float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't normally run to here
For T5 model, I think it shows bigger diff in the perf in smaller batch/samples. For bigger batch/samples the difference is very small. Batch2, sample 200 (current test) Original remove float from _sample predict_runtime = 0:00:16.51 **bs128/ sample 1000 ** Original Remove float from _sample |
For T5 it would be in |
Closing this PR since it's not adding much perf improvement for most of the model. T5 has small regression, but it's very little. |
What does this PR do?
Remove float() from 4.45 upgrade due to perf issue
Extra aten::item, cast is causing perf degradation
DONOTMERGE NOW : Need accuracy test to see if there is any accuracy drop.