Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove logits.float() #33902

Merged
merged 2 commits into from
Oct 4, 2024
Merged

Conversation

ringohoffman
Copy link
Contributor

What does this PR do?

Follow up to:

given that 4.45 has been released and 4.46 is next

Llama 3.1 8B FSDP2 peak inference memory usage with float() (18.5GiB):

float()

Llama 3.1 8B FSDP2 peak inference memory usage without float() (10.6GiB):

no_float()

In my environment, ~43% reduction in peak memory usage.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Cyrilvallez @gante @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀🚀thanks for being so prompt in fixing! 🤗

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Oct 3, 2024

Hey @ringohoffman, I'm curious about the batch size you used to notice such a difference with the float() only, given that num_logits_to_keep=1 by default?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@KyleMylonakisProtopia
Copy link

batch size you used to notice such

I have observed these severe spikes with batch size 1. At this point I always just comment out that piece of code manually given it does not matter for my workflows at all.

@Cyrilvallez
Copy link
Member

Humm, are you using generate? If not, are you making sure to pass num_logits_to_keep=1 to the forward? By default it is 0 if not using generate. Because a batch size of 1 would result in a tensor of shape (1, 1, 128k) for Llama 3.1, which is always low independently of the dtype.

@KyleMylonakisProtopia
Copy link

KyleMylonakisProtopia commented Oct 3, 2024

We observe these spikes during training and generation, but we are using a different loss than the LM loss, so we don't require full precision during training.

@Cyrilvallez
Copy link
Member

I see, makes sense then, as long as you don't pass labels 🤗

@ArthurZucker ArthurZucker merged commit 550673a into huggingface:main Oct 4, 2024
18 checks passed
@ringohoffman ringohoffman deleted the remove-logits-float branch October 9, 2024 07:35
ringohoffman added a commit to ringohoffman/transformers that referenced this pull request Oct 14, 2024
ArthurZucker pushed a commit that referenced this pull request Oct 18, 2024
* Only cast logits to float when computing loss

Some misses from #31292 and #33902

* Move logits.float() into existing if labels is not None branch
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Oct 21, 2024
* Only cast logits to float when computing loss

Some misses from huggingface#31292 and huggingface#33902

* Move logits.float() into existing if labels is not None branch
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Remove logits.float() if not computing loss

* Remove warning about 4.46 logits dtype change if not computing loss
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Only cast logits to float when computing loss

Some misses from huggingface#31292 and huggingface#33902

* Move logits.float() into existing if labels is not None branch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants