Skip to content

Commit

Permalink
Unblock ProtoMPRS to control determinism of DataPipe in single/multi-…
Browse files Browse the repository at this point in the history
…processing and dist/non-dist env (pytorch#827)

Summary:
This PR temporarily extend `PrototypingMultiProcessingReadingService` to fully control the determinism of the pipeline in the combinations of:
- Single/Multi-processing
- Distributed/Non-distributed
When we have `SequentialReadingService` ready to combine `DistributedReadingService` and `PrototypingMultiProcessingReadingService`, a few code should be removed. And, for in-process reading service, we still need a method to isolate global RNGs to prevent data-pipeline interferes randomness against model.

For multiprocessing case, it will set the same random seed for `Shuffler` and set different deterministic seeds for global RNGs including `python.random`, `torch` and `numpy` within each subprocess.
For distributed case, it will share the same random seed for `Shuffler` across all distributed subprocesses to guarantee the shuffle order before sharding.

Test Plan:
All tests are executed in the combinations of the above environments
- [x] Validate the same seed will generate the same order of data
- [x] Validate different seeds will generate different order of data
- [x] Validate the data after shuffle and sharding in each worker are mutually exclusive and collectively exhaustive with/without manual seed

There is one missing test I will add tmrw
- [x] Validate subprocess-local RNGs like `random`, `torch` and `numpy` are properly set with different seeds.

Pull Request resolved: pytorch#827

Reviewed By: VitalyFedyunin, NivekT

Differential Revision: D40323946

Pulled By: ejguan

fbshipit-source-id: 2997d6d5dce87a6c38d5ebdf64a00f9769bb18fa
  • Loading branch information
ejguan committed Oct 23, 2022
1 parent 940bc40 commit 7fb44ee
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 74 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
pip3 install --pre torch -f "${{ steps.pytorch_channel.outputs.value }}"
- name: Install dependencies
run: |
pip3 install requests mypy==0.812 graphviz
pip3 install requests mypy==0.812 graphviz numpy
- name: Build TorchData
run: |
python setup.py develop
Expand Down
85 changes: 85 additions & 0 deletions test/dataloader2/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import random
import unittest

from unittest import TestCase

import numpy as np

import torch

from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper


def _random_fn(data):
r"""
Used to validate the randomness of subprocess-local RNGs are set deterministically.
"""
py_random_num = random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32)
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item()
return (data, py_random_num, np_random_num, torch_random_num)


class DeterminismTest(TestCase):
@parametrize("num_workers", [0, 8])
def test_proto_rs_determinism(self, num_workers):
data_length = 64
exp = list(range(data_length))

data_source = IterableWrapper(exp)
dp = data_source.shuffle().sharding_filter().map(_random_fn)
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers)
dl = DataLoader2(dp, reading_service=rs)

# No seed
res = []
for d, *_ in dl:
res.append(d)
self.assertEqual(sorted(res), exp)

# Shuffle with seed
results = []
for _ in range(2):
res = []
ran_res = []
torch.manual_seed(123)
random.seed(123)
np.random.seed(123)
for d, *ran_nums in dl:
res.append(d)
ran_res.append(ran_nums)
self.assertEqual(sorted(res), exp)
results.append((res, ran_res))
# Same seed generate the same order of data and the same random state
self.assertEqual(results[0], results[1])

# Different seed
res = []
ran_res = []
torch.manual_seed(321)
random.seed(321)
np.random.seed(321)
for d, *ran_nums in dl:
res.append(d)
ran_res.append(ran_nums)
self.assertEqual(sorted(res), exp)
# Different shuffle order
self.assertNotEqual(results[0][0], res)
# Different subprocess-local random state
self.assertNotEqual(results[0][1], ran_res)


instantiate_parametrized_tests(DeterminismTest)


if __name__ == "__main__":
unittest.main()
193 changes: 153 additions & 40 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@


import os
import queue
import random
import socket
import sys
import unittest

from functools import partial
from unittest import TestCase

import numpy as np

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torch.utils.data import DataLoader

from torchdata.dataloader2 import DataLoader2, DistributedReadingService
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter.util.prefetch import PrefetchTimeoutError

Expand All @@ -45,32 +50,123 @@ def abs_path(path):


def _get_open_port():
import socket

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return str(port)


def launch_distributed_training(backend, world_size, fn):
class TerminateSignal:
pass


# TODO(ejguan): Use queue for all distributed tests
def launch_distributed_training(backend, world_size, *args, fn):
os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR
os.environ["MASTER_PORT"] = _get_open_port()
mp.spawn(
fn,
args=(
world_size,
backend,
),
nprocs=world_size,
join=True,
)
ctx = mp.get_context("spawn")
q = ctx.Queue()
ps = []
for rank in range(world_size):
p = ctx.Process(
target=fn,
args=(
rank,
world_size,
backend,
q,
*args,
),
)
p.start()
ps.append(p)
res = []
while True:
try:
d = q.get()
if isinstance(d, TerminateSignal):
break
res.append(d)
except queue.Empty:
continue
for p in ps:
p.join()
return res


def _dist_iterate_one_epoch(dl, seed=None):
r"""
Iterate a full epoch of DataLoader and set seeds for global RNGs if provided.
"""
if seed is not None:
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
res = []
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
return res


