forked from pytorch/data
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unblock ProtoMPRS to control determinism of DataPipe in single/multi-…
…processing and dist/non-dist env (pytorch#827) Summary: This PR temporarily extend `PrototypingMultiProcessingReadingService` to fully control the determinism of the pipeline in the combinations of: - Single/Multi-processing - Distributed/Non-distributed When we have `SequentialReadingService` ready to combine `DistributedReadingService` and `PrototypingMultiProcessingReadingService`, a few code should be removed. And, for in-process reading service, we still need a method to isolate global RNGs to prevent data-pipeline interferes randomness against model. For multiprocessing case, it will set the same random seed for `Shuffler` and set different deterministic seeds for global RNGs including `python.random`, `torch` and `numpy` within each subprocess. For distributed case, it will share the same random seed for `Shuffler` across all distributed subprocesses to guarantee the shuffle order before sharding. Test Plan: All tests are executed in the combinations of the above environments - [x] Validate the same seed will generate the same order of data - [x] Validate different seeds will generate different order of data - [x] Validate the data after shuffle and sharding in each worker are mutually exclusive and collectively exhaustive with/without manual seed There is one missing test I will add tmrw - [x] Validate subprocess-local RNGs like `random`, `torch` and `numpy` are properly set with different seeds. Pull Request resolved: pytorch#827 Reviewed By: VitalyFedyunin, NivekT Differential Revision: D40323946 Pulled By: ejguan fbshipit-source-id: 2997d6d5dce87a6c38d5ebdf64a00f9769bb18fa
- Loading branch information
Showing
5 changed files
with
350 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import random | ||
import unittest | ||
|
||
from unittest import TestCase | ||
|
||
import numpy as np | ||
|
||
import torch | ||
|
||
from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize | ||
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService | ||
from torchdata.datapipes.iter import IterableWrapper | ||
|
||
|
||
def _random_fn(data): | ||
r""" | ||
Used to validate the randomness of subprocess-local RNGs are set deterministically. | ||
""" | ||
py_random_num = random.randint(0, 2 ** 32) | ||
np_random_num = np.random.randint(0, 2 ** 32) | ||
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item() | ||
return (data, py_random_num, np_random_num, torch_random_num) | ||
|
||
|
||
class DeterminismTest(TestCase): | ||
@parametrize("num_workers", [0, 8]) | ||
def test_proto_rs_determinism(self, num_workers): | ||
data_length = 64 | ||
exp = list(range(data_length)) | ||
|
||
data_source = IterableWrapper(exp) | ||
dp = data_source.shuffle().sharding_filter().map(_random_fn) | ||
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers) | ||
dl = DataLoader2(dp, reading_service=rs) | ||
|
||
# No seed | ||
res = [] | ||
for d, *_ in dl: | ||
res.append(d) | ||
self.assertEqual(sorted(res), exp) | ||
|
||
# Shuffle with seed | ||
results = [] | ||
for _ in range(2): | ||
res = [] | ||
ran_res = [] | ||
torch.manual_seed(123) | ||
random.seed(123) | ||
np.random.seed(123) | ||
for d, *ran_nums in dl: | ||
res.append(d) | ||
ran_res.append(ran_nums) | ||
self.assertEqual(sorted(res), exp) | ||
results.append((res, ran_res)) | ||
# Same seed generate the same order of data and the same random state | ||
self.assertEqual(results[0], results[1]) | ||
|
||
# Different seed | ||
res = [] | ||
ran_res = [] | ||
torch.manual_seed(321) | ||
random.seed(321) | ||
np.random.seed(321) | ||
for d, *ran_nums in dl: | ||
res.append(d) | ||
ran_res.append(ran_nums) | ||
self.assertEqual(sorted(res), exp) | ||
# Different shuffle order | ||
self.assertNotEqual(results[0][0], res) | ||
# Different subprocess-local random state | ||
self.assertNotEqual(results[0][1], ran_res) | ||
|
||
|
||
instantiate_parametrized_tests(DeterminismTest) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.