Skip to content

Commit

Permalink
Adding usage examples to all IterDataPipes (#249)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #249

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D34433825

Pulled By: NivekT

fbshipit-source-id: c2bd60eb2ea957f064486064baf5978e7dcb3441
  • Loading branch information
NivekT authored and facebook-github-bot committed Feb 24, 2022
1 parent ebee4ca commit 2cf1f20
Show file tree
Hide file tree
Showing 26 changed files with 375 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
"torch.utils.data.datapipes.map.grouping.T": "T",
"torch.utils.data.datapipes.map.combining.T_co": "T_co",
"torch.utils.data.datapipes.map.combinatorics.T_co": "T_co",
"torchdata.datapipes.iter.util.cycler.T_co": "T_co",
"torchdata.datapipes.iter.util.paragraphaggregator.T_co": "T_co",
"typing.": "",
}

Expand Down
2 changes: 1 addition & 1 deletion docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ These DataPipes help opening and decompressing archive files of different format
:toctree: generated/
:template: datapipe.rst

Extractor
Decompressor
RarArchiveLoader
TarArchiveLoader
XzFileLoader
Expand Down
18 changes: 18 additions & 0 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
Args:
root: The root `fsspec` path directory to list files from
masks: Unix style filter string or string list for filtering file name(s)
Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)
"""

def __init__(
Expand Down Expand Up @@ -82,6 +86,11 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: Iterable DataPipe that provides the pathnames or URLs
mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default)
Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)
>>> file_dp = datapipe.open_file_by_fsspec()
"""

def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r") -> None:
Expand Down Expand Up @@ -111,6 +120,15 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):
source_datapipe: Iterable DataPipe with tuples of metadata and data
mode: Mode in which the file will be opened for write the data (``"w"`` by default)
filepath_fn: Function that takes in metadata and returns the target path of the new file
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def filepath_fn(name: str) -> str:
>>> return dir_path + name
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
>>> fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
>>> res_file_paths = list(fsspec_saver_dp)
"""

def __init__(
Expand Down
18 changes: 18 additions & 0 deletions torchdata/datapipes/iter/load/iopath.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class IoPathFileListerIterDataPipe(IterDataPipe[str]):
Note:
Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL.
S3 URL is supported only with ``iopath``>=0.1.9.
Example:
>>> from torchdata.datapipes.iter import IoPathFileLister
>>> datapipe = IoPathFileLister(root=S3URL)
"""

def __init__(
Expand Down Expand Up @@ -93,6 +97,11 @@ class IoPathFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Note:
Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL.
S3 URL is supported only with `iopath`>=0.1.9.
Example:
>>> from torchdata.datapipes.iter import IoPathFileLister
>>> datapipe = IoPathFileLister(root=S3URL)
>>> file_dp = datapipe.open_file_by_iopath()
"""

def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", pathmgr=None) -> None:
Expand Down Expand Up @@ -135,6 +144,15 @@ class IoPathSaverIterDataPipe(IterDataPipe[str]):
Note:
Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL.
S3 URL is supported only with `iopath`>=0.1.9.
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def filepath_fn(name: str) -> str:
>>> return S3URL + name
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
>>> iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
>>> res_file_paths = list(iopath_saver_dp)
"""

def __init__(
Expand Down
36 changes: 36 additions & 0 deletions torchdata/datapipes/iter/load/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs
timeout: timeout in seconds for HTTP request
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]))
>>> reader_dp = http_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
>>> line
b'BSD 3-Clause License'
"""

def __init__(self, source_datapipe: IterDataPipe[str], timeout: Optional[float] = None) -> None:
Expand Down Expand Up @@ -85,6 +97,18 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs to GDrive files
timeout: timeout in seconds for HTTP request
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader
>>> gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
>>> gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
>>> reader_dp = gdrive_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile
>>> line
<First line from the GDrive File>
"""
source_datapipe: IterDataPipe[str]

