-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix BatchSizeFinder
leaving model in train state
#18826
Fix BatchSizeFinder
leaving model in train state
#18826
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tanaymeh probably missed new is_training
arg in on_fit_start
File "lightning/pytorch/callbacks/batch_size_finder.py", line 187, in on_fit_start
self.scale_batch_size(trainer, pl_module)
TypeError: scale_batch_size() missing 1 required positional argument: 'is_training'
Other than aforementioned missing arg, this PR seems to fix the issue for val/test/prediction, thanks a lot |
@BoringDonut Thanks for the review! I have added the requested changes, can you please review it again? Thanks again! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to work well and fixed the issue. Tested with trainer.fit
.predict
.validate
and .test
Presumably this PR doesn't change any behavior withing BatchSizeFinder
internals, since it only affects the model after _scale_batch_size
already did its main task anyway
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #18826 +/- ##
==========================================
- Coverage 84% 49% -35%
==========================================
Files 443 435 -8
Lines 36123 35974 -149
==========================================
- Hits 30231 17579 -12652
- Misses 5892 18395 +12503 |
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", is_training: bool) -> None: | ||
new_size = _scale_batch_size( | ||
trainer, | ||
self._mode, | ||
self._steps_per_trial, | ||
self._init_val, | ||
self._max_trials, | ||
self._batch_arg_name, | ||
is_training, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tanaymeh On a second thought, I think we can avoid adding a new argument here
pl_module
passed to scale_batch_size
anyway, so there are no need to read pl_module.training
earlier
Example:
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
is_training = pl_module.training
new_size = _scale_batch_size(
trainer,
self._mode,
self._steps_per_trial,
self._init_val,
self._max_trials,
self._batch_arg_name,
is_training,
)
...
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.scale_batch_size(trainer, pl_module)
...
Doing it that way would keep compatibility for people who might have called to scale_batch_size
in their own code (e.g. classes inherited from BatchSizeFinder
)
@tanaymeh @BoringDonut have you found where the batch size tuner calls |
@awaelchli I didn't find out the place where the bug is happening but I am working on it! |
I think I found it 1.https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/src/lightning/pytorch/tuner/batch_size_scaling.py#L333 Basically it seems |
@BoringDonut I added your previously suggested changes of moving the Please let me know if I should add any more changes cc: @awaelchli |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think you can add a test to tests/tests_pytorch/tuner/test_scale_batch_size.py
?
@carmocca Should the test be similar to @BoringDonut's script (here) that is used to test if the model produces the same output in train and eval modes? |
I would suggest calling |
@carmocca Should create a mock for the trainer and lightning module as well? from unittest.mock import Mock
...
trainer = Mock()
lightning_module = Mock()
trainer.lightning_module = lightning_module
... |
|
GitGuardian id | Secret | Commit | Filename | |
---|---|---|---|---|
- | Generic High Entropy Secret | 78fa3af | tests/tests_app/utilities/test_login.py | View secret |
- | Base64 Basic Authentication | 78fa3af | tests/tests_app/utilities/test_login.py | View secret |
🛠 Guidelines to remediate hardcoded secrets
- Understand the implications of revoking this secret by investigating where it is used in your code.
- Replace and store your secret safely. Learn here the best practices.
- Revoke and rotate this secret.
- If possible, rewrite git history. Rewriting git history is not a trivial act. You might completely break other contributing developers' workflow and you risk accidentally deleting legitimate data.
To avoid such incidents in the future consider
- following these best practices for managing and storing secrets including API keys and other credentials
- install secret detection on pre-commit to catch secret before it leaves your machine and ease remediation.
🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.
Our GitHub checks need improvements? Share your feedbacks!
What does this PR do?
This PR patches the bug where
BatchSizeFinder
would leave the model in train state if used with trainer.validateFixes #18813
Before submitting
PR review
@BoringDonut
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--18826.org.readthedocs.build/en/18826/