From b4644da688626a4082cb4d10b3a7141ed401dff1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 5 Oct 2020 11:05:09 +0200 Subject: [PATCH 1/3] add test for checkpoint nan --- tests/checkpointing/test_model_checkpoint.py | 23 +++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b3b8204166ecc..bb22d491d9764 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -15,7 +15,7 @@ from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from tests.base import EvalModelTemplate +from tests.base import EvalModelTemplate, BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -457,3 +457,24 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() + + +# todo consider if we shall ignore also +/- inf +@pytest.mark.parametrize('mode', ['min', 'max', 'auto']) +def test_checkpointing_with_nan(tmpdir, mode): + losses = [8, 7, float('nan'), 5] + + class CurrentModel(BoringModel): + def validation_epoch_end(self, outputs): + val_loss = losses[self.current_epoch] + self.log('abc', torch.tensor(val_loss)) + + model = CurrentModel() + + trainer = Trainer( + checkpoint_callback=ModelCheckpoint(monitor='abc', mode=mode, save_top_k=1, filepath=tmpdir), + default_root_dir=tmpdir, + val_check_interval=1.0, + max_epochs=len(losses), + ) + trainer.fit(model) \ No newline at end of file From 3d6d1436a84adc9b8328d24278e3e560cf118d3b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 5 Oct 2020 11:30:27 +0200 Subject: [PATCH 2/3] fix --- .../callbacks/model_checkpoint.py | 4 ++++ tests/checkpointing/test_model_checkpoint.py | 20 +++++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f27f9c4c61476..2517888ad5ed7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -517,6 +517,10 @@ def _update_best_and_save( self.best_k_models.pop(self.kth_best_model_path) del_list.append(delpath) + # do not save non, for replace then by +/- inf + if torch.isnan(current): + current = {"min": torch.tensor(float('inf')), "max": torch.tensor(-float('inf'))}[self.mode] + self.best_k_models[filepath] = current if len(self.best_k_models) == k: # monitor dict has reached k elements diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index bb22d491d9764..ab8c9a652e15b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -459,15 +459,16 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert w0.eq(w1).all() -# todo consider if we shall ignore also +/- inf -@pytest.mark.parametrize('mode', ['min', 'max', 'auto']) -def test_checkpointing_with_nan(tmpdir, mode): - losses = [8, 7, float('nan'), 5] +@pytest.mark.parametrize('mode', ['min', 'max']) +def test_checkpointing_with_nan_as_first(tmpdir, mode): + os.environ['PL_DEV_DEBUG'] = '1' + monitor = [float('nan')] + monitor += [5, 7, 8] if mode == 'max' else [8, 7, 5] class CurrentModel(BoringModel): def validation_epoch_end(self, outputs): - val_loss = losses[self.current_epoch] - self.log('abc', torch.tensor(val_loss)) + val_loss = monitor[self.current_epoch] + self.log('abc', val_loss) model = CurrentModel() @@ -475,6 +476,9 @@ def validation_epoch_end(self, outputs): checkpoint_callback=ModelCheckpoint(monitor='abc', mode=mode, save_top_k=1, filepath=tmpdir), default_root_dir=tmpdir, val_check_interval=1.0, - max_epochs=len(losses), + max_epochs=len(monitor), ) - trainer.fit(model) \ No newline at end of file + trainer.fit(model) + + # check that last one is also the best one + assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1 \ No newline at end of file From eb7d7a3a9a529c242618f14f0274f38792b4a830 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 5 Oct 2020 11:33:01 +0200 Subject: [PATCH 3/3] pep --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ab8c9a652e15b..ee988bb8f4b60 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -481,4 +481,4 @@ def validation_epoch_end(self, outputs): trainer.fit(model) # check that last one is also the best one - assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1 \ No newline at end of file + assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1