Skip to content

Commit

Permalink
fix init nan for checkpointing (#3863)
Browse files Browse the repository at this point in the history
* add test for checkpoint nan

* fix

* pep
  • Loading branch information
Borda authored Oct 5, 2020
1 parent b014223 commit 6ac0958
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ def _update_best_and_save(
self.best_k_models.pop(self.kth_best_model_path)
del_list.append(delpath)

# do not save non, for replace then by +/- inf
if torch.isnan(current):
current = {"min": torch.tensor(float('inf')), "max": torch.tensor(-float('inf'))}[self.mode]

self.best_k_models[filepath] = current
if len(self.best_k_models) == k:
# monitor dict has reached k elements
Expand Down
27 changes: 26 additions & 1 deletion tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -457,3 +457,28 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()


@pytest.mark.parametrize('mode', ['min', 'max'])
def test_checkpointing_with_nan_as_first(tmpdir, mode):
os.environ['PL_DEV_DEBUG'] = '1'
monitor = [float('nan')]
monitor += [5, 7, 8] if mode == 'max' else [8, 7, 5]

class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
val_loss = monitor[self.current_epoch]
self.log('abc', val_loss)

model = CurrentModel()

trainer = Trainer(
checkpoint_callback=ModelCheckpoint(monitor='abc', mode=mode, save_top_k=1, filepath=tmpdir),
default_root_dir=tmpdir,
val_check_interval=1.0,
max_epochs=len(monitor),
)
trainer.fit(model)

# check that last one is also the best one
assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1

0 comments on commit 6ac0958

Please sign in to comment.