diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 6618f7e930ca1..772ce7762b4c6 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -63,6 +63,9 @@ def _scale_batch_size( rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.") return None + # Capture whether the module is in training or eval mode + is_training = trainer.training + # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") trainer.save_checkpoint(ckpt_path) @@ -95,6 +98,12 @@ def _scale_batch_size( trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) + # Set the module back to it's original mode based on the is_training parameter + if is_training: + trainer.lightning_module.train() + else: + trainer.lightning_module.eval() + return new_size