diff --git a/test/test_distributed.py b/test/test_distributed.py index 80fa9f1cc..09ab0fa4f 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -157,6 +157,13 @@ def _test_fullsync(rank, world_size, backend, q): it2 = iter(dp3) # Reset next(it2) + dp4 = dp.prefetch(2) + it = iter(dp4) + next(it) + dp4.pause() + it2 = iter(dp4) # Reset + next(it2) + _finalize_distributed_queue(rank, q) @world_size_parametrize diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index eaf1d86ac..e61cde849 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -60,6 +60,7 @@ def __init__(self, source_datapipe, buffer_size: int = 10): raise ValueError("'buffer_size' is required to be a positive integer.") self.buffer_size = buffer_size self.thread: Optional[threading.Thread] = None + self.prefetch_data: Optional[_PrefetchData] = None @staticmethod def thread_worker(prefetch_data: _PrefetchData): @@ -83,6 +84,8 @@ def thread_worker(prefetch_data: _PrefetchData): time.sleep(PRODUCER_SLEEP_INTERVAL * 10) def __iter__(self): + if self.thread is not None: + self.shutdown() try: prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) self.prefetch_data = prefetch_data @@ -104,9 +107,7 @@ def __iter__(self): else: time.sleep(CONSUMER_SLEEP_INTERVAL) finally: - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - thread.join() + self.shutdown() def __getstate__(self): """ @@ -127,12 +128,7 @@ def __setstate__(self, state): @final def reset(self): - if self.thread is not None: - self.prefetch_data.run_prefetcher = False - self.prefetch_data.stop_iteration = True - self.prefetch_data.paused = False - self.thread.join() - self.thread = None + self.shutdown() def pause(self): if self.thread is not None: @@ -145,13 +141,28 @@ def pause(self): @final def resume(self): - if self.thread is not None and ( - not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0 + if ( + self.thread is not None + and self.prefetch_data is not None + and (not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0) ): - assert self.prefetch_data is not None self.prefetch_data.run_prefetcher = True self.prefetch_data.paused = False + @final + def shutdown(self): + if self.prefetch_data is not None: + self.prefetch_data.run_prefetcher = False + self.prefetch_data.stop_iteration = True + self.prefetch_data.paused = False + self.prefetch_data = None + if self.thread is not None: + self.thread.join() + self.thread = None + + def __del__(self): + self.shutdown() + def __len__(self) -> int: if isinstance(self.source_datapipe, Sized): return len(self.source_datapipe) @@ -185,6 +196,7 @@ def __init__(self, source_datapipe, device=None, pin_memory_fn=pin_memory_fn): device = torch.cuda.current_device() self.device = device self.pin_memory_fn = pin_memory_fn + self.prefetch_data: Optional[_PrefetchData] = None def is_replicable(self) -> bool: return False @@ -210,6 +222,8 @@ def thread_worker(prefetch_data: _PrefetchData, pin_memory_fn, device): # type: time.sleep(PRODUCER_SLEEP_INTERVAL * 10) def __iter__(self): + if self.thread is not None: + self.shutdown() try: prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) self.prefetch_data = prefetch_data @@ -235,9 +249,7 @@ def __iter__(self): else: time.sleep(CONSUMER_SLEEP_INTERVAL) finally: - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - thread.join() + self.shutdown() def __getstate__(self): state = super().__getstate__()