Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UnZipperMapDataPipe #325

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ DataPipes
Mapper
SequenceWrapper
Shuffler
UnZipper
Zipper
7 changes: 2 additions & 5 deletions test/test_datapipe.py → test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torchdata

from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls
from torch.utils.data.datapipes.map import SequenceWrapper
from torchdata.datapipes.iter import (
BucketBatcher,
Cycler,
Expand All @@ -37,7 +36,7 @@
SampleMultiplexer,
UnZipper,
)
from torchdata.datapipes.map import MapDataPipe
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper


def test_torchdata_pytorch_consistency() -> None:
Expand Down Expand Up @@ -65,7 +64,7 @@ def extract_datapipe_names(module):
raise AssertionError(msg + "\n".join(sorted(missing_datapipes)))


class TestDataPipe(expecttest.TestCase):
class TestIterDataPipe(expecttest.TestCase):
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10))
cache_dp = source_dp.in_memory_cache(size=5)
Expand Down Expand Up @@ -817,8 +816,6 @@ def test_unzipper_iterdatapipe(self):
self.assertEqual(len(source_dp), len(dp2))
self.assertEqual(len(source_dp), len(dp3))

# TODO: Add testing for different stages of pickling for UnZipper

def test_itertomap_mapdatapipe(self):
# Functional Test with None key_value_fn
values = list(range(10))
Expand Down
44 changes: 44 additions & 0 deletions test/test_mapdatapipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import expecttest
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper, UnZipper


class TestMapDataPipe(expecttest.TestCase):
def test_unzipper_mapdatapipe(self) -> None:
source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(10)])

# Functional Test: unzips each sequence, with `sequence_length` specified
dp1: MapDataPipe
dp2: MapDataPipe
dp3: MapDataPipe
dp1, dp2, dp3 = UnZipper(source_dp, sequence_length=3) # type: ignore[misc]
self.assertEqual(list(range(10)), list(dp1))
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

# Functional Test: skipping over specified values
dp2, dp3 = source_dp.unzip(sequence_length=3, columns_to_skip=[0])
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

(dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2])
self.assertEqual(list(range(10, 20)), list(dp2))

source_dp = SequenceWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)])
dp2, dp3 = source_dp.unzip(sequence_length=4, columns_to_skip=[0, 3])
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

# __len__ Test: the lengths of child DataPipes are correct
self.assertEqual((10, 10), (len(dp2), len(dp3)))


if __name__ == "__main__":
unittest.main()
42 changes: 41 additions & 1 deletion test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import expecttest
import torchdata.datapipes.iter as iterdp
import torchdata.datapipes.map as mapdp
from _utils._common_utils_for_test import create_temp_dir, create_temp_files
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE
from torchdata.datapipes.iter import IterableWrapper
Expand Down Expand Up @@ -358,8 +359,47 @@ def test_serializable_with_dill(self):


class TestMapDataPipeSerialization(expecttest.TestCase):
def _serialization_test_helper(self, datapipe):
serialized_dp = pickle.dumps(datapipe)
deserialized_dp = pickle.loads(serialized_dp)
try:
self.assertEqual(list(datapipe), list(deserialized_dp))
except AssertionError as e:
print(f"{datapipe} is failing.")
raise e

def _serialization_test_for_dp_with_children(self, dp1, dp2):
self._serialization_test_helper(dp1)
self._serialization_test_helper(dp2)

def test_serializable(self):
pass
picklable_datapipes: List = [
(mapdp.UnZipper, SequenceWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}),
]

dp_skip_comparison = set()
# These DataPipes produce multiple DataPipes as outputs and those should be compared
dp_compare_children = {mapdp.UnZipper}

for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
try:
# Creating input (usually a DataPipe) for the specific dpipe being tested
if custom_input is None:
custom_input = SequenceWrapper(range(10))

if dpipe in dp_skip_comparison: # Mke sure they are picklable and loadable (no value comparison)
datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg]
serialized_dp = pickle.dumps(datapipe)
_ = pickle.loads(serialized_dp)
elif dpipe in dp_compare_children: # DataPipes that have children
dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_dp_with_children(dp1, dp2)
else: # Single DataPipe that requires comparison
datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_helper(datapipe)
except Exception as e:
print(f"{dpipe} is failing.")
raise e

def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument"""
Expand Down
13 changes: 12 additions & 1 deletion torchdata/datapipes/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@
from torch.utils.data import MapDataPipe

from torch.utils.data.datapipes.map import Batcher, Concater, Mapper, SequenceWrapper, Shuffler, Zipper
from torchdata.datapipes.map.util.unzipper import UnZipperMapDataPipe as UnZipper

from torchdata.datapipes.map.util.utils import IterToMapConverterMapDataPipe as IterToMapConverter

__all__ = ["Batcher", "Concater", "IterToMapConverter", "Mapper", "SequenceWrapper", "Shuffler", "Zipper"]
__all__ = [
"Batcher",
"Concater",
"IterToMapConverter",
"MapDataPipe",
"Mapper",
"SequenceWrapper",
"Shuffler",
"UnZipper",
"Zipper",
]

# Please keep this list sorted
assert __all__ == sorted(__all__)
76 changes: 76 additions & 0 deletions torchdata/datapipes/map/util/unzipper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Sequence, TypeVar

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.map import MapDataPipe


T = TypeVar("T")


@functional_datapipe("unzip")
class UnZipperMapDataPipe(MapDataPipe):
"""
Takes in a DataPipe of Sequences, unpacks each Sequence, and return the elements in separate DataPipes
based on their position in the Sequence (functional name: ``unzip``). The number of instances produced
equals to the ``sequence_legnth`` minus the number of columns to skip.

Note:
Each sequence within the DataPipe should have the same length, specified by
the input argument `sequence_length`.

Args:
source_datapipe: Iterable DataPipe with sequences of data
sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length.
columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be
an integer from 0 to sequence_length - 1)

Example:
>>> from torchdata.datapipes.iter import SequenceWrapper
>>> source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(3)])
>>> dp1, dp2, dp3 = source_dp.unzip(sequence_length=3)
>>> list(dp1)
[0, 1, 2]
>>> list(dp2)
[10, 11, 12]
>>> list(dp3)
[20, 21, 22]
"""

def __new__(
cls,
source_datapipe: MapDataPipe[Sequence[T]],
sequence_length: int,
columns_to_skip: Optional[Sequence[int]] = None,
):
if sequence_length < 1:
raise ValueError(f"Expected `sequence_length` larger than 0, but {sequence_length} is found")
if columns_to_skip is None:
instance_ids = list(range(sequence_length))
else:
skips = set(columns_to_skip)
instance_ids = [i for i in range(sequence_length) if i not in skips]

if len(instance_ids) == 0:
raise RuntimeError(
f"All instances are being filtered out in {cls.__name__}. Please check"
"the input `sequence_length` and `columns_to_skip`."
)
return [_UnZipperMapDataPipe(source_datapipe, i) for i in instance_ids]


class _UnZipperMapDataPipe(MapDataPipe[T]):
def __init__(self, main_datapipe: MapDataPipe[Sequence[T]], instance_id: int):
self.main_datapipe = main_datapipe
self.instance_id = instance_id

def __getitem__(self, index) -> T:
return self.main_datapipe[index][self.instance_id]

def __len__(self) -> int:
return len(self.main_datapipe)