Expand All @@ -108,6 +132,18 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs
timeout: timeout in seconds for HTTP request
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
>>> online_reader_dp = OnlineReader(IterableWrapper([file_url]))
>>> reader_dp = online_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
>>> line
b'BSD 3-Clause License'
"""
source_datapipe: IterDataPipe[str]

Expand Down
25 changes: 22 additions & 3 deletions torchdata/datapipes/iter/transform/bucketbatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,33 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
The purpose of this DataPipe is to batch samples with some similarity according to the sorting function
being passed. For an example in the text domain, it may be batching examples with similar number of tokens
to minimize padding and to increase throughput.
Args:
datapipe: Iterable DataPipe being batched
batch_size: The size of each batch
drop_last: Option to drop the last batch if it's not full
batch_num: Number of batches within a bucket (i.e. `bucket_size = batch_size * batch_num`)
bucket_num: Number of buckets to consist a pool for shuffling (i.e. `pool_size = bucket_size * bucket_num`)
sort_key: Callable to specify the comparison key for sorting within bucket
in_batch_shuffle: Option to do in-batch shuffle or buffer shuffle
sort_key: Callable to sort a bucket (list)
in_batch_shuffle: iF True, do in-batch shuffle; if False, buffer shuffle
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper(range(10))
>>> batch_dp = source_dp.bucketbatch(batch_size=3, drop_last=True)
>>> list(batch_dp)
[[5, 6, 7], [9, 0, 1], [4, 3, 2]]
>>> def sort_bucket(bucket):
>>> return sorted(bucket)
>>> batch_dp = source_dp.bucketbatch(
>>> batch_size=3, drop_last=True, batch_num=100,
>>> bucket_num=1, in_batch_shuffle=False, sort_key=sort_bucket
>>> )
>>> list(batch_dp)
[[3, 4, 5], [6, 7, 8], [0, 1, 2]]
"""
datapipe: IterDataPipe[T_co]
batch_size: int
Expand Down Expand Up @@ -71,7 +90,7 @@ def __new__(
datapipe = datapipe.batch(batch_size, drop_last=drop_last)
# Shuffle the batched data
if sort_key is not None:
# In-batch shuffle each bucket seems not that useful
# In-batch shuffle each bucket seems not that useful, it seems misleading since .batch is called prior.
if in_batch_shuffle:
datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
else:
Expand Down
9 changes: 9 additions & 0 deletions torchdata/datapipes/iter/transform/flatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
Args:
datapipe: Source IterDataPipe
fn: the function to be applied to each element in the DataPipe, the output must be a Sequence
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def fn(e):
>>> return [e, e * 10]
>>> source_dp = IterableWrapper(list(range(5)))
>>> flatmapped_dp = source_dp.flatmap(fn)
>>> list(flatmapped_dp)
[0, 0, 1, 10, 2, 20, 3, 30, 4, 40]
"""

def __init__(self, datapipe: IterDataPipe, fn: Callable):
Expand Down
22 changes: 21 additions & 1 deletion torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class InMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]):
Args:
source_dp: source DataPipe from which elements are read and stored in memory
size: The maximum size (in megabytes) that this DataPipe can hold in memory. This defaults to unlimited.
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper(range(10))
>>> cache_dp = source_dp.in_memory_cache(size=5)
>>> list(cache_dp)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
size: Optional[int] = None
idx: int
Expand Down Expand Up @@ -125,13 +132,14 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe):
the given file path from ``filepath_fn``.
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
>>> url = IterableWrapper(["https://path/to/filename", ])
>>> def _filepath_fn(url):
>>> temp_dir = tempfile.gettempdir()
>>> return os.path.join(temp_dir, os.path.basename(url))
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
"""

Expand Down Expand Up @@ -235,6 +243,18 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
same_filepath_fn: Set to ``True`` to use same ``filepath_fn`` from the ``OnDiskCacheHolder``.
skip_read: Boolean value to skip reading the file handle from ``datapipe``.
By default, reading is enabled and reading function is created based on the ``mode``.
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
>>> url = IterableWrapper(["https://path/to/filename", ])
>>> def _filepath_fn(url):
>>> temp_dir = tempfile.gettempdir()
>>> return os.path.join(temp_dir, os.path.basename(url))
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching``
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
"""

def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):
Expand Down
24 changes: 24 additions & 0 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ class IterKeyZipperIterDataPipe(IterDataPipe[T_co]):
If it's specified as ``None``, the buffer size is set as infinite.
merge_fn: Function that combines the item from ``source_datapipe`` and the item from ``ref_datapipe``,
by default a tuple is created
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> from operator import itemgetter
>>> def merge_fn(t1, t2):
>>> return t1[1] + t2[1]
>>> dp1 = IterableWrapper([('a', 100), ('b', 200), ('c', 300)])
>>> dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 3), ('d', 4)])
>>> res_dp = dp1.zip_with_iter(dp2, key_fn=itemgetter(0),
>>> ref_key_fn=itemgetter(0), keep_key=True, merge_fn=merge_fn)
>>> list(res_dp)
[('a', 101), ('b', 202), ('c', 303)]
"""

def __init__(
Expand Down Expand Up @@ -105,6 +117,18 @@ class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):
key_fn: Function that maps each item from ``source_iterdatapipe`` to a key that exists in ``map_datapipe``
merge_fn: Function that combines the item from ``source_iterdatapipe`` and the matching item
from ``map_datapipe``, by default a tuple is created
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> from torchdata.datapipes.map import SequenceWrapper
>>> from operator import itemgetter
>>> def merge_fn(tuple_from_iter, value_from_map):
>>> return tuple_from_iter[0], tuple_from_iter[1] + value_from_map
>>> dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)])
>>> mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
>>> res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn)
>>> list(res_dp)
[('a', 101), ('b', 202), ('c', 303)]
"""

def __init__(
Expand Down
7 changes: 7 additions & 0 deletions torchdata/datapipes/iter/util/cycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ class CyclerIterDataPipe(IterDataPipe[T_co]):
Args:
source_datapipe: source DataPipe that will be cycled through
count: the number of times to read through ``source_datapipe` (if ``None``, it will cycle in perpetuity)
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(range(3))
>>> dp = dp.cycle(2)
>>> list(dp)
[0, 1, 2, 0, 1, 2]
"""

