From cec5d7d64db5d7bdb19adb1615ee6e1a496ab29b Mon Sep 17 00:00:00 2001 From: Natik Gadzhi Date: Wed, 29 Nov 2023 16:44:03 -0800 Subject: [PATCH 1/4] fix: Implicitly infer image format if possible --- src/dvclive/live.py | 10 ++++++++++ src/dvclive/plots/image.py | 2 +- tests/plots/test_image.py | 23 +++++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index d1971601..979458c3 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -37,6 +37,7 @@ clean_and_copy_into, env2bool, inside_notebook, + isinstance_without_import, matplotlib_installed, open_file_in_browser, ) @@ -378,6 +379,15 @@ def log_image(self, name: str, val): val = ImagePIL.open(val) + # If the provided image name does not have a format on it, + # try to infer the format from PIL Image. + if len(name.split(".")) <= 1: + if ( + isinstance_without_import(val, "PIL.Image", "Image") + and f".{str(val.format).lower()}" in Image.suffixes + ): + name = f"{name}.{str(val.format).lower()}" + if name in self._images: image = self._images[name] else: diff --git a/src/dvclive/plots/image.py b/src/dvclive/plots/image.py index 5cc65362..ae1437ad 100644 --- a/src/dvclive/plots/image.py +++ b/src/dvclive/plots/image.py @@ -7,7 +7,7 @@ class Image(Data): suffixes = (".jpg", ".jpeg", ".gif", ".png") - subfolder = "images" + subfolder: str = "images" @property def output_path(self) -> Path: diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index f5621b67..3fee3c4d 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -24,6 +24,29 @@ def test_pil(tmp_dir): assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() +def test_pil_omitting_extension_doesnt_save_without_valid_format(tmp_dir): + live = Live() + img = Image.new("RGB", (10, 10), (250, 250, 250)) + with pytest.raises(ValueError, match="unknown file extension"): + live.log_image("whoops", img) + + +def test_pil_omitting_extension_sets_the_format_if_path_given(tmp_dir): + live = Live() + img = Image.new("RGB", (10, 10), (250, 250, 250)) + + # Save it first, we'll reload it and pass it's path to log_image again + live.log_image("saved_with_format.png", img) + + # Now try saving without explicit format and check if the format is set correctly. + live.log_image( + "whoops", + (tmp_dir / live.plots_dir / LiveImage.subfolder / "saved_with_format.png"), + ) + + assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "whoops.png").exists() + + def test_invalid_extension(tmp_dir): live = Live() img = Image.new("RGB", (10, 10), (250, 250, 250)) From 757bb64fafc5871f025ed4f0691542f85facc28a Mon Sep 17 00:00:00 2001 From: Natik Gadzhi Date: Wed, 29 Nov 2023 17:00:58 -0800 Subject: [PATCH 2/4] Guard against images without format --- src/dvclive/live.py | 8 +++++--- src/dvclive/plots/image.py | 11 +++++++---- tests/plots/test_image.py | 5 +++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 979458c3..028d29c9 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -371,9 +371,7 @@ def log_metric( logger.debug(f"Logged {name}: {val}") def log_image(self, name: str, val): - if not Image.could_log(val): - raise InvalidDataTypeError(name, type(val)) - + # If we're given a path, try loading the image first. This might error out. if isinstance(val, (str, Path)): from PIL import Image as ImagePIL @@ -388,6 +386,10 @@ def log_image(self, name: str, val): ): name = f"{name}.{str(val.format).lower()}" + # See if the image format and image name are valid + if not Image.could_log(name, val): + raise InvalidDataTypeError(name, type(val)) + if name in self._images: image = self._images[name] else: diff --git a/src/dvclive/plots/image.py b/src/dvclive/plots/image.py index ae1437ad..001d5046 100644 --- a/src/dvclive/plots/image.py +++ b/src/dvclive/plots/image.py @@ -16,18 +16,21 @@ def output_path(self) -> Path: return _path @staticmethod - def could_log(val: object) -> bool: + def could_log(name: str, val: object) -> bool: acceptable = { ("numpy", "ndarray"), ("matplotlib.figure", "Figure"), ("PIL.Image", "Image"), } + + supported_format = False for cls in type(val).mro(): if any(isinstance_without_import(val, *cls) for cls in acceptable): - return True + supported_format = True if isinstance(val, (PurePath, str)): - return True - return False + supported_format = True + + return supported_format and f".{name.split('.')[-1]}" in Image.suffixes def dump(self, val, **kwargs) -> None: # noqa: ARG002 if isinstance_without_import(val, "numpy", "ndarray"): diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index 3fee3c4d..577b4d12 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -4,6 +4,7 @@ from PIL import Image from dvclive import Live +from dvclive.error import InvalidDataTypeError from dvclive.plots import Image as LiveImage @@ -27,7 +28,7 @@ def test_pil(tmp_dir): def test_pil_omitting_extension_doesnt_save_without_valid_format(tmp_dir): live = Live() img = Image.new("RGB", (10, 10), (250, 250, 250)) - with pytest.raises(ValueError, match="unknown file extension"): + with pytest.raises(InvalidDataTypeError, match="has not supported type"): live.log_image("whoops", img) @@ -50,7 +51,7 @@ def test_pil_omitting_extension_sets_the_format_if_path_given(tmp_dir): def test_invalid_extension(tmp_dir): live = Live() img = Image.new("RGB", (10, 10), (250, 250, 250)) - with pytest.raises(ValueError, match="unknown file extension"): + with pytest.raises(InvalidDataTypeError, match="has not supported type"): live.log_image("image.foo", img) From 7e696ce3c39fd05ab0fb0cd23f89717c33f19bac Mon Sep 17 00:00:00 2001 From: Natik Gadzhi Date: Wed, 29 Nov 2023 17:27:52 -0800 Subject: [PATCH 3/4] Whoops, forgot the linters --- src/dvclive/live.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 028d29c9..eb8fc37c 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -379,12 +379,12 @@ def log_image(self, name: str, val): # If the provided image name does not have a format on it, # try to infer the format from PIL Image. - if len(name.split(".")) <= 1: - if ( - isinstance_without_import(val, "PIL.Image", "Image") - and f".{str(val.format).lower()}" in Image.suffixes - ): - name = f"{name}.{str(val.format).lower()}" + if ( + len(name.split(".")) <= 1 + and isinstance_without_import(val, "PIL.Image", "Image") + and f".{str(val.format).lower()}" in Image.suffixes + ): + name = f"{name}.{str(val.format).lower()}" # See if the image format and image name are valid if not Image.could_log(name, val): From efe31f85b0c9e68b644db13ee06ed79c9f402689 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 4 Jan 2024 17:28:16 -0500 Subject: [PATCH 4/4] refactor image validation --- src/dvclive/error.py | 6 ++++++ src/dvclive/live.py | 24 +++++++++++------------- src/dvclive/plots/image.py | 13 +++++-------- tests/plots/test_image.py | 10 +++++++--- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/dvclive/error.py b/src/dvclive/error.py index d2e040c2..ac94259a 100644 --- a/src/dvclive/error.py +++ b/src/dvclive/error.py @@ -17,6 +17,12 @@ def __init__(self): super().__init__("`dvcyaml` path must have filename 'dvc.yaml'") +class InvalidImageNameError(DvcLiveError): + def __init__(self, name): + self.name = name + super().__init__(f"Cannot log image with name '{name}'") + + class InvalidPlotTypeError(DvcLiveError): def __init__(self, name): from .plots import SKLEARN_PLOTS diff --git a/src/dvclive/live.py b/src/dvclive/live.py index e4d25ef5..680a5a78 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -28,6 +28,7 @@ from .error import ( InvalidDataTypeError, InvalidDvcyamlError, + InvalidImageNameError, InvalidParameterTypeError, InvalidPlotTypeError, InvalidReportModeError, @@ -43,7 +44,6 @@ convert_datapoints_to_list_of_dicts, env2bool, inside_notebook, - isinstance_without_import, matplotlib_installed, open_file_in_browser, ) @@ -380,24 +380,22 @@ def log_metric( logger.debug(f"Logged {name}: {val}") def log_image(self, name: str, val): + if not Image.could_log(val): + raise InvalidDataTypeError(name, type(val)) + # If we're given a path, try loading the image first. This might error out. if isinstance(val, (str, Path)): from PIL import Image as ImagePIL - val = ImagePIL.open(val) + suffix = Path(val).suffix + if not Path(name).suffix and suffix in Image.suffixes: + name = f"{name}{suffix}" - # If the provided image name does not have a format on it, - # try to infer the format from PIL Image. - if ( - len(name.split(".")) <= 1 - and isinstance_without_import(val, "PIL.Image", "Image") - and f".{str(val.format).lower()}" in Image.suffixes - ): - name = f"{name}.{str(val.format).lower()}" + val = ImagePIL.open(val) - # See if the image format and image name are valid - if not Image.could_log(name, val): - raise InvalidDataTypeError(name, type(val)) + # See if the image name is valid + if Path(name).suffix not in Image.suffixes: + raise InvalidImageNameError(name) if name in self._images: image = self._images[name] diff --git a/src/dvclive/plots/image.py b/src/dvclive/plots/image.py index 001d5046..5cc65362 100644 --- a/src/dvclive/plots/image.py +++ b/src/dvclive/plots/image.py @@ -7,7 +7,7 @@ class Image(Data): suffixes = (".jpg", ".jpeg", ".gif", ".png") - subfolder: str = "images" + subfolder = "images" @property def output_path(self) -> Path: @@ -16,21 +16,18 @@ def output_path(self) -> Path: return _path @staticmethod - def could_log(name: str, val: object) -> bool: + def could_log(val: object) -> bool: acceptable = { ("numpy", "ndarray"), ("matplotlib.figure", "Figure"), ("PIL.Image", "Image"), } - - supported_format = False for cls in type(val).mro(): if any(isinstance_without_import(val, *cls) for cls in acceptable): - supported_format = True + return True if isinstance(val, (PurePath, str)): - supported_format = True - - return supported_format and f".{name.split('.')[-1]}" in Image.suffixes + return True + return False def dump(self, val, **kwargs) -> None: # noqa: ARG002 if isinstance_without_import(val, "numpy", "ndarray"): diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index 577b4d12..dfcb8efc 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -4,7 +4,7 @@ from PIL import Image from dvclive import Live -from dvclive.error import InvalidDataTypeError +from dvclive.error import InvalidImageNameError from dvclive.plots import Image as LiveImage @@ -28,7 +28,9 @@ def test_pil(tmp_dir): def test_pil_omitting_extension_doesnt_save_without_valid_format(tmp_dir): live = Live() img = Image.new("RGB", (10, 10), (250, 250, 250)) - with pytest.raises(InvalidDataTypeError, match="has not supported type"): + with pytest.raises( + InvalidImageNameError, match="Cannot log image with name 'whoops'" + ): live.log_image("whoops", img) @@ -51,7 +53,9 @@ def test_pil_omitting_extension_sets_the_format_if_path_given(tmp_dir): def test_invalid_extension(tmp_dir): live = Live() img = Image.new("RGB", (10, 10), (250, 250, 250)) - with pytest.raises(InvalidDataTypeError, match="has not supported type"): + with pytest.raises( + InvalidImageNameError, match="Cannot log image with name 'image.foo'" + ): live.log_image("image.foo", img)