Skip to content

Commit

Permalink
Adding UnZipperMapDataPipe (#325)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #325

This `MapDataPipe` seems simple enough to add but we should talk about what is the general guideline for adding `MapDataPipe` before actually doing it.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D35979906

Pulled By: NivekT

fbshipit-source-id: fce7bad9b6c2dad10c815a3d708f12349519823f
  • Loading branch information
NivekT authored and facebook-github-bot committed May 2, 2022
1 parent 1171ec2 commit 47e6fcc
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ MapDataPipes
Mapper
SequenceWrapper
Shuffler
UnZipper
Zipper
7 changes: 2 additions & 5 deletions test/test_datapipe.py → test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torchdata

from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls
from torch.utils.data.datapipes.map import SequenceWrapper
from torchdata.datapipes.iter import (
BucketBatcher,
Cycler,
Expand All @@ -37,7 +36,7 @@
SampleMultiplexer,
UnZipper,
)
from torchdata.datapipes.map import MapDataPipe
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper


def test_torchdata_pytorch_consistency() -> None:
Expand Down Expand Up @@ -65,7 +64,7 @@ def extract_datapipe_names(module):
raise AssertionError(msg + "\n".join(sorted(missing_datapipes)))


class TestDataPipe(expecttest.TestCase):
class TestIterDataPipe(expecttest.TestCase):
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(10))
cache_dp = source_dp.in_memory_cache(size=5)
Expand Down Expand Up @@ -817,8 +816,6 @@ def test_unzipper_iterdatapipe(self):
self.assertEqual(len(source_dp), len(dp2))
self.assertEqual(len(source_dp), len(dp3))

# TODO: Add testing for different stages of pickling for UnZipper

def test_itertomap_mapdatapipe(self):
# Functional Test with None key_value_fn
values = list(range(10))
Expand Down
44 changes: 44 additions & 0 deletions test/test_mapdatapipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import expecttest
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper, UnZipper


class TestMapDataPipe(expecttest.TestCase):
def test_unzipper_mapdatapipe(self) -> None:
source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(10)])

# Functional Test: unzips each sequence, with `sequence_length` specified
dp1: MapDataPipe
dp2: MapDataPipe
dp3: MapDataPipe
dp1, dp2, dp3 = UnZipper(source_dp, sequence_length=3) # type: ignore[misc]
self.assertEqual(list(range(10)), list(dp1))
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

# Functional Test: skipping over specified values
dp2, dp3 = source_dp.unzip(sequence_length=3, columns_to_skip=[0])
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

(dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2])
self.assertEqual(list(range(10, 20)), list(dp2))

source_dp = SequenceWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)])
dp2, dp3 = source_dp.unzip(sequence_length=4, columns_to_skip=[0, 3])
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

# __len__ Test: the lengths of child DataPipes are correct
self.assertEqual((10, 10), (len(dp2), len(dp3)))


if __name__ == "__main__":
unittest.main()
42 changes: 41 additions & 1 deletion test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import expecttest
import torchdata.datapipes.iter as iterdp
import torchdata.datapipes.map as mapdp
from _utils._common_utils_for_test import create_temp_dir, create_temp_files
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE
from torchdata.datapipes.iter import IterableWrapper
Expand Down Expand Up @@ -359,8 +360,47 @@ def test_serializable_with_dill(self):


class TestMapDataPipeSerialization(expecttest.TestCase):
def _serialization_test_helper(self, datapipe):
serialized_dp = pickle.dumps(datapipe)
deserialized_dp = pickle.loads(serialized_dp)
try:
self.assertEqual(list(datapipe), list(deserialized_dp))
except AssertionError as e:
print(f"{datapipe} is failing.")
raise e

def _serialization_test_for_dp_with_children(self, dp1, dp2):
self._serialization_test_helper(dp1)
self._serialization_test_helper(dp2)

def test_serializable(self):
pass
picklable_datapipes: List = [
(mapdp.UnZipper, SequenceWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}),
]