def __init__(self, source_datapipe: IterDataPipe[T_co], count: Optional[int] = None) -> None:
Expand Down
29 changes: 29 additions & 0 deletions torchdata/datapipes/iter/util/dataframemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ class DataFrameMakerIterDataPipe(IterDataPipe): # IterDataPipe[torcharrow.IData
dtype: specify the `TorchArrow` dtype for the DataFrame, use ``torcharrow.dtypes.DType``
columns: List of str that specifies the column names of the DataFrame
device: specify the device on which the DataFrame will be stored
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> import torcharrow.dtypes as dt
>>> source_data = [(i,) for i in range(3)]
>>> source_dp = IterableWrapper(source_data)
>>> DTYPE = dt.Struct([dt.Field("Values", dt.int32)])
>>> df_dp = source_dp.dataframe(dtype=DTYPE)
>>> list(df_dp)[0]
index Values
------- --------
0 0
1 1
2 2
dtype: Struct([Field('Values', int32)]), count: 3, null_count: 0
"""

def __new__(
Expand Down Expand Up @@ -64,6 +79,20 @@ class ParquetDFLoaderIterDataPipe(IterDataPipe): # IterDataPipe[torcharrow.IDat
use_threads: if ``True``, Parquet reader will perform multi-threaded column reads
dtype: specify the `TorchArrow` dtype for the DataFrame, use ``torcharrow.dtypes.DType``
device: specify the device on which the DataFrame will be stored
Example:
>>> from torchdata.datapipes.iter import FileLister
>>> import torcharrow.dtypes as dt
>>> DTYPE = dt.Struct([dt.Field("Values", dt.int32)])
>>> source_dp = FileLister(".", masks="df*.parquet")
>>> parquet_df_dp = source_dp.load_parquet_as_df(dtype=DTYPE)
>>> list(parquet_df_dp)[0]
index Values
------- --------
0 0
1 1
2 2
dtype: Struct([Field('Values', int32)]), count: 3, null_count: 0
"""

def __init__(
Expand Down
9 changes: 9 additions & 0 deletions torchdata/datapipes/iter/util/decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: IterDataPipe containing tuples of path and compressed stream of data
file_type: Optional `string` or ``CompressionType`` that represents what compression format of the inputs
Example:
>>> from torchdata.datapipes.iter import FileLister, FileOpener
>>> tar_file_dp = FileLister(self.temp_dir.name, "*.tar")
>>> tar_load_dp = FileOpener(tar_file_dp, mode="b")
>>> tar_decompress_dp = Decompressor(tar_load_dp, file_type="tar")
>>> for _, stream in tar_decompress_dp:
>>> print(stream.read())
b'0123456789abcdef'
"""

types = CompressionType
Expand Down
14 changes: 13 additions & 1 deletion torchdata/datapipes/iter/util/hashchecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,19 @@ class HashCheckerIterDataPipe(IterDataPipe[Tuple[str, U]]):
does not work with non-seekable stream, e.g. HTTP)
Example:
>>> dp = dp.check_hash({'train.py':'0d8b94d9fa9fb1ad89b9e3da9e1521495dca558fc5213b0fd7fd7b71c23f9921'})
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
>>> expected_MD5_hash = "bb9675028dd39d2dd2bf71002b93e66c"
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]))
>>> # An exception is only raised when the hash doesn't match, otherwise (path, stream) is returned
>>> check_hash_dp = http_reader_dp.check_hash({file_url: expected_MD5_hash}, "md5", rewind=False)
>>> reader_dp = check_hash_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
>>> line
b'BSD 3-Clause License'
"""

def __init__(
Expand Down
Loading

0 comments on commit 2cf1f20

Please sign in to comment.