diff --git a/CHANGELOG.md b/CHANGELOG.md index df7e1bdd8188f..c6551782e5d86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -243,6 +243,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414)) +- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index f3fb2fbeabaa2..9728fba932874 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Optional class ClusterEnvironment(ABC): @@ -31,9 +30,21 @@ def master_port(self) -> int: """ An open and configured port in the master node through which all processes communicate. """ @abstractmethod - def world_size(self) -> Optional[int]: + def world_size(self) -> int: """ The number of processes across all devices and nodes. """ + @abstractmethod + def set_world_size(self, size: int) -> None: + pass + + @abstractmethod + def global_rank(self) -> int: + """ The rank (index) of the currently running process across all nodes and devices. """ + + @abstractmethod + def set_global_rank(self, rank: int) -> None: + pass + @abstractmethod def local_rank(self) -> int: """ The rank (index) of the currently running process inside of the current node. """ diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py index 6b71122b065bf..67752535fe4e1 100644 --- a/pytorch_lightning/plugins/environments/lightning_environment.py +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -14,9 +14,9 @@ import os import socket -from typing import Optional from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities import rank_zero_only class LightningEnvironment(ClusterEnvironment): @@ -34,6 +34,8 @@ class LightningEnvironment(ClusterEnvironment): def __init__(self): super().__init__() self._master_port = None + self._global_rank: int = 0 + self._world_size: int = 1 def creates_children(self) -> bool: return False @@ -46,8 +48,18 @@ def master_port(self) -> int: self._master_port = os.environ.get("MASTER_PORT", find_free_network_port()) return int(self._master_port) - def world_size(self) -> Optional[int]: - return None + def world_size(self) -> int: + return self._world_size + + def set_world_size(self, size: int) -> None: + self._world_size = size + + def global_rank(self) -> int: + return self._global_rank + + def set_global_rank(self, rank: int) -> None: + self._global_rank = rank + rank_zero_only.rank = rank def local_rank(self) -> int: return int(os.environ.get("LOCAL_RANK", 0)) diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 3cba5d101a159..0c91c064e391c 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -23,9 +23,6 @@ class SLURMEnvironment(ClusterEnvironment): - def __init__(self): - super().__init__() - def creates_children(self) -> bool: return True @@ -69,8 +66,17 @@ def master_port(self) -> int: return int(default_port) - def world_size(self): - return None + def world_size(self) -> int: + return int(os.environ["SLURM_NTASKS"]) + + def set_world_size(self, size: int) -> None: + log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["SLURM_PROCID"]) + + def set_global_rank(self, rank: int) -> None: + log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") def local_rank(self) -> int: return int(os.environ['SLURM_LOCALID']) diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py index c3a59fbfd75bc..bdaf148906cf1 100644 --- a/pytorch_lightning/plugins/environments/torchelastic_environment.py +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -24,8 +24,11 @@ class TorchElasticEnvironment(ClusterEnvironment): - def __init__(self): - super().__init__() + @staticmethod + def is_using_torchelastic() -> bool: + """ Returns ``True`` if the current process was launched using the torchelastic command. """ + required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE") + return all(v in os.environ for v in required_env_vars) def creates_children(self) -> bool: return True @@ -51,6 +54,17 @@ def world_size(self) -> Optional[int]: world_size = os.environ.get('WORLD_SIZE') return int(world_size) if world_size is not None else world_size + def set_world_size(self, size: int) -> None: + log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["RANK"]) + + def set_global_rank(self, rank: int) -> None: + log.debug( + "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." + ) + def local_rank(self) -> int: return int(os.environ['LOCAL_RANK']) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 5f411b65ae769..7e9624d9a0122 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -78,11 +78,11 @@ def __init__( self._ddp_kwargs = kwargs self._has_spawned_children = False self.task_idx = None - self.node_rank = 0 self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper + self.set_world_ranks() @property def root_device(self): @@ -193,7 +193,7 @@ def setup_distributed(self): # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table - self.init_ddp_connection(self.global_rank, self.world_size) + self.init_ddp_connection() # on world_size=0 let everyone know training is starting if self.is_global_zero and not torch.distributed.is_initialized(): @@ -213,11 +213,11 @@ def _check_can_spawn_children(self): " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." ) - def set_world_ranks(self): - self.local_rank = self.task_idx - self.node_rank = self.cluster_environment.node_rank() - self.global_rank = self.node_rank * self.num_processes + self.local_rank - self.world_size = self.num_nodes * self.num_processes + def set_world_ranks(self) -> None: + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` @@ -260,11 +260,11 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def init_ddp_connection(self, global_rank: int, world_size: int) -> None: - os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None: + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) - if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index f19fb05a16233..d7c3d84184926 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -19,6 +19,14 @@ class DDP2Plugin(DDPPlugin): + @property + def global_rank(self) -> int: + return self.node_rank + + @property + def world_size(self) -> int: + return self.num_nodes + def setup(self, model): self._model = model # set the task idx @@ -64,7 +72,5 @@ def _is_single_process_single_device(self) -> bool: return False def set_world_ranks(self): - self.local_rank = self.task_idx - self.node_rank = self.cluster_environment.node_rank() - self.global_rank = self.node_rank - self.world_size = self.num_nodes + self.cluster_environment.set_global_rank(self.node_rank) + self.cluster_environment.set_world_size(self.num_nodes) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index e9028729346ab..2a98bb9931156 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -61,11 +61,16 @@ def __init__( self._ddp_kwargs = kwargs self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 - self.node_rank = 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper + self._local_rank = 0 + self.set_world_ranks() + + @property + def local_rank(self) -> int: + return self._local_rank def __getstate__(self): """ Makes this plugin pickleable without destroying the queue in the current process. """ @@ -95,12 +100,12 @@ def setup(self, model): smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() - def set_world_ranks(self, process_idx): - self.local_rank = process_idx - self.node_rank = self.cluster_environment.node_rank() - self.task_idx = self.cluster_environment.local_rank() - self.global_rank = self.node_rank * self.num_processes + self.local_rank - self.world_size = self.num_nodes * self.num_processes + def set_world_ranks(self, process_idx: int = 0) -> None: + self._local_rank = process_idx + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() @property def mp_spawn_kwargs(self): @@ -213,11 +218,12 @@ def configure_ddp(self): ) self._register_ddp_hooks() - def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + def init_ddp_connection(self, global_rank: Optional[int], world_size: Optional[int]) -> None: # TODO: this code is duplicated in DDP and DDPSpawn, make this a function - os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index a8e42e0fa747a..131a134ca724d 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -27,6 +27,22 @@ class DataParallelPlugin(ParallelPlugin): def __init__(self, parallel_devices: Optional[List[torch.device]]): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + @property + def global_rank(self) -> int: + return 0 + + @property + def local_rank(self) -> int: + return 0 + + @property + def node_rank(self) -> int: + return 0 + + @property + def world_size(self) -> int: + return 1 + def setup(self, model): # model needs to be moved to the device before it is wrapped model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 8d0add27cbb29..6c2e6f3dfb1df 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -31,6 +31,19 @@ class HorovodPlugin(ParallelPlugin): def __init__(self, parallel_devices: Optional[List[torch.device]] = None): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + rank_zero_only.rank = self.global_rank + + @property + def global_rank(self) -> int: + return hvd.rank() + + @property + def local_rank(self) -> int: + return hvd.local_rank() + + @property + def world_size(self) -> int: + return hvd.size() @property def root_device(self): @@ -38,17 +51,11 @@ def root_device(self): @property def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs def setup(self, model): self._model = model - - self.global_rank = hvd.rank() - self.local_rank = hvd.local_rank() - self.world_size = hvd.size() - rank_zero_only.rank = self.global_rank - self.model_to_device() def pre_dispatch(self): @@ -63,14 +70,14 @@ def _unpack_lightning_optimizer(opt): # increased total batch size for optimizer in optimizers: for param_group in optimizer.param_groups: - param_group["lr"] *= hvd.size() + param_group["lr"] *= self.world_size # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR lr_schedulers = self.lightning_module.trainer.lr_schedulers for scheduler in lr_schedulers: scheduler = scheduler["scheduler"] if isinstance(scheduler, _LRScheduler): - scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] + scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index d9a8e70588c43..023bdcd0172ff 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -36,9 +36,6 @@ def __init__( super().__init__() self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment - self.global_rank = 0 - self.world_size = 1 - self.local_rank = 0 @property @abstractmethod @@ -53,6 +50,22 @@ def on_gpu(self): def lightning_module(self): return unwrap_lightning_module(self._model) + @property + def global_rank(self) -> int: + return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0 + + @property + def local_rank(self) -> int: + return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0 + + @property + def node_rank(self) -> int: + return self.cluster_environment.node_rank() if self.cluster_environment is not None else 0 + + @property + def world_size(self) -> int: + return self.cluster_environment.world_size() if self.cluster_environment is not None else 1 + @property def is_global_zero(self) -> bool: return self.global_rank == 0 diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index ba26fc9f58ec5..37b7ae994585b 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -100,8 +100,8 @@ def __init__( def init_ddp_connection( self, - global_rank: int, - world_size: int, + global_rank: Optional[int] = None, + world_size: Optional[int] = None, ) -> None: if self.lightning_module.trainer.amp_backend is not None: raise MisconfigurationException( @@ -110,10 +110,10 @@ def init_ddp_connection( if self._skip_init_connections(): return - super().init_ddp_connection( - global_rank=global_rank, - world_size=world_size, - ) + + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() + super().init_ddp_connection(global_rank, world_size) super().init_rpc_connection(global_rank=global_rank, world_size=world_size) model = self.lightning_module self.gpus_per_model = self._infer_check_num_gpus() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index b072a29c7fbc6..73e4b071bd976 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,7 +15,7 @@ import os import re import time -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import torch import torch.multiprocessing as mp @@ -41,7 +41,6 @@ if _OMEGACONF_AVAILABLE: from omegaconf import DictConfig, ListConfig, OmegaConf - if TYPE_CHECKING: from torch.nn import Module from torch.utils.data import DataLoader @@ -52,8 +51,21 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None: super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False) self.tpu_local_core_rank = 0 + self.tpu_global_core_rank = 0 self.start_method = None + @property + def global_rank(self) -> int: + return self.tpu_local_core_rank + + @property + def local_rank(self) -> int: + return self.tpu_local_core_rank + + @property + def world_size(self) -> int: + return self.num_processes + @staticmethod def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): if not isinstance(dataloaders, list): @@ -115,11 +127,9 @@ def configure_ddp(self) -> None: def init_ddp_connection(self, global_rank: int, world_size: int) -> None: pass - def set_world_ranks(self, process_idx: int) -> None: + def set_world_ranks(self, process_idx: int = 0) -> None: self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - self.global_rank = self.tpu_local_core_rank - self.world_size = self.num_nodes * self.num_processes def new_process(self, process_idx: int, trainer, mp_queue) -> None: self.mp_queue = mp_queue @@ -128,7 +138,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: if seed is not None: seed_everything(int(seed)) - self.set_world_ranks(process_idx) + self.set_world_ranks() # set warning rank rank_zero_only.rank = self.global_rank diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index aa52ec1c40d82..1f086bbee8ca3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -60,7 +60,7 @@ DeviceType, DistributedType, ) -from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException if _HOROVOD_AVAILABLE: @@ -302,8 +302,18 @@ def root_gpu(self) -> Optional[int]: @property def is_using_torchelastic(self) -> bool: - te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ) - return te_flags_passed + """ + .. deprecated:: v1.3 + Will be removed in v1.5.0. + + Returns: + ``True`` if the current process was launched using the torchelastic command. + """ + rank_zero_deprecation( + "The property `AcceleratorConnector.is_using_torchelastic` was deprecated in v1.3" + " and will be removed in 1.5. Use `TorchElasticEnvironment.is_using_torchelastic()` instead.", + ) + return TorchElasticEnvironment.is_using_torchelastic() def select_precision_plugin(self) -> PrecisionPlugin: # set precision type @@ -358,7 +368,12 @@ def select_precision_plugin(self) -> PrecisionPlugin: def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: - plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) + plugin = DDP2Plugin( + parallel_devices=self.parallel_devices, + num_nodes=self.num_nodes, + cluster_environment=self.cluster_environment, + sync_batchnorm=self.sync_batchnorm, + ) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( num_nodes=self.num_nodes, @@ -367,11 +382,11 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks - use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic + use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN - use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic + use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN @@ -459,7 +474,7 @@ def select_cluster_environment(self) -> ClusterEnvironment: return self._cluster_environment if self.is_slurm_managing_tasks: env = SLURMEnvironment() - elif self.is_using_torchelastic: + elif TorchElasticEnvironment.is_using_torchelastic(): env = TorchElasticEnvironment() else: env = LightningEnvironment() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 79a17df074e35..de927aa5fdd1a 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -95,7 +95,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", - "SLURM_LOCALID": "10" + "SLURM_PROCID": "1", + "SLURM_LOCALID": "1", } ) @mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) @@ -109,8 +110,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -125,15 +126,15 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@RunIf(min_gpus=1) +@RunIf(min_gpus=2) @mock.patch.dict( os.environ, { "CUDA_VISIBLE_DEVICES": "0,1", "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "10" + "SLURM_PROCID": "1", + "SLURM_LOCALID": "1" } ) @mock.patch('torch.cuda.device_count', return_value=2) @@ -148,8 +149,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -165,7 +166,16 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + } +) @mock.patch('torch.cuda.device_count', return_value=2) @mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock): @@ -177,8 +187,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -194,7 +204,16 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + } +) @mock.patch('torch.cuda.device_count', return_value=2) @mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock): @@ -206,8 +225,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -222,11 +241,15 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@mock.patch.dict(os.environ, { - "WORLD_SIZE": "1", - "LOCAL_RANK": "10", - "NODE_RANK": "0", -}) +@mock.patch.dict( + os.environ, { + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + } +) @mock.patch('torch.cuda.device_count', return_value=0) @mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock): @@ -238,8 +261,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -260,7 +283,8 @@ def on_fit_start(self, trainer, pl_module): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", } ) @mock.patch('torch.cuda.device_count', return_value=0) @@ -296,7 +320,8 @@ def on_fit_start(self, trainer, pl_module): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", } ) @mock.patch('torch.cuda.device_count', return_value=0) @@ -378,7 +403,8 @@ class TrainTypePlugin(SingleDevicePlugin): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", } ) @mock.patch('torch.cuda.device_count', return_value=0) diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 0b9d6776c1aaa..13b2e90592fa7 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -147,7 +147,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_LOCALID": "0", + "SLURM_PROCID": "0", } ) def test_amp_gpu_ddp_slurm_managed(tmpdir): diff --git a/tests/plugins/environments/test_lightning_environment.py b/tests/plugins/environments/test_lightning_environment.py index 83d26cb0fcf91..3f89b88bfc215 100644 --- a/tests/plugins/environments/test_lightning_environment.py +++ b/tests/plugins/environments/test_lightning_environment.py @@ -11,7 +11,7 @@ def test_default_attributes(): assert not env.creates_children() assert env.master_address() == "127.0.0.1" assert isinstance(env.master_port(), int) - assert env.world_size() is None + assert env.world_size() == 1 assert env.local_rank() == 0 assert env.node_rank() == 0 @@ -27,9 +27,14 @@ def test_attributes_from_environment_variables(): env = LightningEnvironment() assert env.master_address() == "1.2.3.4" assert env.master_port() == 500 - assert env.world_size() is None + assert env.world_size() == 1 + assert env.global_rank() == 0 assert env.local_rank() == 2 assert env.node_rank() == 3 + env.set_global_rank(100) + assert env.global_rank() == 100 + env.set_world_size(100) + assert env.world_size() == 100 @mock.patch.dict(os.environ, { diff --git a/tests/plugins/environments/test_slurm_environment.py b/tests/plugins/environments/test_slurm_environment.py index 8e82434846e68..0be88dbeb91c6 100644 --- a/tests/plugins/environments/test_slurm_environment.py +++ b/tests/plugins/environments/test_slurm_environment.py @@ -1,3 +1,4 @@ +import logging import os from unittest import mock @@ -13,7 +14,9 @@ def test_default_attributes(): assert env.creates_children() assert env.master_address() == "127.0.0.1" assert env.master_port() == 12910 - assert env.world_size() is None + with pytest.raises(KeyError): + # world size is required to be passed as env variable + env.world_size() with pytest.raises(KeyError): # local rank is required to be passed as env variable env.local_rank() @@ -26,19 +29,33 @@ def test_default_attributes(): os.environ, { "SLURM_NODELIST": "1.1.1.1, 1.1.1.2", "SLURM_JOB_ID": "0001234", - "WORLD_SIZE": "20", + "SLURM_NTASKS": "20", "SLURM_LOCALID": "2", + "SLURM_PROCID": "1", "SLURM_NODEID": "3", } ) -def test_attributes_from_environment_variables(): +def test_attributes_from_environment_variables(caplog): """ Test that the SLURM cluster environment takes the attributes from the environment variables. """ env = SLURMEnvironment() assert env.master_address() == "1.1.1.1" assert env.master_port() == 15000 + 1234 - assert env.world_size() is None + assert env.world_size() == 20 + assert env.global_rank() == 1 assert env.local_rank() == 2 assert env.node_rank() == 3 + # setter should be no-op + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_global_rank(100) + assert env.global_rank() == 1 + assert "setting global rank is not allowed" in caplog.text + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_world_size(100) + assert env.world_size() == 20 + assert "setting world size is not allowed" in caplog.text @pytest.mark.parametrize( diff --git a/tests/plugins/environments/test_torchelastic_environment.py b/tests/plugins/environments/test_torchelastic_environment.py index 55cfc25adde3c..2b9efafbbcc67 100644 --- a/tests/plugins/environments/test_torchelastic_environment.py +++ b/tests/plugins/environments/test_torchelastic_environment.py @@ -1,3 +1,4 @@ +import logging import os from unittest import mock @@ -25,15 +26,29 @@ def test_default_attributes(): "MASTER_ADDR": "1.2.3.4", "MASTER_PORT": "500", "WORLD_SIZE": "20", + "RANK": "1", "LOCAL_RANK": "2", "GROUP_RANK": "3", } ) -def test_attributes_from_environment_variables(): +def test_attributes_from_environment_variables(caplog): """ Test that the torchelastic cluster environment takes the attributes from the environment variables. """ env = TorchElasticEnvironment() assert env.master_address() == "1.2.3.4" assert env.master_port() == 500 assert env.world_size() == 20 + assert env.global_rank() == 1 assert env.local_rank() == 2 assert env.node_rank() == 3 + # setter should be no-op + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_global_rank(100) + assert env.global_rank() == 1 + assert "setting global rank is not allowed" in caplog.text + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_world_size(100) + assert env.world_size() == 20 + assert "setting world size is not allowed" in caplog.text diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index fc3cd54327288..328cb0a59f08e 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -26,6 +26,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", + "SLURM_PROCID": "0", "SLURM_LOCALID": "0", } ) diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py new file mode 100644 index 0000000000000..032276dd674d0 --- /dev/null +++ b/tests/plugins/test_cluster_integration.py @@ -0,0 +1,114 @@ +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPShardedPlugin, DeepSpeedPlugin, RPCSequentialPlugin +from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.utilities import rank_zero_only +from tests.helpers.runif import RunIf + + +def environment_combinations(): + expected = dict(global_rank=3, local_rank=1, node_rank=1, world_size=4) + # Lightning + variables = { + "CUDA_VISIBLE_DEVICES": "0,1,2,4", + "LOCAL_RANK": "1", + "NODE_RANK": "1", + "WORLD_SIZE": "8", + } + environment = LightningEnvironment() + yield environment, variables, expected + # SLURM + variables = { + "CUDA_VISIBLE_DEVICES": "0,1,2,4", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_LOCALID": "1", + "SLURM_NODEID": "1", + "SLURM_PROCID": "3", + "SLURM_NTASKS": "4", + } + environment = SLURMEnvironment() + yield environment, variables, expected + # TorchElastic + variables = { + "CUDA_VISIBLE_DEVICES": "0,1,2,4", + "LOCAL_RANK": "1", + "GROUP_RANK": "1", + "RANK": "3", + "WORLD_SIZE": "4", + "LOCAL_WORLD_SIZE": "2", + } + environment = TorchElasticEnvironment() + yield environment, variables, expected + + +@pytest.mark.parametrize( + "plugin_cls", [ + DDPPlugin, + DDPShardedPlugin, + DDP2Plugin, + pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), + pytest.param(RPCSequentialPlugin, marks=RunIf(fairscale_pipe=True)), + ] +) +def test_ranks_available_manual_plugin_selection(plugin_cls): + """ Test that the rank information is readily available after Trainer initialization. """ + num_nodes = 2 + for cluster, variables, expected in environment_combinations(): + + if plugin_cls == DDP2Plugin: + expected.update(global_rank=expected["node_rank"], world_size=num_nodes) + + with mock.patch.dict(os.environ, variables): + plugin = plugin_cls( + parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], + num_nodes=num_nodes, + cluster_environment=cluster, + ) + trainer = Trainer(plugins=[plugin]) + assert rank_zero_only.rank == expected["global_rank"] + assert trainer.global_rank == expected["global_rank"] + assert trainer.local_rank == expected["local_rank"] + assert trainer.node_rank == expected["node_rank"] + assert trainer.world_size == expected["world_size"] + + +@pytest.mark.parametrize( + "trainer_kwargs", [ + dict(accelerator="ddp", gpus=[1, 2]), + dict(accelerator="ddp_sharded", gpus=[1, 2]), + dict(accelerator="ddp2", gpus=[1, 2]), + dict(accelerator="ddp_cpu", num_processes=2), + dict(accelerator="ddp_spawn", gpus=[1, 2]), + ] +) +@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("torch.cuda.device_count", return_value=4) +def test_ranks_available_automatic_plugin_selection(mock0, mock1, trainer_kwargs): + """ Test that the rank information is readily available after Trainer initialization. """ + num_nodes = 2 + trainer_kwargs.update(num_nodes=num_nodes) + + for cluster, variables, expected in environment_combinations(): + + if trainer_kwargs["accelerator"] == "ddp2": + expected.update(global_rank=expected["node_rank"], world_size=num_nodes) + if trainer_kwargs["accelerator"] in ("ddp_cpu", "ddp_spawn"): + if isinstance(cluster, (SLURMEnvironment, TorchElasticEnvironment)): + # slurm and torchelastic do not work with spawn plugins + continue + # when using spawn, we don't reach rank > 0 until we call Trainer.fit() + expected.update(global_rank=(expected["node_rank"] * 2), local_rank=0) + + with mock.patch.dict(os.environ, variables): + trainer = Trainer(**trainer_kwargs) + assert type(trainer.training_type_plugin.cluster_environment) == type(cluster) + assert rank_zero_only.rank == expected["global_rank"] + assert trainer.global_rank == expected["global_rank"] + assert trainer.local_rank == expected["local_rank"] + assert trainer.node_rank == expected["node_rank"] + assert trainer.world_size == expected["world_size"] diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py index 9ecc93a9b5055..7abf9fcbd5039 100644 --- a/tests/plugins/test_rpc_plugin.py +++ b/tests/plugins/test_rpc_plugin.py @@ -19,6 +19,7 @@ "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", + "SLURM_PROCID": "0", "SLURM_LOCALID": "0", }, ) diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index bb587827c3a3f..53baa8e54461b 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import MagicMock import pytest @@ -51,8 +52,9 @@ def predict_dataloader(self): (None, [_loader, _loader_no_len], None, None), ], ) +@mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") def test_error_patched_iterable_dataloaders( - tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders + _, tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders ): model = BoringModelNoDataloaders() connector = DataConnector(MagicMock()) @@ -69,6 +71,7 @@ def test_error_patched_iterable_dataloaders( TPUSpawnPlugin(MagicMock()).connect(model) -def test_error_process_iterable_dataloader(tmpdir): +@mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") +def test_error_process_iterable_dataloader(_): with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len)