-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove logits.float()
#33902
Remove logits.float()
#33902
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀🚀thanks for being so prompt in fixing! 🤗
Hey @ringohoffman, I'm curious about the batch size you used to notice such a difference with the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I have observed these severe spikes with batch size 1. At this point I always just comment out that piece of code manually given it does not matter for my workflows at all. |
Humm, are you using |
We observe these spikes during training and generation, but we are using a different loss than the LM loss, so we don't require full precision during training. |
I see, makes sense then, as long as you don't pass |
Some misses from huggingface#31292 and huggingface#33902
* Only cast logits to float when computing loss Some misses from huggingface#31292 and huggingface#33902 * Move logits.float() into existing if labels is not None branch
* Remove logits.float() if not computing loss * Remove warning about 4.46 logits dtype change if not computing loss
* Only cast logits to float when computing loss Some misses from huggingface#31292 and huggingface#33902 * Move logits.float() into existing if labels is not None branch
What does this PR do?
Follow up to:
given that 4.45 has been released and 4.46 is next
Llama 3.1 8B FSDP2 peak inference memory usage with
float()
(18.5GiB):Llama 3.1 8B FSDP2 peak inference memory usage without
float()
(10.6GiB):In my environment, ~43% reduction in peak memory usage.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Cyrilvallez @gante @ArthurZucker