Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BatchSizeFinder leaving model in train state #18826

12 changes: 8 additions & 4 deletions src/lightning/pytorch/callbacks/batch_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,15 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
" If this is not the intended behavior, please remove either one."
)

def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", is_training: bool) -> None:
new_size = _scale_batch_size(
trainer,
self._mode,
self._steps_per_trial,
self._init_val,
self._max_trials,
self._batch_arg_name,
is_training,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanaymeh On a second thought, I think we can avoid adding a new argument here

pl_module passed to scale_batch_size anyway, so there are no need to read pl_module.training earlier

Example:

    def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        is_training = pl_module.training
        new_size = _scale_batch_size(
            trainer,
            self._mode,
            self._steps_per_trial,
            self._init_val,
            self._max_trials,
            self._batch_arg_name,
            is_training,
        )
...
    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self.scale_batch_size(trainer, pl_module)
...

Doing it that way would keep compatibility for people who might have called to scale_batch_size in their own code (e.g. classes inherited from BatchSizeFinder)


self.optimal_batch_size = new_size
Expand All @@ -189,10 +190,13 @@ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
if trainer.sanity_checking or trainer.state.fn != "validate":
return

self.scale_batch_size(trainer, pl_module)
is_training = pl_module.training
self.scale_batch_size(trainer, pl_module, is_training)

def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.scale_batch_size(trainer, pl_module)
is_training = pl_module.training
self.scale_batch_size(trainer, pl_module, is_training)

def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.scale_batch_size(trainer, pl_module)
is_training = pl_module.training
self.scale_batch_size(trainer, pl_module, is_training)
7 changes: 7 additions & 0 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _scale_batch_size(
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size",
is_training: bool = True,
) -> Optional[int]:
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
error.
Expand Down Expand Up @@ -95,6 +96,12 @@ def _scale_batch_size(
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)

# Set the model to training or evaluation mode based on the is_training parameter
if is_training:
trainer.lightning_module.train()
else:
trainer.lightning_module.eval()

return new_size


Expand Down