def _finalize_distributed_queue(rank, q):
r"""
Synchronize all distributed processes to guarantee all data have been put into
the Multiprocessing Queue.
"""
pg = dist.new_group(backend="gloo")
end_tensor = torch.tensor([rank], dtype=torch.int64)
dist.all_reduce(end_tensor, group=pg)
if rank == 0:
q.put(TerminateSignal())

dist.destroy_process_group(pg)


def _random_fn(data):
r"""
Used to validate the randomness of subprocess-local RNGs are set deterministically.
"""
py_random_num = random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32)
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item()
return (data, py_random_num, np_random_num, torch_random_num)


def _test_proto_distributed_training(rank, world_size, backend, q, num_workers):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Balanced data
data_length = world_size * 8
if num_workers > 0:
data_length *= num_workers

data_source = IterableWrapper(list(range(data_length)))
dp = data_source.shuffle().sharding_filter().map(_random_fn)
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers)
dl = DataLoader2(dp, reading_service=rs)

# No seed
res = _dist_iterate_one_epoch(dl, seed=None)
q.put((0, rank, res))

# Shuffle with seed
for epoch in range(2):
res = _dist_iterate_one_epoch(dl, seed=123)
q.put((epoch + 1, rank, res))

# Different seed
res = _dist_iterate_one_epoch(dl, seed=321)
q.put((3, rank, res))

_finalize_distributed_queue(rank, q)
dl.shutdown()


class DistributedTest(TestCase):
@staticmethod
def _test_fullsync(rank, world_size, backend):
def _test_fullsync(rank, world_size, backend, q):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Use a prime number to make sure uneven data sharding
data_length = 23
Expand All @@ -79,11 +175,7 @@ def _test_fullsync(rank, world_size, backend):

dp1 = dp.fullsync()
for _ in range(2):
res = []
for d in dp1:
res.append(d)
# Simulate training synchronization
dist.barrier()
res = _dist_iterate_one_epoch(dp1)
assert res == list(range(rank, data_length // world_size * world_size, world_size))

# Timeout Test
Expand All @@ -94,10 +186,12 @@ def _test_fullsync(rank, world_size, backend):
except Exception as e:
assert isinstance(e, PrefetchTimeoutError)

_finalize_distributed_queue(rank, q)

@backend_parametrize
def test_fullsync(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, DistributedTest._test_fullsync)
launch_distributed_training(backend, world_size, fn=DistributedTest._test_fullsync)

@staticmethod
def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None):
Expand All @@ -118,48 +212,38 @@ def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None):
return dl

@staticmethod
def _test_distributed_training(dl2, rank, world_size, backend):
def _test_distributed_training(dl2, rank, world_size, backend, q):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Use a prime number to make sure uneven data sharding
data_length = 23

# No shuffle
dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=False)
res = []
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
res = _dist_iterate_one_epoch(dl)
assert sorted(res) == list(range(rank, data_length // world_size * world_size, world_size))

# Shuffle
dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=True)
results = []
for _ in range(2):
res = []
torch.manual_seed(123)
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
res = _dist_iterate_one_epoch(dl, seed=123)
results.append(res)
assert results[0] == results[1]

# Different seed
res = []
torch.manual_seed(321)
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
res = _dist_iterate_one_epoch(dl, seed=321)
results.append(res)
assert len(results[0]) == len(results[2])
assert results[0] != results[2]

_finalize_distributed_queue(rank, q)
if dl2:
dl.shutdown()

@backend_parametrize
def test_distributed_dl2(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, partial(DistributedTest._test_distributed_training, True))
launch_distributed_training(backend, world_size, fn=partial(DistributedTest._test_distributed_training, True))

@unittest.skipIf(
IS_WINDOWS,
Expand All @@ -185,7 +269,7 @@ def test_elastic_training_dl2(self, backend) -> None:
@backend_parametrize
def test_distributed_dl1(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, partial(DistributedTest._test_distributed_training, False))
launch_distributed_training(backend, world_size, fn=partial(DistributedTest._test_distributed_training, False))

@unittest.skipIf(
IS_WINDOWS,
Expand All @@ -209,6 +293,35 @@ def test_elastic_training_dl1(self, backend) -> None:
],
)

@backend_parametrize
@parametrize("num_workers", [0, 8])
def test_proto_rs_dl2(self, backend, num_workers) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
res = launch_distributed_training(backend, world_size, num_workers, fn=_test_proto_distributed_training)
result = ({}, {}, {}, {})
for epoch, rank, r in res:
d, *ran_nums = list(zip(*r))
result[epoch][rank] = (d, ran_nums)
# Same seed generate the same order of data and the same random state
self.assertEqual(result[1], result[2])
# Different seeds
for rank in range(world_size):
# Different shuffle order
self.assertNotEqual(result[1][rank][0], result[3][rank][0])
# Different subprocess-local random state
self.assertNotEqual(result[1][rank][1], result[3][rank][1])

# Mutually exclusive and collectively exhaustive with/without seed
data_length = world_size * 8
if num_workers > 0:
data_length *= num_workers
exp = list(range(data_length))
for res in result:
concat_res = []
for r in res.values():
concat_res.extend(r[0])
self.assertEqual(sorted(concat_res), exp)


instantiate_parametrized_tests(DistributedTest)

Expand Down
Loading

0 comments on commit 7fb44ee

Please sign in to comment.