Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reland: [DataLoader2] Removing delegation for 'pause', 'limit', and 'resume' #1067

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def resume(self) -> None:
Restarts the threads within ``DataLoader2`` and allows it to yield additional batches.
"""
self.dataloader._resume()
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "resume"):
self.dataloader._datapipe_iter.resume() # type: ignore[attr-defined]

def limit(self, num_batches: Optional[int]) -> None:
"""
Expand All @@ -112,8 +110,7 @@ def limit(self, num_batches: Optional[int]) -> None:
"""
self.limit_counter = 0
self.limit_threshold = num_batches
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "limit"):
self.dataloader._datapipe_iter.limit(num_batches) # type: ignore[attr-defined]
self.dataloader._limit(num_batches)

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -369,11 +366,8 @@ def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
self.reading_service._pause()
# TODO: the condition should be `else` once `self._datapipe_iter.pause/limit()` is no longer used
elif self._datapipe_iter is None or not (
hasattr(self._datapipe_iter, "limit") or hasattr(self._datapipe_iter, "pause")
):
warnings.warn("ReadingService doesn't support pause.")
else:
warnings.warn("ReadingService doesn't support `pause`.")

def _resume(self):
if hasattr(self.reading_service, "_resume"):
Expand All @@ -382,6 +376,11 @@ def _resume(self):
else:
self.reading_service._resume()
self._is_paused = False
# TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used
elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"):
warnings.warn("ReadingService doesn't support resume.")
else:
warnings.warn("ReadingService doesn't support `resume`.")

def _limit(self, num_batches: Optional[int]) -> None:
if hasattr(self.reading_service, "_limit"):
self.reading_service._limit(num_batches)
else:
warnings.warn("ReadingService doesn't support `limit`.")
7 changes: 7 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ def _resume(self):
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
self._main_prefetch_datapipe.resume() # type: ignore[union-attr]

def _limit(self, num_batches: Optional[int]) -> None:
"""
For this ReadingService, `DataLoader2Iterator` and `DataLoader2` should sufficiently handle
the limit operation, such that nothing needs to be done here.
"""
pass


class DistributedReadingService(ReadingServiceInterface):
r"""
Expand Down