Skip to content

Commit

Permalink
Fix data-related public API (#368)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #368

This is PR aims to expose the right data-relate API.

There are two more changes made in this PR to convert public api to private api
`check_lambda_fn` -> `_check_lambda_fn`
`deprecation_warning` -> `_deprecation_warning`

X-link: pytorch/pytorch#76143

Reviewed By: albanD, NivekT

Differential Revision: D35798311

Pulled By: ejguan

fbshipit-source-id: b13fded5c88a533c706702fb2070c918c839dca4
ejguan authored and facebook-github-bot committed Apr 21, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 6da391e commit c1d89fe
Showing 8 changed files with 21 additions and 21 deletions.
6 changes: 3 additions & 3 deletions torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from typing import Callable, Iterator, List, TypeVar

from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_lambda_fn

T_co = TypeVar("T_co", covariant=True)

@@ -59,7 +59,7 @@ def __init__(
) -> None:
self.datapipe = datapipe

check_lambda_fn(fn)
_check_lambda_fn(fn)
self.fn = fn # type: ignore[assignment]

assert batch_size > 0, "Batch size is required to be larger than 0!"
@@ -118,7 +118,7 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
def __init__(self, datapipe: IterDataPipe, fn: Callable, input_col=None) -> None:
self.datapipe = datapipe

check_lambda_fn(fn)
_check_lambda_fn(fn)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col

4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from functools import partial
from typing import Callable, Deque, Dict, Iterator, Optional, TypeVar

from torch.utils.data.datapipes.utils.common import check_lambda_fn, DILL_AVAILABLE
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE

from torch.utils.data.graph import traverse
from torchdata.datapipes import functional_datapipe
@@ -160,7 +160,7 @@ def __init__(
):
self.source_datapipe = source_datapipe

check_lambda_fn(filepath_fn)
_check_lambda_fn(filepath_fn)
filepath_fn = _generator_to_list(filepath_fn) if inspect.isgeneratorfunction(filepath_fn) else filepath_fn

if hash_dict is not None and hash_type not in ("sha256", "md5"):
12 changes: 6 additions & 6 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from typing import Callable, Iterator, Optional, TypeVar

from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_lambda_fn

T_co = TypeVar("T_co", covariant=True)

@@ -64,14 +64,14 @@ def __init__(
raise TypeError(f"ref_datapipe must be a IterDataPipe, but its type is {type(ref_datapipe)} instead.")
self.source_datapipe = source_datapipe
self.ref_datapipe = ref_datapipe
check_lambda_fn(key_fn)
_check_lambda_fn(key_fn)
self.key_fn = key_fn
if ref_key_fn is not None:
check_lambda_fn(ref_key_fn)
_check_lambda_fn(ref_key_fn)
self.ref_key_fn = key_fn if ref_key_fn is None else ref_key_fn
self.keep_key = keep_key
if merge_fn is not None:
check_lambda_fn(merge_fn)
_check_lambda_fn(merge_fn)
self.merge_fn = merge_fn
if buffer_size is not None and buffer_size <= 0:
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
@@ -153,10 +153,10 @@ def __init__(
raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.")
self.source_iterdatapipe: IterDataPipe = source_iterdatapipe
self.map_datapipe: MapDataPipe = map_datapipe
check_lambda_fn(key_fn)
_check_lambda_fn(key_fn)
self.key_fn: Callable = key_fn
if merge_fn is not None:
check_lambda_fn(merge_fn)
_check_lambda_fn(merge_fn)
self.merge_fn: Optional[Callable] = merge_fn
self.length: int = -1

4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/paragraphaggregator.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@

from typing import Callable, Iterator, List, Tuple, TypeVar

from torch.utils.data.datapipes.utils.common import check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_lambda_fn

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
@@ -44,7 +44,7 @@ class ParagraphAggregatorIterDataPipe(IterDataPipe[Tuple[str, str]]):

def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Callable = _default_line_join) -> None:
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
check_lambda_fn(joiner)
_check_lambda_fn(joiner)
self.joiner: Callable = joiner

def __iter__(self) -> Iterator[Tuple[str, str]]:
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/tararchiveloader.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from io import BufferedIOBase
from typing import cast, IO, Iterable, Iterator, Optional, Tuple

from torch.utils.data.datapipes.utils.common import deprecation_warning
from torch.utils.data.datapipes.utils.common import _deprecation_warning

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
@@ -85,7 +85,7 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
"""

def __new__(cls, datapipe: Iterable[Tuple[str, BufferedIOBase]], mode: str = "r:*", length: int = -1):
deprecation_warning(
_deprecation_warning(
cls.__name__,
deprecation_version="0.4",
removal_version="0.6",
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/xzfileloader.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from io import BufferedIOBase
from typing import Iterable, Iterator, Tuple

from torch.utils.data.datapipes.utils.common import deprecation_warning
from torch.utils.data.datapipes.utils.common import _deprecation_warning

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
@@ -74,7 +74,7 @@ class XzFileReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
"""

def __new__(cls, datapipe: Iterable[Tuple[str, BufferedIOBase]], length: int = -1):
deprecation_warning(
_deprecation_warning(
cls.__name__,
deprecation_version="0.4",
removal_version="0.6",
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/ziparchiveloader.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from io import BufferedIOBase
from typing import cast, IO, Iterable, Iterator, Tuple

from torch.utils.data.datapipes.utils.common import deprecation_warning
from torch.utils.data.datapipes.utils.common import _deprecation_warning

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
@@ -86,7 +86,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
"""

def __new__(cls, datapipe: Iterable[Tuple[str, BufferedIOBase]], length: int = -1):
deprecation_warning(
_deprecation_warning(
cls.__name__,
deprecation_version="0.4",
removal_version="0.6",
4 changes: 2 additions & 2 deletions torchdata/datapipes/map/util/utils.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from typing import Callable, Dict, Optional

from torch.utils.data import IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import check_lambda_fn, DILL_AVAILABLE
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE

if DILL_AVAILABLE:
import dill
@@ -42,7 +42,7 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
if not isinstance(datapipe, IterDataPipe):
raise TypeError(f"IterToMapConverter can only apply on IterDataPipe, but found {type(datapipe)}")
self.datapipe = datapipe
check_lambda_fn(key_value_fn)
_check_lambda_fn(key_value_fn)
self.key_value_fn = key_value_fn # type: ignore[assignment]
self._map = None
self._length = -1

0 comments on commit c1d89fe

Please sign in to comment.