Skip to content

Commit

Permalink
fix: Improve safety of logging images with implicit formats (#744)
Browse files Browse the repository at this point in the history
* fix: Implicitly infer image format if possible

* Guard against images without format

* Whoops, forgot the linters

* refactor image validation

---------

Co-authored-by: dberenbaum <[email protected]>
Co-authored-by: AlexandreKempf <[email protected]>
  • Loading branch information
3 people authored Feb 7, 2024
1 parent b266b80 commit b1e6bf4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/dvclive/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def __init__(self):
super().__init__("`dvcyaml` path must have filename 'dvc.yaml'")


class InvalidImageNameError(DvcLiveError):
def __init__(self, name):
self.name = name
super().__init__(f"Cannot log image with name '{name}'")


class InvalidPlotTypeError(DvcLiveError):
def __init__(self, name):
from .plots import SKLEARN_PLOTS
Expand Down
10 changes: 10 additions & 0 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .error import (
InvalidDataTypeError,
InvalidDvcyamlError,
InvalidImageNameError,
InvalidParameterTypeError,
InvalidPlotTypeError,
InvalidReportModeError,
Expand Down Expand Up @@ -389,11 +390,20 @@ def log_image(
if not Image.could_log(val):
raise InvalidDataTypeError(name, type(val))

# If we're given a path, try loading the image first. This might error out.
if isinstance(val, (str, PurePath)):
from PIL import Image as ImagePIL

suffix = Path(val).suffix
if not Path(name).suffix and suffix in Image.suffixes:
name = f"{name}{suffix}"

val = ImagePIL.open(val)

# See if the image name is valid
if Path(name).suffix not in Image.suffixes:
raise InvalidImageNameError(name)

if name in self._images:
image = self._images[name]
else:
Expand Down
30 changes: 29 additions & 1 deletion tests/plots/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from PIL import Image

from dvclive import Live
from dvclive.error import InvalidImageNameError
from dvclive.plots import Image as LiveImage


Expand All @@ -24,10 +25,37 @@ def test_pil(tmp_dir):
assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists()


def test_pil_omitting_extension_doesnt_save_without_valid_format(tmp_dir):
live = Live()
img = Image.new("RGB", (10, 10), (250, 250, 250))
with pytest.raises(
InvalidImageNameError, match="Cannot log image with name 'whoops'"
):
live.log_image("whoops", img)


def test_pil_omitting_extension_sets_the_format_if_path_given(tmp_dir):
live = Live()
img = Image.new("RGB", (10, 10), (250, 250, 250))

# Save it first, we'll reload it and pass it's path to log_image again
live.log_image("saved_with_format.png", img)

# Now try saving without explicit format and check if the format is set correctly.
live.log_image(
"whoops",
(tmp_dir / live.plots_dir / LiveImage.subfolder / "saved_with_format.png"),
)

assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "whoops.png").exists()


def test_invalid_extension(tmp_dir):
live = Live()
img = Image.new("RGB", (10, 10), (250, 250, 250))
with pytest.raises(ValueError, match="unknown file extension"):
with pytest.raises(
InvalidImageNameError, match="Cannot log image with name 'image.foo'"
):
live.log_image("image.foo", img)


Expand Down

0 comments on commit b1e6bf4

Please sign in to comment.