Skip to content

Commit

Permalink
Add distributedReadingService
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Aug 15, 2022
1 parent 827b13d commit 0f715fc
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 0 deletions.
37 changes: 37 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize

from torchdata.dataloader2 import DataLoader2, DistributedReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter.util.prefetch import PrefetchTimeoutError

Expand Down Expand Up @@ -80,6 +81,42 @@ def test_fullsync(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend == "gloo" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, DistributedTest._test_fullsync)

@staticmethod
def _test_distributed_rs(rank, world_size, backend):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Use a prime number to make sure uneven data sharding
data_length = 23
dp = IterableWrapper(list(range(data_length))).shuffle().sharding_filter()
dl = DataLoader2(dp, DistributedReadingService())

for _ in range(2):
res = []
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
assert res == list(range(rank, data_length // world_size * world_size, world_size))

# Timeout Test
dl = DataLoader2(dp, DistributedReadingService(timeout=0.01))
try:
for _ in range(2):
_ = list(dl)
except Exception as e:
assert isinstance(e, PrefetchTimeoutError)

@parametrize(
"backend",
["gloo", "nccl"]
if torch.cuda.nccl.is_available([])
else [
"gloo",
],
)
def test_distributed_reading_service(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend == "gloo" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, DistributedTest._test_fullsync)


instantiate_parametrized_tests(DistributedTest)

Expand Down
1 change: 1 addition & 0 deletions torchdata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@

# Use the same timeout as PyTorch Distributed
default_timeout_in_s = 30 * 60
default_check_interval = 0.01
2 changes: 2 additions & 0 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .dataloader2 import DataLoader2, DataLoader2Iterator
from .error import PauseIteration
from .reading_service import (
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
Expand All @@ -17,6 +18,7 @@
__all__ = [
"DataLoader2",
"DataLoader2Iterator",
"DistributedReadingService",
"MultiProcessingReadingService",
"PauseIteration",
"PrototypeMultiProcessingReadingService",
Expand Down
84 changes: 84 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
import functools
import multiprocessing as mp
import time

from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Any, Callable, List, Optional

import torch
import torch.distributed as dist

from torch.utils.data import DataLoader
from torch.utils.data.graph import DataPipe

from torchdata._constants import default_check_interval, default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.datapipes.iter import IterableWrapper

Expand Down Expand Up @@ -235,3 +240,82 @@ def finalize(self) -> None:
if self.persistent_workers and self.dl_ is not None and self.dl_._iterator is not None:
self.dl_._iterator._shutdown_workers() # type: ignore[attr-defined]
self.dl_._iterator = None


SHARED_SEED = "_dl_shared_seed"
SHARED_SEED_COUNTER = "_dl_shared_seed_recv_cnt"


class DistributedReadingService(ReadingServiceInterface):
def __init__(self, timeout: int = default_timeout_in_s, check_interval: float = default_check_interval):
if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError("Torch Distributed is required to be initialized")
self._world_size: int = dist.get_world_size()
self._rank: int = dist.get_rank()
self._seed_generator = torch.Generator()
self._datapipe: Optional[DataPipe] = None
self._timeout: int = timeout
self._check_interval: float = check_interval

def initialize(self, datapipe: DataPipe) -> DataPipe:
torch.utils.data.graph_settings.apply_sharding(
datapipe,
self._world_size,
self._rank,
)
datapipe = datapipe.fullsync(self._timeout)
self._datapipe = datapipe
return datapipe

def initialize_iteration(self) -> None:
seed = self._share_seed()
self._seed_generator.manual_seed(seed)
assert self._datapipe is not None
self._datapipe = torch.utils.data.graph_settings.apply_shuffle_seed(
self._datapipe,
self._seed_generator,
)

def _share_seed(self):
_sd = torch.empty((), dtype=torch.int64).random_().item()
store = dist.distributed_c10d._get_default_store()
if self._rank == 0:
_sd_str = str(_sd)
store.set(SHARED_SEED, _sd_str)
_sd_recv_cnt = store.add(SHARED_SEED_COUNTER, 1)
start = time.time()
while _sd_recv_cnt < self._world_size:
time.sleep(self._check_interval)
_sd_recv_cnt = store.add(SHARED_SEED_COUNTER, 0)
if timedelta(seconds=(time.time() - start)) > timedelta(seconds=self._timeout):
raise RuntimeError(
"Timed out receiving the signal from the "
"distribtued store on Rank 0 that all other "
"Ranks have received the shared seed. "
f"(world_size={self._world_size}, "
f"received={_sd_recv_cnt}, "
f"timeout={self._timeout})"
)
store.set(SHARED_SEED, "")
_sd_recv_cnt = store.add(SHARED_SEED_COUNTER, -self._world_size)
assert _sd_recv_cnt == 0
else:
_sd_str = ""
start = time.time()
while len(_sd_str) == 0:
time.sleep(self._check_interval)
_sd_str = store.get(SHARED_SEED)
if timedelta(seconds=(time.time() - start)) > timedelta(seconds=self._timeout):
raise RuntimeError(
"Timed out receiving the shared seed from the "
f"distribtued store on Rank {self._rank}. "
f"(world_size={self._world_size}, "
f"timeout={self._timeout})"
)
_sd_recv_cnt = store.add(SHARED_SEED_COUNTER, 1)
while _sd_recv_cnt > 0:
time.sleep(self._check_interval)
_sd_recv_cnt = store.add(SHARED_SEED_COUNTER, 0)
_sd = int(_sd_str)

return _sd

0 comments on commit 0f715fc

Please sign in to comment.