From 47e6fcc3eda9e3183866724739938b52f1faff61 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Mon, 2 May 2022 07:53:36 -0700 Subject: [PATCH] Adding UnZipperMapDataPipe (#325) Summary: Pull Request resolved: https://github.com/pytorch/data/pull/325 This `MapDataPipe` seems simple enough to add but we should talk about what is the general guideline for adding `MapDataPipe` before actually doing it. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D35979906 Pulled By: NivekT fbshipit-source-id: fce7bad9b6c2dad10c815a3d708f12349519823f --- docs/source/torchdata.datapipes.map.rst | 1 + ...{test_datapipe.py => test_iterdatapipe.py} | 7 +- test/test_mapdatapipe.py | 44 +++++++++++ test/test_serialization.py | 42 +++++++++- torchdata/datapipes/map/__init__.py | 13 +++- torchdata/datapipes/map/util/unzipper.py | 76 +++++++++++++++++++ 6 files changed, 176 insertions(+), 7 deletions(-) rename test/{test_datapipe.py => test_iterdatapipe.py} (99%) create mode 100644 test/test_mapdatapipe.py create mode 100644 torchdata/datapipes/map/util/unzipper.py diff --git a/docs/source/torchdata.datapipes.map.rst b/docs/source/torchdata.datapipes.map.rst index 31702cccd..ccff1f9b6 100644 --- a/docs/source/torchdata.datapipes.map.rst +++ b/docs/source/torchdata.datapipes.map.rst @@ -40,4 +40,5 @@ MapDataPipes Mapper SequenceWrapper Shuffler + UnZipper Zipper diff --git a/test/test_datapipe.py b/test/test_iterdatapipe.py similarity index 99% rename from test/test_datapipe.py rename to test/test_iterdatapipe.py index 219402826..52584aba6 100644 --- a/test/test_datapipe.py +++ b/test/test_iterdatapipe.py @@ -18,7 +18,6 @@ import torchdata from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls -from torch.utils.data.datapipes.map import SequenceWrapper from torchdata.datapipes.iter import ( BucketBatcher, Cycler, @@ -37,7 +36,7 @@ SampleMultiplexer, UnZipper, ) -from torchdata.datapipes.map import MapDataPipe +from torchdata.datapipes.map import MapDataPipe, SequenceWrapper def test_torchdata_pytorch_consistency() -> None: @@ -65,7 +64,7 @@ def extract_datapipe_names(module): raise AssertionError(msg + "\n".join(sorted(missing_datapipes))) -class TestDataPipe(expecttest.TestCase): +class TestIterDataPipe(expecttest.TestCase): def test_in_memory_cache_holder_iterdatapipe(self) -> None: source_dp = IterableWrapper(range(10)) cache_dp = source_dp.in_memory_cache(size=5) @@ -817,8 +816,6 @@ def test_unzipper_iterdatapipe(self): self.assertEqual(len(source_dp), len(dp2)) self.assertEqual(len(source_dp), len(dp3)) - # TODO: Add testing for different stages of pickling for UnZipper - def test_itertomap_mapdatapipe(self): # Functional Test with None key_value_fn values = list(range(10)) diff --git a/test/test_mapdatapipe.py b/test/test_mapdatapipe.py new file mode 100644 index 000000000..87f85b220 --- /dev/null +++ b/test/test_mapdatapipe.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import expecttest +from torchdata.datapipes.map import MapDataPipe, SequenceWrapper, UnZipper + + +class TestMapDataPipe(expecttest.TestCase): + def test_unzipper_mapdatapipe(self) -> None: + source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(10)]) + + # Functional Test: unzips each sequence, with `sequence_length` specified + dp1: MapDataPipe + dp2: MapDataPipe + dp3: MapDataPipe + dp1, dp2, dp3 = UnZipper(source_dp, sequence_length=3) # type: ignore[misc] + self.assertEqual(list(range(10)), list(dp1)) + self.assertEqual(list(range(10, 20)), list(dp2)) + self.assertEqual(list(range(20, 30)), list(dp3)) + + # Functional Test: skipping over specified values + dp2, dp3 = source_dp.unzip(sequence_length=3, columns_to_skip=[0]) + self.assertEqual(list(range(10, 20)), list(dp2)) + self.assertEqual(list(range(20, 30)), list(dp3)) + + (dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2]) + self.assertEqual(list(range(10, 20)), list(dp2)) + + source_dp = SequenceWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)]) + dp2, dp3 = source_dp.unzip(sequence_length=4, columns_to_skip=[0, 3]) + self.assertEqual(list(range(10, 20)), list(dp2)) + self.assertEqual(list(range(20, 30)), list(dp3)) + + # __len__ Test: the lengths of child DataPipes are correct + self.assertEqual((10, 10), (len(dp2), len(dp3))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_serialization.py b/test/test_serialization.py index 113d9be87..fc4a56f32 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -15,6 +15,7 @@ import expecttest import torchdata.datapipes.iter as iterdp +import torchdata.datapipes.map as mapdp from _utils._common_utils_for_test import create_temp_dir, create_temp_files from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE from torchdata.datapipes.iter import IterableWrapper @@ -359,8 +360,47 @@ def test_serializable_with_dill(self): class TestMapDataPipeSerialization(expecttest.TestCase): + def _serialization_test_helper(self, datapipe): + serialized_dp = pickle.dumps(datapipe) + deserialized_dp = pickle.loads(serialized_dp) + try: + self.assertEqual(list(datapipe), list(deserialized_dp)) + except AssertionError as e: + print(f"{datapipe} is failing.") + raise e + + def _serialization_test_for_dp_with_children(self, dp1, dp2): + self._serialization_test_helper(dp1) + self._serialization_test_helper(dp2) + def test_serializable(self): - pass + picklable_datapipes: List = [ + (mapdp.UnZipper, SequenceWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}), + ] + + dp_skip_comparison = set() + # These DataPipes produce multiple DataPipes as outputs and those should be compared + dp_compare_children = {mapdp.UnZipper} + + for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes: + try: + # Creating input (usually a DataPipe) for the specific dpipe being tested + if custom_input is None: + custom_input = SequenceWrapper(range(10)) + + if dpipe in dp_skip_comparison: # Mke sure they are picklable and loadable (no value comparison) + datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] + serialized_dp = pickle.dumps(datapipe) + _ = pickle.loads(serialized_dp) + elif dpipe in dp_compare_children: # DataPipes that have children + dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] + self._serialization_test_for_dp_with_children(dp1, dp2) + else: # Single DataPipe that requires comparison + datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] + self._serialization_test_helper(datapipe) + except Exception as e: + print(f"{dpipe} is failing.") + raise e def test_serializable_with_dill(self): """Only for DataPipes that take in a function as argument""" diff --git a/torchdata/datapipes/map/__init__.py b/torchdata/datapipes/map/__init__.py index a4cbcf46d..56acc2612 100644 --- a/torchdata/datapipes/map/__init__.py +++ b/torchdata/datapipes/map/__init__.py @@ -7,10 +7,21 @@ from torch.utils.data import MapDataPipe from torch.utils.data.datapipes.map import Batcher, Concater, Mapper, SequenceWrapper, Shuffler, Zipper +from torchdata.datapipes.map.util.unzipper import UnZipperMapDataPipe as UnZipper from torchdata.datapipes.map.util.utils import IterToMapConverterMapDataPipe as IterToMapConverter -__all__ = ["Batcher", "Concater", "IterToMapConverter", "Mapper", "SequenceWrapper", "Shuffler", "Zipper"] +__all__ = [ + "Batcher", + "Concater", + "IterToMapConverter", + "MapDataPipe", + "Mapper", + "SequenceWrapper", + "Shuffler", + "UnZipper", + "Zipper", +] # Please keep this list sorted assert __all__ == sorted(__all__) diff --git a/torchdata/datapipes/map/util/unzipper.py b/torchdata/datapipes/map/util/unzipper.py new file mode 100644 index 000000000..4fef763a4 --- /dev/null +++ b/torchdata/datapipes/map/util/unzipper.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, TypeVar + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.map import MapDataPipe + + +T = TypeVar("T") + + +@functional_datapipe("unzip") +class UnZipperMapDataPipe(MapDataPipe): + """ + Takes in a DataPipe of Sequences, unpacks each Sequence, and return the elements in separate DataPipes + based on their position in the Sequence (functional name: ``unzip``). The number of instances produced + equals to the ``sequence_legnth`` minus the number of columns to skip. + + Note: + Each sequence within the DataPipe should have the same length, specified by + the input argument `sequence_length`. + + Args: + source_datapipe: Iterable DataPipe with sequences of data + sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length. + columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be + an integer from 0 to sequence_length - 1) + + Example: + >>> from torchdata.datapipes.iter import SequenceWrapper + >>> source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(3)]) + >>> dp1, dp2, dp3 = source_dp.unzip(sequence_length=3) + >>> list(dp1) + [0, 1, 2] + >>> list(dp2) + [10, 11, 12] + >>> list(dp3) + [20, 21, 22] + """ + + def __new__( + cls, + source_datapipe: MapDataPipe[Sequence[T]], + sequence_length: int, + columns_to_skip: Optional[Sequence[int]] = None, + ): + if sequence_length < 1: + raise ValueError(f"Expected `sequence_length` larger than 0, but {sequence_length} is found") + if columns_to_skip is None: + instance_ids = list(range(sequence_length)) + else: + skips = set(columns_to_skip) + instance_ids = [i for i in range(sequence_length) if i not in skips] + + if len(instance_ids) == 0: + raise RuntimeError( + f"All instances are being filtered out in {cls.__name__}. Please check" + "the input `sequence_length` and `columns_to_skip`." + ) + return [_UnZipperMapDataPipe(source_datapipe, i) for i in instance_ids] + + +class _UnZipperMapDataPipe(MapDataPipe[T]): + def __init__(self, main_datapipe: MapDataPipe[Sequence[T]], instance_id: int): + self.main_datapipe = main_datapipe + self.instance_id = instance_id + + def __getitem__(self, index) -> T: + return self.main_datapipe[index][self.instance_id] + + def __len__(self) -> int: + return len(self.main_datapipe)