diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index f9c2baa85..0cfcb367c 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -204,6 +204,8 @@ def _resume(self): Resumes DataPipes' activities. This is required to be called after `_pause` before the DataLoader can keep yielding elements. """ + assert self._end_datapipe is not None + dp_list = list_dps(traverse_dps(self._end_datapipe)) # Reversed order for dp in dp_list[::-1]: @@ -435,6 +437,7 @@ def _pause(self): in order to collect state. """ assert self._end_datapipe is not None + dp_list = list_dps(traverse_dps(self._end_datapipe)) for dp in dp_list: # TODO: Combine QueueWrapper and _IterateQueueDataPipes, @@ -451,6 +454,8 @@ def _resume(self): Resumes DataPipes' activities. This is required to be called after `_pause` before the DataLoader can keep yielding elements. """ + assert self._end_datapipe is not None + self._worker_consumer_datapipe.request_resume() # type: ignore[union-attr] dp_list = list_dps(traverse_dps(self._end_datapipe)) # Reversed order