diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 00593ae0d..ecf07a2b6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -53,7 +53,7 @@ jobs: pip3 install --pre torch -f "${{ steps.pytorch_channel.outputs.value }}" - name: Install dependencies run: | - pip3 install requests mypy==0.812 graphviz + pip3 install requests mypy==0.812 graphviz numpy - name: Build TorchData run: | python setup.py develop diff --git a/test/dataloader2/test_random.py b/test/dataloader2/test_random.py new file mode 100644 index 000000000..7e575afd7 --- /dev/null +++ b/test/dataloader2/test_random.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import random +import unittest + +from unittest import TestCase + +import numpy as np + +import torch + +from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize +from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService +from torchdata.datapipes.iter import IterableWrapper + + +def _random_fn(data): + r""" + Used to validate the randomness of subprocess-local RNGs are set deterministically. + """ + py_random_num = random.randint(0, 2 ** 32) + np_random_num = np.random.randint(0, 2 ** 32) + torch_random_num = torch.randint(0, 2 ** 32, size=[]).item() + return (data, py_random_num, np_random_num, torch_random_num) + + +class DeterminismTest(TestCase): + @parametrize("num_workers", [0, 8]) + def test_proto_rs_determinism(self, num_workers): + data_length = 64 + exp = list(range(data_length)) + + data_source = IterableWrapper(exp) + dp = data_source.shuffle().sharding_filter().map(_random_fn) + rs = PrototypeMultiProcessingReadingService(num_workers=num_workers) + dl = DataLoader2(dp, reading_service=rs) + + # No seed + res = [] + for d, *_ in dl: + res.append(d) + self.assertEqual(sorted(res), exp) + + # Shuffle with seed + results = [] + for _ in range(2): + res = [] + ran_res = [] + torch.manual_seed(123) + random.seed(123) + np.random.seed(123) + for d, *ran_nums in dl: + res.append(d) + ran_res.append(ran_nums) + self.assertEqual(sorted(res), exp) + results.append((res, ran_res)) + # Same seed generate the same order of data and the same random state + self.assertEqual(results[0], results[1]) + + # Different seed + res = [] + ran_res = [] + torch.manual_seed(321) + random.seed(321) + np.random.seed(321) + for d, *ran_nums in dl: + res.append(d) + ran_res.append(ran_nums) + self.assertEqual(sorted(res), exp) + # Different shuffle order + self.assertNotEqual(results[0][0], res) + # Different subprocess-local random state + self.assertNotEqual(results[0][1], ran_res) + + +instantiate_parametrized_tests(DeterminismTest) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_distributed.py b/test/test_distributed.py index 20706d904..87ce92bb9 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -6,19 +6,24 @@ import os +import queue +import random +import socket import sys import unittest from functools import partial from unittest import TestCase +import numpy as np + import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize from torch.utils.data import DataLoader -from torchdata.dataloader2 import DataLoader2, DistributedReadingService +from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService from torchdata.datapipes.iter import IterableWrapper from torchdata.datapipes.iter.util.prefetch import PrefetchTimeoutError @@ -45,8 +50,6 @@ def abs_path(path): def _get_open_port(): - import socket - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] @@ -54,23 +57,116 @@ def _get_open_port(): return str(port) -def launch_distributed_training(backend, world_size, fn): +class TerminateSignal: + pass + + +# TODO(ejguan): Use queue for all distributed tests +def launch_distributed_training(backend, world_size, *args, fn): os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR os.environ["MASTER_PORT"] = _get_open_port() - mp.spawn( - fn, - args=( - world_size, - backend, - ), - nprocs=world_size, - join=True, - ) + ctx = mp.get_context("spawn") + q = ctx.Queue() + ps = [] + for rank in range(world_size): + p = ctx.Process( + target=fn, + args=( + rank, + world_size, + backend, + q, + *args, + ), + ) + p.start() + ps.append(p) + res = [] + while True: + try: + d = q.get() + if isinstance(d, TerminateSignal): + break + res.append(d) + except queue.Empty: + continue + for p in ps: + p.join() + return res + + +def _dist_iterate_one_epoch(dl, seed=None): + r""" + Iterate a full epoch of DataLoader and set seeds for global RNGs if provided. + """ + if seed is not None: + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + res = [] + for d in dl: + res.append(d) + # Simulate training synchronization + dist.barrier() + return res + + +def _finalize_distributed_queue(rank, q): + r""" + Synchronize all distributed processes to guarantee all data have been put into + the Multiprocessing Queue. + """ + pg = dist.new_group(backend="gloo") + end_tensor = torch.tensor([rank], dtype=torch.int64) + dist.all_reduce(end_tensor, group=pg) + if rank == 0: + q.put(TerminateSignal()) + + dist.destroy_process_group(pg) + + +def _random_fn(data): + r""" + Used to validate the randomness of subprocess-local RNGs are set deterministically. + """ + py_random_num = random.randint(0, 2 ** 32) + np_random_num = np.random.randint(0, 2 ** 32) + torch_random_num = torch.randint(0, 2 ** 32, size=[]).item() + return (data, py_random_num, np_random_num, torch_random_num) + + +def _test_proto_distributed_training(rank, world_size, backend, q, num_workers): + dist.init_process_group(backend, rank=rank, world_size=world_size) + # Balanced data + data_length = world_size * 8 + if num_workers > 0: + data_length *= num_workers + + data_source = IterableWrapper(list(range(data_length))) + dp = data_source.shuffle().sharding_filter().map(_random_fn) + rs = PrototypeMultiProcessingReadingService(num_workers=num_workers) + dl = DataLoader2(dp, reading_service=rs) + + # No seed + res = _dist_iterate_one_epoch(dl, seed=None) + q.put((0, rank, res)) + + # Shuffle with seed + for epoch in range(2): + res = _dist_iterate_one_epoch(dl, seed=123) + q.put((epoch + 1, rank, res)) + + # Different seed + res = _dist_iterate_one_epoch(dl, seed=321) + q.put((3, rank, res)) + + _finalize_distributed_queue(rank, q) + dl.shutdown() class DistributedTest(TestCase): @staticmethod - def _test_fullsync(rank, world_size, backend): + def _test_fullsync(rank, world_size, backend, q): dist.init_process_group(backend, rank=rank, world_size=world_size) # Use a prime number to make sure uneven data sharding data_length = 23 @@ -79,11 +175,7 @@ def _test_fullsync(rank, world_size, backend): dp1 = dp.fullsync() for _ in range(2): - res = [] - for d in dp1: - res.append(d) - # Simulate training synchronization - dist.barrier() + res = _dist_iterate_one_epoch(dp1) assert res == list(range(rank, data_length // world_size * world_size, world_size)) # Timeout Test @@ -94,10 +186,12 @@ def _test_fullsync(rank, world_size, backend): except Exception as e: assert isinstance(e, PrefetchTimeoutError) + _finalize_distributed_queue(rank, q) + @backend_parametrize def test_fullsync(self, backend) -> None: world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() - launch_distributed_training(backend, world_size, DistributedTest._test_fullsync) + launch_distributed_training(backend, world_size, fn=DistributedTest._test_fullsync) @staticmethod def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None): @@ -118,48 +212,38 @@ def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None): return dl @staticmethod - def _test_distributed_training(dl2, rank, world_size, backend): + def _test_distributed_training(dl2, rank, world_size, backend, q): dist.init_process_group(backend, rank=rank, world_size=world_size) # Use a prime number to make sure uneven data sharding data_length = 23 # No shuffle dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=False) - res = [] - for d in dl: - res.append(d) - # Simulate training synchronization - dist.barrier() + res = _dist_iterate_one_epoch(dl) assert sorted(res) == list(range(rank, data_length // world_size * world_size, world_size)) # Shuffle dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=True) results = [] for _ in range(2): - res = [] - torch.manual_seed(123) - for d in dl: - res.append(d) - # Simulate training synchronization - dist.barrier() + res = _dist_iterate_one_epoch(dl, seed=123) results.append(res) assert results[0] == results[1] # Different seed - res = [] - torch.manual_seed(321) - for d in dl: - res.append(d) - # Simulate training synchronization - dist.barrier() + res = _dist_iterate_one_epoch(dl, seed=321) results.append(res) assert len(results[0]) == len(results[2]) assert results[0] != results[2] + _finalize_distributed_queue(rank, q) + if dl2: + dl.shutdown() + @backend_parametrize def test_distributed_dl2(self, backend) -> None: world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() - launch_distributed_training(backend, world_size, partial(DistributedTest._test_distributed_training, True)) + launch_distributed_training(backend, world_size, fn=partial(DistributedTest._test_distributed_training, True)) @unittest.skipIf( IS_WINDOWS, @@ -185,7 +269,7 @@ def test_elastic_training_dl2(self, backend) -> None: @backend_parametrize def test_distributed_dl1(self, backend) -> None: world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() - launch_distributed_training(backend, world_size, partial(DistributedTest._test_distributed_training, False)) + launch_distributed_training(backend, world_size, fn=partial(DistributedTest._test_distributed_training, False)) @unittest.skipIf( IS_WINDOWS, @@ -209,6 +293,35 @@ def test_elastic_training_dl1(self, backend) -> None: ], ) + @backend_parametrize + @parametrize("num_workers", [0, 8]) + def test_proto_rs_dl2(self, backend, num_workers) -> None: + world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() + res = launch_distributed_training(backend, world_size, num_workers, fn=_test_proto_distributed_training) + result = ({}, {}, {}, {}) + for epoch, rank, r in res: + d, *ran_nums = list(zip(*r)) + result[epoch][rank] = (d, ran_nums) + # Same seed generate the same order of data and the same random state + self.assertEqual(result[1], result[2]) + # Different seeds + for rank in range(world_size): + # Different shuffle order + self.assertNotEqual(result[1][rank][0], result[3][rank][0]) + # Different subprocess-local random state + self.assertNotEqual(result[1][rank][1], result[3][rank][1]) + + # Mutually exclusive and collectively exhaustive with/without seed + data_length = world_size * 8 + if num_workers > 0: + data_length *= num_workers + exp = list(range(data_length)) + for res in result: + concat_res = [] + for r in res.values(): + concat_res.extend(r[0]) + self.assertEqual(sorted(concat_res), exp) + instantiate_parametrized_tests(DistributedTest) diff --git a/torchdata/dataloader2/communication/eventloop.py b/torchdata/dataloader2/communication/eventloop.py index 6d43402a0..ce89d279a 100644 --- a/torchdata/dataloader2/communication/eventloop.py +++ b/torchdata/dataloader2/communication/eventloop.py @@ -31,9 +31,9 @@ ] -def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_locally_fn=None, call_on_reset_epoch=None): - if call_locally_fn is not None: - call_locally_fn(source_datapipe) +def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_init=None, call_on_epoch_reset=None): + if call_on_process_init is not None: + call_on_process_init(source_datapipe) if isinstance(source_datapipe, IterDataPipe): pipe_type = communication.iter protocol_type = communication.protocol.IterDataPipeQueueProtocolServer @@ -48,16 +48,16 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_locally_fn= source_datapipe, protocol_type(req_queue, res_queue), blocking_request_get=True, - reset_epoch_fn=call_on_reset_epoch, + reset_epoch_fn=call_on_epoch_reset, ): pass -def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe, call_locally_fn=None, call_on_reset_epoch=None): +def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe, call_on_process_init=None, call_on_epoch_reset=None): req_queue = multiprocessing_ctx.Queue() res_queue = multiprocessing_ctx.Queue() process = multiprocessing_ctx.Process( - target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, call_locally_fn, call_on_reset_epoch) + target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, call_on_process_init, call_on_epoch_reset) ) return process, req_queue, res_queue diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index ec7e0b678..a445c9b85 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -7,6 +7,7 @@ import functools import multiprocessing as mp +import random from abc import ABC, abstractmethod @@ -23,6 +24,13 @@ from torchdata.dataloader2.graph import DataPipe from torchdata.datapipes.iter import FullSync, IterableWrapper, IterDataPipe +try: + import numpy + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + class ReadingServiceInterface(ABC): r""" @@ -99,6 +107,10 @@ def _collate_no_op(batch): return batch[0] +def _generate_random_seed(rng: Optional[torch.Generator] = None, dtype: torch.dtype = torch.int64) -> torch.Tensor: + return torch.empty((), dtype=dtype).random_(generator=rng) + + class _IterateQueueDataPipes(IterDataPipe): def __init__(self, datapipes): # TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper @@ -162,7 +174,11 @@ class PrototypeMultiProcessingReadingService(ReadingServiceInterface): num_workers: int processes: List datapipes: List - combined_datapipes: Optional[IterDataPipe] + end_datapipe: Optional[DataPipe] + _mp: bool + _pg: Optional[dist.ProcessGroup] + _world_size: int + _rank: int def __init__( self, @@ -178,42 +194,75 @@ def __init__( self.prefetch_mainloop = prefetch_mainloop self.processes = [] self.datapipes = [] - self.combined_datapipes = None + self.end_datapipe = None + self._mp = num_workers > 0 + self._pg = None + self._world_size = 1 + self._rank = 0 @staticmethod - def init_datapipe_process(num_workers, worker_id, datapipe): - # TODO(614): Add distributed support - # TODO(615): Add shuffle determinism support - torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id) + def _process_init_fn(world_size, rank, num_workers, worker_id, datapipe): + global_worker_id = worker_id * world_size + rank + total_num_workers = num_workers * world_size + torch.utils.data.graph_settings.apply_sharding(datapipe, total_num_workers, global_worker_id) @staticmethod - def call_on_epoch_reset(datapipe, *args): + def _process_reset_fn(world_size, rank, num_workers, worker_id, datapipe, shared_seed): # This function will receive worker local copy of datapipe and args value from initialize_iteration - pass + worker_seed_generator = torch.Generator() + worker_seed_generator.manual_seed(shared_seed) + torch.utils.data.graph_settings.apply_random_seed( + datapipe, + worker_seed_generator, + ) + # Set different seeds across distributed workers + global_worker_id = worker_id * world_size + rank + worker_seed_generator.manual_seed(shared_seed + global_worker_id) + + py_seed = _generate_random_seed(worker_seed_generator).item() + random.seed(py_seed) + + torch_seed = _generate_random_seed(worker_seed_generator).item() + torch.manual_seed(torch_seed) + + if HAS_NUMPY: + # Numpy only accepts uint32 as the seed + np_seed = _generate_random_seed(worker_seed_generator, torch.int32).item() + if np_seed < 0: + np_seed = 2 ** 32 + np_seed + numpy.random.seed(np_seed) def initialize(self, datapipe: DataPipe) -> DataPipe: r""" - ``MultiProcessingReadingService`` finds information about sharding, + ``PrototypeMultiProcessingReadingService`` finds information about sharding, separates graph by multiple pieces and reconnects it using queues. creates subprocesses. """ - if self.num_workers == 0: + if dist.is_available() and dist.is_initialized(): + self._world_size = dist.get_world_size() + self._rank = dist.get_rank() + self._pg = dist.new_group(backend="gloo") + if not self._mp: # TODO(616): Warn and recommend usage of InProcessReadingService + self._process_init_fn(self._world_size, self._rank, 1, 0, datapipe) + self.end_datapipe = datapipe return datapipe if self.prefetch_worker > 0: datapipe = datapipe.prefetch(self.prefetch_worker) for worker_id in range(self.num_workers): - # TODO(617): Separate into function, because we also need to apply distributed seed - # and call it inside process - call_inside_process = functools.partial(self.init_datapipe_process, self.num_workers, worker_id) - call_on_epoch_reset = self.call_on_epoch_reset + call_on_process_init = functools.partial( + self._process_init_fn, self._world_size, self._rank, self.num_workers, worker_id + ) + call_on_epoch_reset = functools.partial( + self._process_reset_fn, self._world_size, self._rank, self.num_workers, worker_id + ) ctx = mp.get_context(self.multiprocessing_context) (process, req_queue, res_queue) = communication.eventloop.SpawnProcessForDataPipeline( ctx, datapipe, - call_inside_process, + call_on_process_init, call_on_epoch_reset, ) process.daemon = True @@ -224,28 +273,48 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: ) self.datapipes.append(local_datapipe) - self.combined_datapipes = _IterateQueueDataPipes(self.datapipes) + self.end_datapipe = _IterateQueueDataPipes(self.datapipes) # type: ignore[assignment] if self.prefetch_mainloop > 0: - self.combined_datapipes = self.combined_datapipes.prefetch(self.prefetch_mainloop) - return self.combined_datapipes # type: ignore[return-value] + self.end_datapipe = self.end_datapipe.prefetch(self.prefetch_mainloop) # type: ignore[union-attr] + return self.end_datapipe # type: ignore[return-value] def initialize_iteration(self) -> None: - if self.combined_datapipes is not None: + shared_seed = _generate_random_seed() + if self._pg is not None: + dist.broadcast(shared_seed, src=0, group=self._pg) + shared_seed_int: int = shared_seed.item() # type: ignore[assignment] + _seed_generator = torch.Generator() + _seed_generator.manual_seed(shared_seed_int) + torch.utils.data.graph_settings.apply_random_seed( + self.end_datapipe, # type: ignore[arg-type] + _seed_generator, + ) + + assert self.end_datapipe is not None + if self._mp: if self.prefetch_mainloop > 0: # Stop prefetching first - self.combined_datapipes.reset() - self.combined_datapipes.source_datapipe.reset_epoch() - self.combined_datapipes.source_datapipe.reset() + self.end_datapipe.reset() # type: ignore[union-attr] + end_datapipe: DataPipe = self.end_datapipe.source_datapipe else: - self.combined_datapipes.reset_epoch() - self.combined_datapipes.reset() + end_datapipe = self.end_datapipe + # Send the shared seed to subprocesses + end_datapipe.reset_epoch(shared_seed_int) + end_datapipe.reset() + # In-process (num_workers == 0) + else: + # Technically speaking, we should call `_process_reset_fn` to reset global RNGs + # for data-related operations. However, it would pollute the state of global RNGs + # (random, torch and numpy), if users have already seeded them in the main process + # TODO(ejguan): This should be fixed by adding a method to isolate global RNGs + pass def __del__(self): self.finalize() def finalize(self) -> None: r""" - ``MultiProcessingReadingService`` invalidate states & properly exits all subprocesses. + ``PrototypeMultiProcessingReadingService`` invalidate states & properly exits all subprocesses. """ # TODO(618): Check if anyone stuck with messages def clean_me(process, req_queue, res_queue): @@ -267,6 +336,10 @@ def clean_me(process, req_queue, res_queue): self.processes = [] + if self._pg is not None: + dist.destroy_process_group(self._pg) + self._pg = None + class MultiProcessingReadingService(ReadingServiceInterface): r""" @@ -384,12 +457,17 @@ def initialize_iteration(self) -> None: ) def _share_seed(self): - shared_seed = torch.empty((), dtype=torch.int64).random_() + shared_seed = _generate_random_seed() dist.broadcast(shared_seed, src=0, group=self._pg) return shared_seed.item() + def __del__(self): + self.finalize() + def finalize(self) -> None: r""" Clean up the distributed process group. """ - self._pg = None + if self._pg is not None: + dist.destroy_process_group(self._pg) + self._pg = None