Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BatchSizeFinder leaving model in train state #18826

Conversation

tanaymeh
Copy link
Contributor

@tanaymeh tanaymeh commented Oct 19, 2023

What does this PR do?

This PR patches the bug where BatchSizeFinder would leave the model in train state if used with trainer.validate

Fixes #18813

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

@BoringDonut

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--18826.org.readthedocs.build/en/18826/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Oct 19, 2023
Copy link
Contributor

@BoringDonut BoringDonut left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanaymeh probably missed new is_training arg in on_fit_start

  File "lightning/pytorch/callbacks/batch_size_finder.py", line 187, in on_fit_start
    self.scale_batch_size(trainer, pl_module)
TypeError: scale_batch_size() missing 1 required positional argument: 'is_training'

@BoringDonut
Copy link
Contributor

BoringDonut commented Oct 19, 2023

Other than aforementioned missing arg, this PR seems to fix the issue for val/test/prediction, thanks a lot

@tanaymeh tanaymeh requested a review from BoringDonut October 19, 2023 14:55
@tanaymeh
Copy link
Contributor Author

@BoringDonut Thanks for the review! I have added the requested changes, can you please review it again?

Thanks again!

Copy link
Contributor

@BoringDonut BoringDonut left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to work well and fixed the issue. Tested with trainer.fit .predict .validate and .test
Presumably this PR doesn't change any behavior withing BatchSizeFinder internals, since it only affects the model after _scale_batch_size already did its main task anyway

@codecov
Copy link

codecov bot commented Oct 19, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 49%. Comparing base (b8a96fe) to head (591eb4e).
Report is 492 commits behind head on master.

❗ There is a different number of reports uploaded between BASE (b8a96fe) and HEAD (591eb4e). Click for more details.

HEAD has 190 uploads less than BASE
Flag BASE (b8a96fe) HEAD (591eb4e)
lightning 42 16
cpu 71 24
pytest 53 2
python3.10 21 9
python3.9 6 3
lightning_fabric 10 0
python3.8 12 6
app 9 0
examples 9 0
python3.11 15 6
gpu 4 2
tpu 2 0
lightning_app 4 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18826      +/-   ##
==========================================
- Coverage      84%      49%     -35%     
==========================================
  Files         443      435       -8     
  Lines       36123    35974     -149     
==========================================
- Hits        30231    17579   -12652     
- Misses       5892    18395   +12503     

Comment on lines 171 to 180
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", is_training: bool) -> None:
new_size = _scale_batch_size(
trainer,
self._mode,
self._steps_per_trial,
self._init_val,
self._max_trials,
self._batch_arg_name,
is_training,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanaymeh On a second thought, I think we can avoid adding a new argument here

pl_module passed to scale_batch_size anyway, so there are no need to read pl_module.training earlier

Example:

    def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        is_training = pl_module.training
        new_size = _scale_batch_size(
            trainer,
            self._mode,
            self._steps_per_trial,
            self._init_val,
            self._max_trials,
            self._batch_arg_name,
            is_training,
        )
...
    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self.scale_batch_size(trainer, pl_module)
...

Doing it that way would keep compatibility for people who might have called to scale_batch_size in their own code (e.g. classes inherited from BatchSizeFinder)

@awaelchli
Copy link
Contributor

@tanaymeh @BoringDonut have you found where the batch size tuner calls model.train()? Ideally, we would want to make the finder not change it in the first place. This would be the real fix IMO. We should leave the responsibility of calling model.train()/eval() to the respective loop IMO.

@awaelchli awaelchli added tuner bug Something isn't working labels Oct 21, 2023
@tanaymeh
Copy link
Contributor Author

@tanaymeh @BoringDonut have you found where the batch size tuner calls model.train()? Ideally, we would want to make the finder not change it in the first place. This would be the real fix IMO. We should leave the responsibility of calling model.train()/eval() to the respective loop IMO.

@awaelchli I didn't find out the place where the bug is happening but I am working on it!

@BoringDonut
Copy link
Contributor

BoringDonut commented Oct 21, 2023

have you found where the batch size tuner calls model.train()?

@tanaymeh @awaelchli

I think I found it
Call trace looks like that

1.https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/src/lightning/pytorch/tuner/batch_size_scaling.py#L333
which redirects to
2. https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/src/lightning/pytorch/loops/evaluation_loop.py#L108
Last line of which calls to return self.on_run_end()
3. https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/src/lightning/pytorch/loops/evaluation_loop.py#L246
Aaaaand this methods assumes it runs within the train loop, so it converts model to the train state
4.
https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/src/lightning/pytorch/loops/evaluation_loop.py#L270
This _on_evaluation_model_train ends up in
5. https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/src/lightning/pytorch/core/hooks.py#L166

Basically it seems BatchSizeFinder should either not call _EvaluationLoop through .run(), as this method would unconditionally turn model to a train state
Or it indeed should compensate by calling to model.eval if it understands that it works with an _EvaluationLoop
I don't see an easy way to modify _EvaluationLoop.run() to address it without too much of a hustle
As such, i think solution by @tanaymeh looks reasonable.
I guess to be more consistent with existing BatchSizeFinder style this call to .eval can be done in BatchSizeFinder.__scale_batch_restore_params

@tanaymeh
Copy link
Contributor Author

@BoringDonut I added your previously suggested changes of moving the is_training inside the scale_batch_size function.

Please let me know if I should add any more changes

cc: @awaelchli

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think you can add a test to tests/tests_pytorch/tuner/test_scale_batch_size.py?

@tanaymeh
Copy link
Contributor Author

Do you think you can add a test to tests/tests_pytorch/tuner/test_scale_batch_size.py?

@carmocca Should the test be similar to @BoringDonut's script (here) that is used to test if the model produces the same output in train and eval modes?

@carmocca
Copy link
Contributor

I would suggest calling _scale_batch_size directly with some unittest Mocks, and assert that the original training state is preserved

@tanaymeh
Copy link
Contributor Author

tanaymeh commented Oct 30, 2023

I would suggest calling _scale_batch_size directly with some unittest Mocks, and assert that the original training state is preserved

@carmocca Should create a mock for the trainer and lightning module as well?

from unittest.mock import Mock

...

trainer = Mock()
lightning_module = Mock()
trainer.lightning_module = lightning_module

...

Copy link

gitguardian bot commented Jan 16, 2024

⚠️ GitGuardian has uncovered 2 secrets following the scan of your pull request.

Please consider investigating the findings and remediating the incidents. Failure to do so may lead to compromising the associated services or software components.

🔎 Detected hardcoded secrets in your pull request
GitGuardian id Secret Commit Filename
- Generic High Entropy Secret 78fa3af tests/tests_app/utilities/test_login.py View secret
- Base64 Basic Authentication 78fa3af tests/tests_app/utilities/test_login.py View secret
🛠 Guidelines to remediate hardcoded secrets
  1. Understand the implications of revoking this secret by investigating where it is used in your code.
  2. Replace and store your secret safely. Learn here the best practices.
  3. Revoke and rotate this secret.
  4. If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.

To avoid such incidents in the future consider


🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.

Our GitHub checks need improvements? Share your feedbacks!

@awaelchli awaelchli modified the milestones: 2.1.x, 2.2.x Feb 8, 2024
@awaelchli awaelchli modified the milestones: 2.2.x, 2.3.x Jun 13, 2024
@tanaymeh tanaymeh closed this Jun 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working community This PR is from the community pl Generic label for PyTorch Lightning package tuner
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BatchSizeFinder leaves model in the train state if used with trainer.validate
5 participants