From 56bc88af9dce01ede519b85cc0dd24156a589d99 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Tue, 23 May 2023 18:15:14 -0400 Subject: [PATCH] [DataPipe] Ensures Prefetcher shuts down properly ghstack-source-id: cb24f3c6ad0c01c7fbe7bf31ad532db9a0bae27e Pull Request resolved: https://github.com/pytorch/data/pull/1166 --- test/test_distributed.py | 7 +++ torchdata/datapipes/iter/util/prefetcher.py | 47 ++++++++++++++------- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index 80fa9f1cc..09ab0fa4f 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -157,6 +157,13 @@ def _test_fullsync(rank, world_size, backend, q): it2 = iter(dp3) # Reset next(it2) + dp4 = dp.prefetch(2) + it = iter(dp4) + next(it) + dp4.pause() + it2 = iter(dp4) # Reset + next(it2) + _finalize_distributed_queue(rank, q) @world_size_parametrize diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index eaf1d86ac..c12e5444a 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -60,6 +60,7 @@ def __init__(self, source_datapipe, buffer_size: int = 10): raise ValueError("'buffer_size' is required to be a positive integer.") self.buffer_size = buffer_size self.thread: Optional[threading.Thread] = None + self.prefetch_data: Optional[_PrefetchData] = None @staticmethod def thread_worker(prefetch_data: _PrefetchData): @@ -104,9 +105,12 @@ def __iter__(self): else: time.sleep(CONSUMER_SLEEP_INTERVAL) finally: - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - thread.join() + if "prefetch_data" in locals(): + prefetch_data.run_prefetcher = False + prefetch_data.stop_iteration = True + prefetch_data.paused = False + if "thread" in locals(): + thread.join() def __getstate__(self): """ @@ -127,12 +131,7 @@ def __setstate__(self, state): @final def reset(self): - if self.thread is not None: - self.prefetch_data.run_prefetcher = False - self.prefetch_data.stop_iteration = True - self.prefetch_data.paused = False - self.thread.join() - self.thread = None + self.shutdown() def pause(self): if self.thread is not None: @@ -145,13 +144,28 @@ def pause(self): @final def resume(self): - if self.thread is not None and ( - not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0 + if ( + self.thread is not None + and self.prefetch_data is not None + and (not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0) ): - assert self.prefetch_data is not None self.prefetch_data.run_prefetcher = True self.prefetch_data.paused = False + @final + def shutdown(self): + if hasattr(self, "prefetch_data") and self.prefetch_data is not None: + self.prefetch_data.run_prefetcher = False + self.prefetch_data.stop_iteration = True + self.prefetch_data.paused = False + self.prefetch_data = None + if hasattr(self, "thread") and self.thread is not None: + self.thread.join() + self.thread = None + + def __del__(self): + self.shutdown() + def __len__(self) -> int: if isinstance(self.source_datapipe, Sized): return len(self.source_datapipe) @@ -235,9 +249,12 @@ def __iter__(self): else: time.sleep(CONSUMER_SLEEP_INTERVAL) finally: - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - thread.join() + if "prefetch_data" in locals(): + prefetch_data.run_prefetcher = False + prefetch_data.stop_iteration = True + prefetch_data.paused = False + if "thread" in locals(): + thread.join() def __getstate__(self): state = super().__getstate__()