-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remember the eval mode of submodules when switching trainer stages #18951
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #18951 +/- ##
==========================================
- Coverage 76% 48% -27%
==========================================
Files 450 442 -8
Lines 36508 36383 -125
==========================================
- Hits 27583 17572 -10011
- Misses 8925 18811 +9886 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this PR replace #18826?
Co-authored-by: Carlos Mocholí <[email protected]>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Since the problem was in the training loop which called Also, currently in the docs it is shown that # ...
for batch_idx, batch in enumerate(train_dataloader):
loss = model.training_step(batch, batch_idx)
loss.backward()
# ...
if validate_at_some_point:
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()
# ----------------- VAL LOOP ---------------
for val_batch_idx, val_batch in enumerate(val_dataloader):
val_out = model.validation_step(val_batch, val_batch_idx)
# ----------------- VAL LOOP ---------------
# enable grads + batchnorm + dropout
torch.set_grad_enabled(True)
model.train() I would happy to open a new issue for fixing the docs to better align with this very useful PR! |
What does this PR do?
Fixes #18930
Part of #16827
A common issue users are facing is that the loop calls
train()
on the LightningModule despite the user having frozen certain layers. For example,This leads to a surprise when the user finds out that their batch norm layers have changed statistics, even though they were set explicitly to
eval()
mode. To avoid this, the user has to learn that they should override theon_validation_model_eval()
andon_validation_model_train()
hooks in the module, but this is a detail difficult to find in our docs and get right. Most users who face this challenge end up on slack or GH to ask for help.The PR makes the following changes to automate this for the user:
.training
mode of every submodule before calling.eval()
now. When the validation loop ends, and before switching to training, it restores the.training
mode on all submodules to what it was before. This ensures that layers the user has chosen to be in eval mode remain in eval mode!.train()
at the beginning with the same motivation: The user can now set a subset of their model to.eval()
mode / freeze it explicitly in the LightningModule's__init__
without doing acrobatics with hooks, and the Trainer will respect it and preserve it (see the added test). Note: This is not a breaking change, because PyTorch's default is to have a model in.training=True
mode.📚 Documentation preview 📚: https://pytorch-lightning--18951.org.readthedocs.build/en/18951/
cc @Borda @justusschock @awaelchli