Skip to content

Commit

Permalink
Fix ModelCheckpoint.CHECKPOINT_NAME_LAST test interaction (#18993)
Browse files Browse the repository at this point in the history
(cherry picked from commit b4605b4)
  • Loading branch information
awaelchli authored and lantiga committed Nov 15, 2023
1 parent d947b01 commit a49b1c8
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,12 @@ def test_model_checkpoint_file_extension(tmpdir):
assert set(expected) == set(os.listdir(tmpdir))


def test_model_checkpoint_save_last(tmpdir):
def test_model_checkpoint_save_last(tmpdir, monkeypatch):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = LogInTwoMethods()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last-{epoch}"
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_NAME_LAST", "last-{epoch}")
model_checkpoint = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -511,7 +511,6 @@ def test_model_checkpoint_save_last(tmpdir):
)
assert os.path.islink(tmpdir / last_filename)
assert os.path.realpath(tmpdir / last_filename) == model_checkpoint._last_checkpoint_saved
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"


def test_model_checkpoint_link_checkpoint(tmp_path):
Expand Down

0 comments on commit a49b1c8

Please sign in to comment.