diff --git a/src/dvclive/live.py b/src/dvclive/live.py index ce8d5357..48b5f94f 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -7,12 +7,14 @@ import os import shutil import tempfile -from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING +from pathlib import Path, PurePath +from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal if TYPE_CHECKING: import numpy as np import pandas as pd + import matplotlib + import PIL from dvc.exceptions import DvcException from funcy import set_in @@ -69,9 +71,9 @@ def __init__( self, dir: str = "dvclive", # noqa: A002 resume: bool = False, - report: Optional[str] = None, + report: Literal["md", "notebook", "html", None] = None, save_dvc_exp: bool = True, - dvcyaml: Union[str, bool] = "dvc.yaml", + dvcyaml: Optional[str] = "dvc.yaml", cache_images: bool = False, exp_name: Optional[str] = None, exp_message: Optional[str] = None, @@ -379,11 +381,15 @@ def log_metric( self.summary = set_in(self.summary, metric.summary_keys, val) logger.debug(f"Logged {name}: {val}") - def log_image(self, name: str, val): + def log_image( + self, + name: str, + val: Union[np.ndarray, matplotlib.figure.Figure, PIL.Image, StrPath], + ): if not Image.could_log(val): raise InvalidDataTypeError(name, type(val)) - if isinstance(val, (str, Path)): + if isinstance(val, (str, PurePath)): from PIL import Image as ImagePIL val = ImagePIL.open(val) @@ -401,10 +407,10 @@ def log_image(self, name: str, val): def log_plot( self, name: str, - datapoints: pd.DataFrame | np.ndarray | List[Dict], + datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]], x: str, y: str, - template: Optional[str] = None, + template: Optional[str] = "linear", title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, @@ -434,7 +440,14 @@ def log_plot( plot.dump(datapoints) logger.debug(f"Logged {name}") - def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs): + def log_sklearn_plot( + self, + kind: str, + labels: Union[List, np.ndarray], + predictions: Union[List, Tuple, np.ndarray], + name: Optional[str] = None, + **kwargs, + ): val = (labels, predictions) plot_config = { @@ -491,7 +504,7 @@ def log_artifact( cache: bool = True, ): """Tracks a local file or directory with DVC""" - if not isinstance(path, (str, Path)): + if not isinstance(path, (str, PurePath)): raise InvalidDataTypeError(path, builtins.type(path)) if self._dvc_repo is not None: @@ -574,7 +587,7 @@ def make_dvcyaml(self): make_dvcyaml(self) @catch_and_warn(DvcException, logger) - def post_to_studio(self, event): + def post_to_studio(self, event: str): post_to_studio(self, event) def end(self): diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index a168de9e..8b762590 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -4,7 +4,7 @@ import os import re import shutil -from pathlib import Path +from pathlib import Path, PurePath from platform import uname from typing import Union, List, Dict, TYPE_CHECKING import webbrowser @@ -26,7 +26,7 @@ np = None -StrPath = Union[str, Path] +StrPath = Union[str, PurePath] def run_once(f):