Skip to content

Commit

Permalink
Reland: [DataLoader2] Removing delegation for 'pause', 'limit', and '…
Browse files Browse the repository at this point in the history
…resume'

ghstack-source-id: 96dfad68c49af524797fceb8ac541a60e4b79e4c
Pull Request resolved: #1067
  • Loading branch information
NivekT committed Apr 10, 2023
1 parent 851e26a commit 6d4d82d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
23 changes: 11 additions & 12 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def resume(self) -> None:
Restarts the threads within ``DataLoader2`` and allows it to yield additional batches.
"""
self.dataloader._resume()
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "resume"):
self.dataloader._datapipe_iter.resume() # type: ignore[attr-defined]

def limit(self, num_batches: Optional[int]) -> None:
"""
Expand All @@ -112,8 +110,7 @@ def limit(self, num_batches: Optional[int]) -> None:
"""
self.limit_counter = 0
self.limit_threshold = num_batches
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "limit"):
self.dataloader._datapipe_iter.limit(num_batches) # type: ignore[attr-defined]
self.dataloader._limit(num_batches)

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -369,11 +366,8 @@ def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
self.reading_service._pause()
# TODO: the condition should be `else` once `self._datapipe_iter.pause/limit()` is no longer used
elif self._datapipe_iter is None or not (
hasattr(self._datapipe_iter, "limit") or hasattr(self._datapipe_iter, "pause")
):
warnings.warn("ReadingService doesn't support pause.")
else:
warnings.warn("ReadingService doesn't support `pause`.")

def _resume(self):
if hasattr(self.reading_service, "_resume"):
Expand All @@ -382,6 +376,11 @@ def _resume(self):
else:
self.reading_service._resume()
self._is_paused = False
# TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used
elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"):
warnings.warn("ReadingService doesn't support resume.")
else:
warnings.warn("ReadingService doesn't support `resume`.")

def _limit(self, num_batches: Optional[int]) -> None:
if hasattr(self.reading_service, "_limit"):
self.reading_service._limit(num_batches)
else:
warnings.warn("ReadingService doesn't support `limit`.")
7 changes: 7 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ def _resume(self):
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
self._main_prefetch_datapipe.resume() # type: ignore[union-attr]

def _limit(self, num_batches: Optional[int]) -> None:
"""
For this ReadingService, `DataLoader2Iterator` and `DataLoader2` should sufficiently handle
the limit operation, such that nothing needs to be done here.
"""
pass


class DistributedReadingService(ReadingServiceInterface):
r"""
Expand Down

0 comments on commit 6d4d82d

Please sign in to comment.