Skip to content

Commit

Permalink
Fix test interactions (#18994)
Browse files Browse the repository at this point in the history
(cherry picked from commit 340961a)
  • Loading branch information
awaelchli authored and lantiga committed Nov 15, 2023
1 parent d4d27a6 commit 8d0830b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
4 changes: 2 additions & 2 deletions tests/tests_fabric/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ def test_tensorboard_finalize(monkeypatch, tmp_path):


@mock.patch("lightning.fabric.loggers.tensorboard.log")
def test_tensorboard_with_symlink(log, tmp_path):
def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
relative paths."""
os.chdir(tmp_path) # need to use relative paths
monkeypatch.chdir(tmp_path) # need to use relative paths
source = os.path.join(".", "lightning_logs")
dest = os.path.join(".", "sym_lightning_logs")

Expand Down
12 changes: 5 additions & 7 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
assert trainer.state.finished, f"Training failed with {trainer.state}"


def test_model_checkpoint_format_checkpoint_name(tmpdir):
def test_model_checkpoint_format_checkpoint_name(tmpdir, monkeypatch):
# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name("", {"epoch": 3, "step": 2})
assert ckpt_name == "epoch=3-step=2"
Expand All @@ -422,18 +422,16 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == "epoch=003-epoch_test=003"

# prefix
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = "@"
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_JOIN_CHAR", "@")
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch},{acc:.5f}", {"epoch": 3, "acc": 0.03}, prefix="test")
assert ckpt_name == "test@epoch=3,acc=0.03000"
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
monkeypatch.undo()

# non-default char for equals sign
default_char = ModelCheckpoint.CHECKPOINT_EQUALS_CHAR
ModelCheckpoint.CHECKPOINT_EQUALS_CHAR = ":"
monkeypatch.setattr(ModelCheckpoint, "CHECKPOINT_EQUALS_CHAR", ":")
ckpt_name = ModelCheckpoint._format_checkpoint_name("{epoch:03d}-{acc}", {"epoch": 3, "acc": 0.03})
assert ckpt_name == "epoch:003-acc:0.03"
ModelCheckpoint.CHECKPOINT_EQUALS_CHAR = default_char
monkeypatch.undo()

# no dirpath set
ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=None).format_checkpoint_name({"epoch": 3, "step": 2})
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ def test_tensorboard_save_hparams_to_yaml_once(tmp_path):


@mock.patch("lightning.pytorch.loggers.tensorboard.log")
def test_tensorboard_with_symlink(log, tmp_path):
def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
relative paths."""
os.chdir(tmp_path) # need to use relative paths
monkeypatch.chdir(tmp_path) # need to use relative paths
source = os.path.join(".", "lightning_logs")
dest = os.path.join(".", "sym_lightning_logs")

Expand Down

0 comments on commit 8d0830b

Please sign in to comment.