Skip to content

Commit

Permalink
Adding RepeaterIterDataPipe (#748)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #748

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D38881275

Pulled By: NivekT

fbshipit-source-id: 1d33d17b4cc9020633a1bc84f4d6d68ba7ad3045
  • Loading branch information
NivekT authored and facebook-github-bot committed Aug 24, 2022
1 parent 3a61e76 commit 6bd0eb7
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ These DataPipes help to augment your samples.
Cycler
Enumerator
IndexAdder
Repeater

Combinatorial DataPipes
-----------------------------
Expand Down
26 changes: 26 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MapKeyZipper,
MaxTokenBucketizer,
ParagraphAggregator,
Repeater,
Rows2Columnar,
SampleMultiplexer,
UnZipper,
Expand Down Expand Up @@ -257,6 +258,31 @@ def odd_even_bug(i: int) -> int:
result_dp = source_dp.zip_with_map(map_dp, odd_even)
self.assertEqual(len(source_dp), len(result_dp))

def test_repeater_iterdatapipe(self) -> None:
import itertools

source_dp = IterableWrapper(range(5))

# Functional Test: repeat for correct number of times
repeater_dp = source_dp.repeat(3)
self.assertEqual(
list(itertools.chain.from_iterable(itertools.repeat(x, 3) for x in range(5))), list(repeater_dp)
)

# Functional Test: `times` must be > 1
with self.assertRaisesRegex(ValueError, "The number of repetition must be > 1"):
source_dp.repeat(1)

# Reset Test:
repeater_dp = Repeater(source_dp, times=2)
n_elements_before_reset = 4
res_before_reset, res_after_reset = reset_after_n_next_calls(repeater_dp, n_elements_before_reset)
self.assertEqual([0, 0, 1, 1], res_before_reset)
self.assertEqual(list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in range(5))), res_after_reset)

# __len__ Test: returns correct length
self.assertEqual(10, len(repeater_dp))

def test_cycler_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(5))

Expand Down
1 change: 1 addition & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def test_serializable(self):
(),
{},
),
(iterdp.Repeater, None, (2,), {}),
(iterdp.SampleMultiplexer, {IterableWrapper([0] * 10): 0.5, IterableWrapper([1] * 10): 0.5}, (), {}),
(
iterdp.Saver,
Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
IterKeyZipperIterDataPipe as IterKeyZipper,
MapKeyZipperIterDataPipe as MapKeyZipper,
)
from torchdata.datapipes.iter.util.cycler import CyclerIterDataPipe as Cycler
from torchdata.datapipes.iter.util.cycler import CyclerIterDataPipe as Cycler, RepeaterIterDataPipe as Repeater
from torchdata.datapipes.iter.util.dataframemaker import (
DataFrameMakerIterDataPipe as DataFrameMaker,
ParquetDFLoaderIterDataPipe as ParquetDataFrameLoader,
Expand Down Expand Up @@ -189,6 +189,7 @@
"ParagraphAggregator",
"ParquetDataFrameLoader",
"RarArchiveLoader",
"Repeater",
"RoutedDecoder",
"Rows2Columnar",
"S3FileLister",
Expand Down
39 changes: 39 additions & 0 deletions torchdata/datapipes/iter/util/cycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class CyclerIterDataPipe(IterDataPipe[T_co]):
Cycles the specified input in perpetuity by default, or for the specified number
of times (functional name: ``cycle``).
If the ordering does not matter (e.g. because you plan to ``shuffle`` later) or if you would like to
repeat an element multiple times before moving onto the next element, use :class:`.Repeater`.
Args:
source_datapipe: source DataPipe that will be cycled through
count: the number of times to read through ``source_datapipe` (if ``None``, it will cycle in perpetuity)
Expand Down Expand Up @@ -49,3 +52,39 @@ def __len__(self) -> int:
)
else:
return self.count * len(self.source_datapipe)


@functional_datapipe("repeat")
class RepeaterIterDataPipe(IterDataPipe[T_co]):
"""
Repeatedly yield each element of source DataPipe for the specified number of times before
moving onto the next element (functional name: ``repeat``). Note that no copy is made in this DataPipe,
the same element is yielded repeatedly.
If you would like to yield the whole DataPipe in order multiple times, use :class:`.Cycler`.
Args:
source_datapipe: source DataPipe that will be iterated through
times: the number of times an element of ``source_datapipe`` will be yielded before moving onto the next element
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(range(3))
>>> dp = dp.repeat(2)
>>> list(dp)
[0, 0, 1, 1, 2, 2]
"""

def __init__(self, source_datapipe: IterDataPipe[T_co], times: int) -> None:
self.source_datapipe: IterDataPipe[T_co] = source_datapipe
self.times: int = times
if times <= 1:
raise ValueError(f"The number of repetition must be > 1, got {times}")

def __iter__(self) -> Iterator[T_co]:
for element in self.source_datapipe:
for _ in range(self.times):
yield element

def __len__(self) -> int:
return self.times * len(self.source_datapipe)

0 comments on commit 6bd0eb7

Please sign in to comment.