Skip to content

Commit

Permalink
Adding DataLoader2 Adapter API and shuffle adapter (#484)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #484

Test Plan: Imported from OSS

Reviewed By: NivekT, Miiira

Differential Revision: D36866291

Pulled By: VitalyFedyunin

fbshipit-source-id: 1df16c3c45518a069dd87a5565d418ced750ae09
  • Loading branch information
VitalyFedyunin authored and facebook-github-bot committed Jun 3, 2022
1 parent 6f1f81d commit 5aac88f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 8 deletions.
41 changes: 41 additions & 0 deletions test/test_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings

from unittest import TestCase

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2.adapter import Shuffle
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


class AdapterTest(TestCase):
def test_shuffle(self) -> None:
size = 500
dp = IterableWrapper(range(size))

dl = DataLoader2(datapipe=dp)
self.assertEqual(list(range(size)), list(dl))

with warnings.catch_warnings(record=True) as wa:
dl = DataLoader2(datapipe=dp, datapipe_adapter_fn=Shuffle(True))
self.assertNotEqual(list(range(size)), list(dl))
self.assertEqual(1, len(wa))

dp = IterableWrapper(range(size)).shuffle()

dl = DataLoader2(datapipe=dp)
self.assertNotEqual(list(range(size)), list(dl))

dl = DataLoader2(dp, Shuffle(True))
self.assertNotEqual(list(range(size)), list(dl))

dl = DataLoader2(dp, [Shuffle(None)])
self.assertNotEqual(list(range(size)), list(dl))

dl = DataLoader2(dp, [Shuffle(False)])
self.assertEqual(list(range(size)), list(dl))
47 changes: 47 additions & 0 deletions torchdata/dataloader2/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import abstractmethod

import torch

from torchdata.datapipes.iter import IterDataPipe

__all__ = [
"Adapter",
"Shuffle",
]

assert __all__ == sorted(__all__)


class Adapter:
@abstractmethod
def __call__(self, datapipe: IterDataPipe) -> IterDataPipe:
pass


class Shuffle(Adapter):
r"""
Shuffle DataPipes adapter allows control over all existing Shuffler (``shuffle``) DataPipes in the graph.
Args:
enable: Optional[Boolean] = True
Shuffle(enable = True) - enables all previously disabled Shuffler DataPipes. If none exists, it will add a new `shuffle` at the end of the graph.
Shuffle(enable = False) - disables all Shuffler DataPipes in the graph.
Shuffle(enable = None) - Is noop. Introduced for backward compatibility.
Example:
>>> dp = IterableWrapper(range(size)).shuffle()
>>> dl = DataLoader2(dp, [Shuffle(False)])
>>> self.assertEqual(list(range(size)), list(dl))
"""

def __init__(self, enable=True):
self.enable = enable

def __call__(self, datapipe: IterDataPipe) -> IterDataPipe:
return torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=self.enable)
24 changes: 16 additions & 8 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import pickle
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generic, Iterator, Optional, TypeVar
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

from torchdata.dataloader2.adapter import Adapter

from torchdata.datapipes.iter import IterDataPipe

Expand Down Expand Up @@ -45,22 +47,27 @@ class DataLoader2(Generic[T_co]):
def __init__(
self,
datapipe: IterDataPipe,
# TODO: Change into Iterable[DPAdapter] and apply them sequentially for OSS use case
datapipe_adapter_fn: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
datapipe_adapter_fn: Optional[Union[Iterable[Adapter], Adapter]] = None,
reading_service: Optional[ReadingServiceInterface] = None,
) -> None:
self.datapipe = datapipe
self._adapted: bool = False
self._datapipe_iter: Optional[Iterator[T_co]] = None
self._reset_iter: bool = True
# TODO(VitalyFedyunin): Some ReadingServices might want to validate adapters, we can add this feature
self.datapipe_adapter_fn = datapipe_adapter_fn
if datapipe_adapter_fn is None:
self.datapipe_adapter_fns = None
elif isinstance(datapipe_adapter_fn, Iterable):
self.datapipe_adapter_fns = datapipe_adapter_fn
else:
self.datapipe_adapter_fns = [datapipe_adapter_fn]
self.reading_service = reading_service
self.reading_service_state: Optional[bytes] = None
self._terminated: bool = False

if self.datapipe_adapter_fn is not None:
self.datapipe = self.datapipe_adapter_fn(self.datapipe)
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt: IterDataPipe = self.datapipe

def __iter__(self) -> Iterator[T_co]:
Expand Down Expand Up @@ -182,6 +189,7 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
self.reading_service_state = reading_service_state

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fn is not None:
self.datapipe = self.datapipe_adapter_fn(self.datapipe)
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt = self.datapipe

0 comments on commit 5aac88f

Please sign in to comment.