dp_skip_comparison = set()
# These DataPipes produce multiple DataPipes as outputs and those should be compared
dp_compare_children = {mapdp.UnZipper}

for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
try:
# Creating input (usually a DataPipe) for the specific dpipe being tested
if custom_input is None:
custom_input = SequenceWrapper(range(10))

if dpipe in dp_skip_comparison: # Mke sure they are picklable and loadable (no value comparison)
datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg]
serialized_dp = pickle.dumps(datapipe)
_ = pickle.loads(serialized_dp)
elif dpipe in dp_compare_children: # DataPipes that have children
dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_for_dp_with_children(dp1, dp2)
else: # Single DataPipe that requires comparison
datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self._serialization_test_helper(datapipe)
except Exception as e:
print(f"{dpipe} is failing.")
raise e

def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function as argument"""
Expand Down
13 changes: 12 additions & 1 deletion torchdata/datapipes/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@
from torch.utils.data import MapDataPipe

from torch.utils.data.datapipes.map import Batcher, Concater, Mapper, SequenceWrapper, Shuffler, Zipper
from torchdata.datapipes.map.util.unzipper import UnZipperMapDataPipe as UnZipper

from torchdata.datapipes.map.util.utils import IterToMapConverterMapDataPipe as IterToMapConverter

__all__ = ["Batcher", "Concater", "IterToMapConverter", "Mapper", "SequenceWrapper", "Shuffler", "Zipper"]
__all__ = [
"Batcher",
"Concater",
"IterToMapConverter",
"MapDataPipe",
"Mapper",
"SequenceWrapper",
"Shuffler",
"UnZipper",
"Zipper",
]

# Please keep this list sorted
assert __all__ == sorted(__all__)
76 changes: 76 additions & 0 deletions torchdata/datapipes/map/util/unzipper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Sequence, TypeVar

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.map import MapDataPipe


T = TypeVar("T")


@functional_datapipe("unzip")
class UnZipperMapDataPipe(MapDataPipe):
"""
Takes in a DataPipe of Sequences, unpacks each Sequence, and return the elements in separate DataPipes
based on their position in the Sequence (functional name: ``unzip``). The number of instances produced
equals to the ``sequence_legnth`` minus the number of columns to skip.
Note:
Each sequence within the DataPipe should have the same length, specified by
the input argument `sequence_length`.
Args:
source_datapipe: Iterable DataPipe with sequences of data
sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length.
columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be
an integer from 0 to sequence_length - 1)
Example:
>>> from torchdata.datapipes.iter import SequenceWrapper
>>> source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(3)])
>>> dp1, dp2, dp3 = source_dp.unzip(sequence_length=3)
>>> list(dp1)
[0, 1, 2]
>>> list(dp2)
[10, 11, 12]
>>> list(dp3)
[20, 21, 22]
"""

def __new__(
cls,
source_datapipe: MapDataPipe[Sequence[T]],
sequence_length: int,
columns_to_skip: Optional[Sequence[int]] = None,
):
if sequence_length < 1:
raise ValueError(f"Expected `sequence_length` larger than 0, but {sequence_length} is found")
if columns_to_skip is None:
instance_ids = list(range(sequence_length))
else:
skips = set(columns_to_skip)
instance_ids = [i for i in range(sequence_length) if i not in skips]

if len(instance_ids) == 0:
raise RuntimeError(
f"All instances are being filtered out in {cls.__name__}. Please check"
"the input `sequence_length` and `columns_to_skip`."
)
return [_UnZipperMapDataPipe(source_datapipe, i) for i in instance_ids]


class _UnZipperMapDataPipe(MapDataPipe[T]):
def __init__(self, main_datapipe: MapDataPipe[Sequence[T]], instance_id: int):
self.main_datapipe = main_datapipe
self.instance_id = instance_id

def __getitem__(self, index) -> T:
return self.main_datapipe[index][self.instance_id]

def __len__(self) -> int:
return len(self.main_datapipe)

0 comments on commit 47e6fcc

Please sign in to comment.