diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 019e77f568..8159c6a1a7 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -50,6 +50,7 @@ is_rocm_pytorch) # yapf: enable from .registry import Registry, build_from_cfg + from .seed import worker_init_fn from .trace import is_jit_tracing __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', @@ -70,5 +71,5 @@ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'load_url', 'has_method' + '_get_cuda_home', 'load_url', 'has_method', 'worker_init_fn' ] diff --git a/mmcv/utils/seed.py b/mmcv/utils/seed.py new file mode 100644 index 0000000000..003f923677 --- /dev/null +++ b/mmcv/utils/seed.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +import numpy as np +import torch + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): + """Function to initialize each worker. + + The seed of each worker equals to + ``num_worker * rank + worker_id + user_seed``. + + Args: + worker_id (int): Id for each worker. + num_workers (int): Number of workers. + rank (int): Rank in distributed training. + seed (int): Random seed. + """ + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed)