Skip to content

Commit

Permalink
add typehints to Live's public methods (#770)
Browse files Browse the repository at this point in the history
* add typehints to Live public methods

* Fixes type in general

- Removes redundancy with types import
- strPath is now using PurePath instead of Path

* revert types for template plots

* remove imported types and fix typo

* revert sklearnkind to str because types cannot be generated dynamically

* fix mypy error on _dvc_repo

---------

Co-authored-by: Dave Berenbaum <[email protected]>
  • Loading branch information
AlexandreKempf and Dave Berenbaum authored Feb 7, 2024
1 parent 38f2f51 commit b266b80
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
35 changes: 24 additions & 11 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
from pathlib import Path, PurePath
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal

if TYPE_CHECKING:
import numpy as np
import pandas as pd
import matplotlib
import PIL

from dvc.exceptions import DvcException
from funcy import set_in
Expand Down Expand Up @@ -69,9 +71,9 @@ def __init__(
self,
dir: str = "dvclive", # noqa: A002
resume: bool = False,
report: Optional[str] = None,
report: Literal["md", "notebook", "html", None] = None,
save_dvc_exp: bool = True,
dvcyaml: Union[str, bool] = "dvc.yaml",
dvcyaml: Optional[str] = "dvc.yaml",
cache_images: bool = False,
exp_name: Optional[str] = None,
exp_message: Optional[str] = None,
Expand Down Expand Up @@ -379,11 +381,15 @@ def log_metric(
self.summary = set_in(self.summary, metric.summary_keys, val)
logger.debug(f"Logged {name}: {val}")

def log_image(self, name: str, val):
def log_image(
self,
name: str,
val: Union[np.ndarray, matplotlib.figure.Figure, PIL.Image, StrPath],
):
if not Image.could_log(val):
raise InvalidDataTypeError(name, type(val))

if isinstance(val, (str, Path)):
if isinstance(val, (str, PurePath)):
from PIL import Image as ImagePIL

val = ImagePIL.open(val)
Expand All @@ -401,10 +407,10 @@ def log_image(self, name: str, val):
def log_plot(
self,
name: str,
datapoints: pd.DataFrame | np.ndarray | List[Dict],
datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]],
x: str,
y: str,
template: Optional[str] = None,
template: Optional[str] = "linear",
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
Expand Down Expand Up @@ -434,7 +440,14 @@ def log_plot(
plot.dump(datapoints)
logger.debug(f"Logged {name}")

def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
def log_sklearn_plot(
self,
kind: str,
labels: Union[List, np.ndarray],
predictions: Union[List, Tuple, np.ndarray],
name: Optional[str] = None,
**kwargs,
):
val = (labels, predictions)

plot_config = {
Expand Down Expand Up @@ -491,7 +504,7 @@ def log_artifact(
cache: bool = True,
):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
if not isinstance(path, (str, PurePath)):
raise InvalidDataTypeError(path, builtins.type(path))

if self._dvc_repo is not None:
Expand Down Expand Up @@ -574,7 +587,7 @@ def make_dvcyaml(self):
make_dvcyaml(self)

@catch_and_warn(DvcException, logger)
def post_to_studio(self, event):
def post_to_studio(self, event: str):
post_to_studio(self, event)

def end(self):
Expand Down
4 changes: 2 additions & 2 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
import shutil
from pathlib import Path
from pathlib import Path, PurePath
from platform import uname
from typing import Union, List, Dict, TYPE_CHECKING
import webbrowser
Expand All @@ -26,7 +26,7 @@
np = None


StrPath = Union[str, Path]
StrPath = Union[str, PurePath]


def run_once(f):
Expand Down

0 comments on commit b266b80

Please sign in to comment.