Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DistribtuedReadingService #727

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions test/bin/elastic_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import argparse

import torch
import torch.distributed as dist

from torch.utils.data import DataLoader
from torchdata.dataloader2 import DataLoader2, DistributedReadingService
from torchdata.datapipes.iter import IterableWrapper


def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None):
data_source = IterableWrapper(list(range(data_length)))

dp = data_source.sharding_filter()
if shuffle:
dp = dp.shuffle()

if dl2:
if rs is None:
rs = DistributedReadingService()
dl = DataLoader2(dp, reading_service=rs)
else:
dp = dp.fullsync()
dl = DataLoader(dp)

return dl


def main(backend, dl2):
dist.init_process_group(backend)
rank = dist.get_rank()
world_size = dist.get_world_size()

# Use a prime number to make sure uneven data sharding
data_length = 23

# No Shuffle
dl = _get_dataloader(data_length, dl2=dl2, shuffle=False)
res = []
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
assert sorted(res) == list(range(rank, data_length // world_size * world_size, world_size))

# Shuffle
dl = _get_dataloader(data_length, dl2=dl2, shuffle=True)
results = []
for _ in range(2):
res = []
torch.manual_seed(123)
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
results.append(res)
assert results[0] == results[1]

# Different seed
res = []
torch.manual_seed(321)
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
results.append(res)
assert len(results[0]) == len(results[2])
assert results[0] != results[2]


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Elastic Training")
backend_group = parser.add_mutually_exclusive_group(required=True)
backend_group.add_argument("--gloo", action="store_true", help="GLOO backend")
backend_group.add_argument("--nccl", action="store_true", help="NCCL backend")
backend_group.add_argument("--mpi", action="store_true", help="MPI backend")
dl_group = parser.add_mutually_exclusive_group(required=True)
dl_group.add_argument("--dl1", action="store_true", help="DataLoader")
dl_group.add_argument("--dl2", action="store_true", help="DataLoader2")

args = parser.parse_args()

backend = "gloo"
if args.nccl:
backend = "nccl"
elif args.mpi:
backend = "mpi"

dl2 = True
if args.dl1:
dl2 = False

main(backend, dl2)
140 changes: 130 additions & 10 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import os
import unittest

from functools import partial
from unittest import TestCase

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torch.utils.data import DataLoader

from torchdata.dataloader2 import DataLoader2, DistributedReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter.util.prefetch import PrefetchTimeoutError

ejguan marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -24,10 +27,25 @@


if not dist.is_available():
import sys

print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)


_backends = ["gloo"]
if dist.is_mpi_available():
_backends.append("mpi")
if dist.is_nccl_available() and torch.cuda.device_count() > 0:
_backends.append("nccl")

backend_parametrize = parametrize("backend", _backends)


def abs_path(path):
return os.path.join(os.path.dirname(__file__), os.path.normpath(path))


def launch_distributed_training(backend, world_size, fn):
os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR
os.environ["MASTER_PORT"] = TEST_MASTER_PORT
Expand Down Expand Up @@ -68,18 +86,120 @@ def _test_fullsync(rank, world_size, backend):
except Exception as e:
assert isinstance(e, PrefetchTimeoutError)

@parametrize(
"backend",
["gloo", "nccl"]
if torch.cuda.nccl.is_available([])
else [
"gloo",
],
)
@backend_parametrize
def test_fullsync(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend == "gloo" else torch.cuda.device_count()
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, DistributedTest._test_fullsync)

@staticmethod
def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None):
data_source = IterableWrapper(list(range(data_length)))

dp = data_source.sharding_filter()
if shuffle:
dp = dp.shuffle()

if dl2:
if rs is None:
rs = DistributedReadingService()
dl = DataLoader2(dp, reading_service=rs)
else:
dp = dp.fullsync()
dl = DataLoader(dp)

return dl

@staticmethod
def _test_distributed_training(dl2, rank, world_size, backend):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Use a prime number to make sure uneven data sharding
data_length = 23

