Skip to content

Commit

Permalink
[DataPipe] Ensures Prefetcher shuts down properly
Browse files Browse the repository at this point in the history
ghstack-source-id: cb24f3c6ad0c01c7fbe7bf31ad532db9a0bae27e
Pull Request resolved: #1166
  • Loading branch information
NivekT committed May 23, 2023
1 parent ba31745 commit 56bc88a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
7 changes: 7 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def _test_fullsync(rank, world_size, backend, q):
it2 = iter(dp3) # Reset
next(it2)

dp4 = dp.prefetch(2)
it = iter(dp4)
next(it)
dp4.pause()
it2 = iter(dp4) # Reset
next(it2)

_finalize_distributed_queue(rank, q)

@world_size_parametrize
Expand Down
47 changes: 32 additions & 15 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, source_datapipe, buffer_size: int = 10):
raise ValueError("'buffer_size' is required to be a positive integer.")
self.buffer_size = buffer_size
self.thread: Optional[threading.Thread] = None
self.prefetch_data: Optional[_PrefetchData] = None

@staticmethod
def thread_worker(prefetch_data: _PrefetchData):
Expand Down Expand Up @@ -104,9 +105,12 @@ def __iter__(self):
else:
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
thread.join()
if "prefetch_data" in locals():
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
prefetch_data.paused = False
if "thread" in locals():
thread.join()

def __getstate__(self):
"""
Expand All @@ -127,12 +131,7 @@ def __setstate__(self, state):

@final
def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
self.prefetch_data.paused = False
self.thread.join()
self.thread = None
self.shutdown()

def pause(self):
if self.thread is not None:
Expand All @@ -145,13 +144,28 @@ def pause(self):

@final
def resume(self):
if self.thread is not None and (
not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0
if (
self.thread is not None
and self.prefetch_data is not None
and (not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0)
):
assert self.prefetch_data is not None
self.prefetch_data.run_prefetcher = True
self.prefetch_data.paused = False

@final
def shutdown(self):
if hasattr(self, "prefetch_data") and self.prefetch_data is not None:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
self.prefetch_data.paused = False
self.prefetch_data = None
if hasattr(self, "thread") and self.thread is not None:
self.thread.join()
self.thread = None

def __del__(self):
self.shutdown()

def __len__(self) -> int:
if isinstance(self.source_datapipe, Sized):
return len(self.source_datapipe)
Expand Down Expand Up @@ -235,9 +249,12 @@ def __iter__(self):
else:
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
thread.join()
if "prefetch_data" in locals():
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
prefetch_data.paused = False
if "thread" in locals():
thread.join()

def __getstate__(self):
state = super().__getstate__()
Expand Down

0 comments on commit 56bc88a

Please sign in to comment.