Skip to content

Commit

Permalink
Implement pause/resume for FullSync (#1130)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1130

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D45015591

Pulled By: NivekT

fbshipit-source-id: 6ff38037d18d35aec8bc0a33727fc78f6746256f
  • Loading branch information
NivekT authored and facebook-github-bot committed May 1, 2023
1 parent 85d54a5 commit 3e7eaae
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions torchdata/datapipes/iter/util/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import threading

import time

from collections import deque
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from dataclasses import dataclass
Expand All @@ -18,6 +20,8 @@
from torchdata._constants import default_timeout_in_s
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.iter.util.prefetcher import PRODUCER_SLEEP_INTERVAL


T_co = TypeVar("T_co", covariant=True)

Expand Down Expand Up @@ -66,6 +70,7 @@ def __init__(
self._futures: Deque[Future] = deque()
self._lock = threading.RLock()
self._end_flag = False
self._paused = False
self._idx = 0
for _ in range(prefetch_size):
with self._lock:
Expand All @@ -78,7 +83,11 @@ def __init__(
self._idx += 1

def fetch_next(self):
return next(self.datapipe_iterator)
while self._paused:
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

res = next(self.datapipe_iterator)
return res

def _done_callback_fn(self, index: int, f: Future):
if f.exception():
Expand Down Expand Up @@ -107,6 +116,12 @@ def return_next(self):
def shutdown(self):
self._executor.shutdown(wait=True)

def pause(self):
self._paused = True

def resume(self):
self._paused = False


@functional_datapipe("fullsync")
class FullSyncIterDataPipe(IterDataPipe[T_co]):
Expand Down Expand Up @@ -237,10 +252,10 @@ def __setstate__(self, state):

@final
def pause(self):
if self._world_size > 1 and self._executor is not None:
raise RuntimeError("`pause` is not supported for FullSync at the moment.")
if self._executor is not None:
self._executor.pause()

@final
def resume(self):
if self._world_size > 1 and self._executor is not None:
raise RuntimeError("`resume` is not supported for FullSync at the moment.")
if self._executor is not None:
self._executor.resume()

0 comments on commit 3e7eaae

Please sign in to comment.