Skip to content

Commit

Permalink
torchdata datapipes drop addition (#725)
Browse files Browse the repository at this point in the history
Summary:
Please read through our [contribution guide](https://github.com/pytorch/data/blob/main/CONTRIBUTING.md) prior to
creating your pull request.

- Note that there is a section on requirements related to adding a new DataPipe.

Fixes #656

### Changes
part 1 of the feature requests to manipulate datapipe columns here: #656
this adds a drop functionality to iter datapipes

Pull Request resolved: #725

Reviewed By: ejguan

Differential Revision: D38608512

Pulled By: dbish

fbshipit-source-id: 273ff7bfde001baf4e35961fe3056fece8ced502
  • Loading branch information
Diamond Bishop authored and facebook-github-bot committed Aug 11, 2022
1 parent 85ae838 commit 5dade9a
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ These DataPipes helps you select specific samples within a DataPipe.

Filter
Header
Dropper

Text DataPipes
-----------------------------
Expand Down
48 changes: 48 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,54 @@ def test_zip_longest_iterdatapipe(self):
# __len__ Test: length matches the length of the shortest input
self.assertEqual(len(output_dp), 10)

def test_drop_iterdatapipe(self):
# tuple tests
input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)])

# Functional Test: single index drop for tuple elements
drop_dp = input_dp.drop(1)
self.assertEqual([(0, 2), (3, 5), (6, 8)], list(drop_dp))

# Functional Test: multiple indices drop for tuple elements
drop_dp = input_dp.drop([0, 2])
self.assertEqual([(1,), (4,), (7,)], list(drop_dp))

# dict tests
input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3}, {"a": 3, "b": 4, "c": 5}, {"a": 5, "b": 6, "c": 7}])

# Functional Test: single key drop for dict elements
drop_dp = input_dp.drop("a")
self.assertEqual([{"b": 2, "c": 3}, {"b": 4, "c": 5}, {"b": 6, "c": 7}], list(drop_dp))

# Functional Test: multiple key drop for dict elements
drop_dp = input_dp.drop(["a", "b"])
self.assertEqual([{"c": 3}, {"c": 5}, {"c": 7}], list(drop_dp))

# list tests
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])

# Functional Test: single key drop for list elements
drop_dp = input_dp.drop(2)
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(drop_dp))

# Functional Test: multiple key drop for list elements
drop_dp = input_dp.drop([0, 1])
self.assertEqual([[2], [5], [8]], list(drop_dp))

# Reset Test:
n_elements_before_reset = 2
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
drop_dp = input_dp.drop([0, 1])
expected_res = [[2], [5], [8]]
res_before_reset, res_after_reset = reset_after_n_next_calls(drop_dp, n_elements_before_reset)
self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset)
self.assertEqual(expected_res, res_after_reset)

# __len__ Test:
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
drop_dp = input_dp.drop([0, 1])
self.assertEqual(3, len(drop_dp))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_serializable(self):
(iterdp.Cycler, None, (2,), {}),
(iterdp.DataFrameMaker, IterableWrapper([(i,) for i in range(3)]), (), {"dtype": DTYPE}),
(iterdp.Decompressor, None, (), {}),
(iterdp.Dropper, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}),
(iterdp.Enumerator, None, (2,), {}),
(iterdp.FlatMapper, None, (_fake_fn_ls,), {}),
(iterdp.FSSpecFileLister, ".", (), {}),
Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
)
from torchdata.datapipes.iter.transform.callable import (
BatchMapperIterDataPipe as BatchMapper,
DropperIterDataPipe as Dropper,
FlatMapperIterDataPipe as FlatMapper,
)
from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader
Expand Down Expand Up @@ -143,6 +144,7 @@
"DataFrameMaker",
"Decompressor",
"Demultiplexer",
"Dropper",
"EndOnDiskCacheHolder",
"Enumerator",
"Extractor",
Expand Down
2 changes: 1 addition & 1 deletion torchdata/datapipes/iter/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ from torchdata.datapipes.map import MapDataPipe
from torch.utils.data import DataChunk, IterableDataset, default_collate
from torch.utils.data.datapipes._typing import _DataPipeMeta

from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, Hashable

try:
import torcharrow
Expand Down
70 changes: 69 additions & 1 deletion torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Iterator, List, TypeVar
import warnings
from typing import Callable, Hashable, Iterator, List, Sized, TypeVar, Union

from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
Expand Down Expand Up @@ -137,3 +138,70 @@ def __iter__(self) -> Iterator[T_co]:

def __len__(self) -> int:
raise TypeError(f"{type(self).__name__}'s length relies on the output of its function.")


@functional_datapipe("drop")
class DropperIterDataPipe(IterDataPipe[T_co]):
r"""
Drop columns/elements in input DataPipe via its indices (functional name: ``drop``).
Args:
datapipe: IterDataPipe with columns to be dropped
indices: a single column index to be dropped or a list of indices
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, ZipperMapDataPipe
>>> dp1 = IterableWrapper(range(5))
>>> dp2 = IterableWrapper(range(10, 15))
>>> dp = dp1.zip(dp2)
>>> list(dp)
[(0, 10), (1, 11), (2, 12), (3, 13), (4, 14)]
>>> drop_dp = dp.drop(1)
>>> list(drop_dp)
[(0), (1), (2), (3), (4)]
"""
datapipe: IterDataPipe

def __init__(
self,
datapipe: IterDataPipe,
indices: Union[Hashable, List[Hashable]],
) -> None:
super().__init__()
self.datapipe = datapipe
if isinstance(indices, list):
self.indices = set(indices)
else:
self.indices = {indices}

def __iter__(self) -> Iterator[T_co]:
for old_item in self.datapipe:
if isinstance(old_item, tuple):
new_item = tuple(x for i, x in enumerate(old_item) if i not in self.indices) # type: ignore[assignment]
elif isinstance(old_item, list):
new_item = [x for i, x in enumerate(old_item) if i not in self.indices] # type: ignore[assignment]
elif isinstance(old_item, dict):
new_item = {k: v for (k, v) in old_item.items() if k not in self.indices} # type: ignore[assignment]
else:
new_item = old_item
warnings.warn(
"The next item was not an iterable and cannot be filtered, "
"please be aware that no filter was done or new item created."
)

# check to make sure all indices requested were in the item. warn if not
try:
for i in self.indices:
old_item[i]
except (IndexError, KeyError):
warnings.warn(
"At least one index in the filter is not present in the item being returned,"
" please be aware that expected columns/keys may be missing."
)

yield new_item # type: ignore[misc]

def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")

0 comments on commit 5dade9a

Please sign in to comment.