diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 484983d7b62be..df9101910dc8e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -353,12 +353,9 @@ def train(self): # stop training stop = should_stop and met_min_epochs if stop: - self.main_progress_bar.close() - model.on_train_end() - return + break self.main_progress_bar.close() - model.on_train_end() if self.logger is not None: diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index fa643c64ac5fc..2f39dbd25e124 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -21,6 +21,7 @@ def test_early_stopping_cpu_model(tmpdir): stopping = EarlyStopping(monitor='val_loss', min_delta=0.1) trainer_options = dict( default_save_path=tmpdir, + min_epochs=2, early_stop_callback=stopping, gradient_clip_val=1.0, overfit_pct=0.20, @@ -33,7 +34,7 @@ def test_early_stopping_cpu_model(tmpdir): ) model, hparams = tutils.get_model() - tutils.run_model_test(trainer_options, model, on_gpu=False) + tutils.run_model_test(trainer_options, model, on_gpu=False, early_stop=True) # test freeze on cpu model.freeze() diff --git a/tests/utils.py b/tests/utils.py index 70ecccb28d7a0..33da9b6228939 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,7 +45,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50): trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() -def run_model_test(trainer_options, model, on_gpu=True): +def run_model_test(trainer_options, model, on_gpu=True, early_stop=False): save_dir = trainer_options['default_save_path'] # logger file to get meta @@ -65,6 +65,10 @@ def run_model_test(trainer_options, model, on_gpu=True): # correct result and ok accuracy assert result == 1, 'amp + ddp model failed to complete' + if early_stop: + assert trainer.current_epoch >= trainer.min_epochs, 'amp + ddp model failed to complete' + assert trainer.current_epoch < trainer.max_epochs-1, 'amp + ddp model failed to stop early' + # test model loading pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath)