diff --git a/src/dvclive/error.py b/src/dvclive/error.py index d2e040c2..ac94259a 100644 --- a/src/dvclive/error.py +++ b/src/dvclive/error.py @@ -17,6 +17,12 @@ def __init__(self): super().__init__("`dvcyaml` path must have filename 'dvc.yaml'") +class InvalidImageNameError(DvcLiveError): + def __init__(self, name): + self.name = name + super().__init__(f"Cannot log image with name '{name}'") + + class InvalidPlotTypeError(DvcLiveError): def __init__(self, name): from .plots import SKLEARN_PLOTS diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 48b5f94f..81c4ca0b 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -31,6 +31,7 @@ from .error import ( InvalidDataTypeError, InvalidDvcyamlError, + InvalidImageNameError, InvalidParameterTypeError, InvalidPlotTypeError, InvalidReportModeError, @@ -389,11 +390,20 @@ def log_image( if not Image.could_log(val): raise InvalidDataTypeError(name, type(val)) + # If we're given a path, try loading the image first. This might error out. if isinstance(val, (str, PurePath)): from PIL import Image as ImagePIL + suffix = Path(val).suffix + if not Path(name).suffix and suffix in Image.suffixes: + name = f"{name}{suffix}" + val = ImagePIL.open(val) + # See if the image name is valid + if Path(name).suffix not in Image.suffixes: + raise InvalidImageNameError(name) + if name in self._images: image = self._images[name] else: diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index f5621b67..dfcb8efc 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -4,6 +4,7 @@ from PIL import Image from dvclive import Live +from dvclive.error import InvalidImageNameError from dvclive.plots import Image as LiveImage @@ -24,10 +25,37 @@ def test_pil(tmp_dir): assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() +def test_pil_omitting_extension_doesnt_save_without_valid_format(tmp_dir): + live = Live() + img = Image.new("RGB", (10, 10), (250, 250, 250)) + with pytest.raises( + InvalidImageNameError, match="Cannot log image with name 'whoops'" + ): + live.log_image("whoops", img) + + +def test_pil_omitting_extension_sets_the_format_if_path_given(tmp_dir): + live = Live() + img = Image.new("RGB", (10, 10), (250, 250, 250)) + + # Save it first, we'll reload it and pass it's path to log_image again + live.log_image("saved_with_format.png", img) + + # Now try saving without explicit format and check if the format is set correctly. + live.log_image( + "whoops", + (tmp_dir / live.plots_dir / LiveImage.subfolder / "saved_with_format.png"), + ) + + assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "whoops.png").exists() + + def test_invalid_extension(tmp_dir): live = Live() img = Image.new("RGB", (10, 10), (250, 250, 250)) - with pytest.raises(ValueError, match="unknown file extension"): + with pytest.raises( + InvalidImageNameError, match="Cannot log image with name 'image.foo'" + ): live.log_image("image.foo", img)