# No shuffle
dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=False)
res = []
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
assert sorted(res) == list(range(rank, data_length // world_size * world_size, world_size))

# Shuffle
dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=True)
results = []
for _ in range(2):
res = []
torch.manual_seed(123)
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
results.append(res)
assert results[0] == results[1]

# Different seed
res = []
torch.manual_seed(321)
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
results.append(res)
assert len(results[0]) == len(results[2])
assert results[0] != results[2]

@backend_parametrize
def test_distributed_dl2(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, partial(DistributedTest._test_distributed_training, True))

@unittest.skipIf(
IS_WINDOWS,
"Torch Elastic is not working properly on Windows. See: https://github.com/pytorch/pytorch/issues/85427",
)
@backend_parametrize
def test_elastic_training_dl2(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
nnodes = 1
from torch.distributed import run

run.main(
[
"--run_path",
f"--nnodes={nnodes}",
f"--nproc_per_node={world_size}",
abs_path("bin/elastic_training.py"),
"--" + backend,
"--dl2",
],
)

@backend_parametrize
def test_distributed_dl1(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, partial(DistributedTest._test_distributed_training, False))

@unittest.skipIf(
IS_WINDOWS,
"Torch Elastic is not working properly on Windows. See: https://github.com/pytorch/pytorch/issues/85427",
)
@backend_parametrize
def test_elastic_training_dl1(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
nnodes = 1
from torch.distributed import run

run.main(
[
"--run_path",
f"--nnodes={nnodes}",
f"--nproc_per_node={world_size}",
abs_path("bin/elastic_training.py"),
"--" + backend,
"--dl1",
],
)


instantiate_parametrized_tests(DistributedTest)

Expand Down
2 changes: 2 additions & 0 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .dataloader2 import DataLoader2, DataLoader2Iterator
from .error import PauseIteration
from .reading_service import (
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
Expand All @@ -17,6 +18,7 @@
__all__ = [
"DataLoader2",
"DataLoader2Iterator",
"DistributedReadingService",
"MultiProcessingReadingService",
"PauseIteration",
"PrototypeMultiProcessingReadingService",
Expand Down
76 changes: 75 additions & 1 deletion torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
import functools
import multiprocessing as mp
import time

from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Any, Callable, List, Optional

import torch
import torch.distributed as dist

from torch.utils.data import DataLoader
from torch.utils.data.graph import DataPipe

from torchdata._constants import default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import FullSync, IterableWrapper


class ReadingServiceInterface(ABC):
Expand Down Expand Up @@ -235,3 +240,72 @@ def finalize(self) -> None:
if self.persistent_workers and self.dl_ is not None and self.dl_._iterator is not None:
self.dl_._iterator._shutdown_workers() # type: ignore[attr-defined]
self.dl_._iterator = None


class DistributedReadingService(ReadingServiceInterface):
r"""
``DistributedReadingSerivce`` handles distributed sharding on the graph of ``DataPipe`` and
guarantee the randomness by sharing the same seed across the distributed processes.

Args:
timeout: Timeout for operations executed against the process group in seconds.
Default value equals 30 minutes.
"""

def __init__(self, timeout: int = default_timeout_in_s):
if not dist.is_available():
raise RuntimeError("Torch Distributed is required to be available")
self._world_size: int = 1
self._rank: int = 0
self._datapipe: Optional[DataPipe] = None
self._timeout: int = timeout
self._pg: Optional[dist.ProcessGroup] = None

def initialize(self, datapipe: DataPipe) -> DataPipe:
r"""
Launches the ``gloo``-backend distributed process group. Carries out distributed sharding
on the graph of ``DataPipe`` and returnes the graph attached with a ``FullSyncIterDataPipe``
at the end.
"""
if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError("Torch Distributed is required to be initialized")
self._world_size = dist.get_world_size()
self._rank = dist.get_rank()
self._pg = dist.new_group(backend="gloo", timeout=timedelta(seconds=self._timeout))
torch.utils.data.graph_settings.apply_sharding(
datapipe,
self._world_size,
self._rank,
)
# Only append FullSyncIterDataPipe if it's not presented at the end of the pipeline
if not isinstance(datapipe, FullSync):
datapipe = datapipe.fullsync(self._timeout)
self._datapipe = datapipe
return datapipe

def initialize_iteration(self) -> None:
r"""
Shares the same seed from rank 0 to other ranks across the distributed processes
and apply the random seed to the graph of ``DataPipe``.
"""
# TODO: Seed Generator should be moved to DataLoader2 after the API
# change of initialize_iteration is landed.
seed = self._share_seed()
_seed_generator = torch.Generator()
_seed_generator.manual_seed(seed)
assert self._datapipe is not None
self._datapipe = torch.utils.data.graph_settings.apply_random_seed(
self._datapipe,
_seed_generator,
)

def _share_seed(self):
shared_seed = torch.empty((), dtype=torch.int64).random_()
dist.broadcast(shared_seed, src=0, group=self._pg)
return shared_seed.item()

def finalize(self) -> None:
r"""
Clean up the distributed process group.
"""
self._pg = None
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/util/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class FullSyncIterDataPipe(IterDataPipe[T_co]):
"""

def __init__(self, datapipe: IterDataPipe, timeout=default_timeout_in_s):
if not dist.is_available():
raise RuntimeError("Torch Distributed is required to be available")
self.datapipe = datapipe
self.timeout = timeout

Expand Down