Skip to content

Commit

Permalink
Clean up environment access in plugins (#6941)
Browse files Browse the repository at this point in the history
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
  • Loading branch information
3 people authored Apr 13, 2021
1 parent 89074fa commit 33cc9fe
Show file tree
Hide file tree
Showing 23 changed files with 403 additions and 100 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))


- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/plugins/environments/cluster_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Optional


class ClusterEnvironment(ABC):
Expand All @@ -31,9 +30,21 @@ def master_port(self) -> int:
""" An open and configured port in the master node through which all processes communicate. """

@abstractmethod
def world_size(self) -> Optional[int]:
def world_size(self) -> int:
""" The number of processes across all devices and nodes. """

@abstractmethod
def set_world_size(self, size: int) -> None:
pass

@abstractmethod
def global_rank(self) -> int:
""" The rank (index) of the currently running process across all nodes and devices. """

@abstractmethod
def set_global_rank(self, rank: int) -> None:
pass

@abstractmethod
def local_rank(self) -> int:
""" The rank (index) of the currently running process inside of the current node. """
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/plugins/environments/lightning_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import os
import socket
from typing import Optional

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_only


class LightningEnvironment(ClusterEnvironment):
Expand All @@ -34,6 +34,8 @@ class LightningEnvironment(ClusterEnvironment):
def __init__(self):
super().__init__()
self._master_port = None
self._global_rank: int = 0
self._world_size: int = 1

def creates_children(self) -> bool:
return False
Expand All @@ -46,8 +48,18 @@ def master_port(self) -> int:
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port())
return int(self._master_port)

def world_size(self) -> Optional[int]:
return None
def world_size(self) -> int:
return self._world_size

def set_world_size(self, size: int) -> None:
self._world_size = size

def global_rank(self) -> int:
return self._global_rank

def set_global_rank(self, rank: int) -> None:
self._global_rank = rank
rank_zero_only.rank = rank

def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))
Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

class SLURMEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()

def creates_children(self) -> bool:
return True

Expand Down Expand Up @@ -69,8 +66,17 @@ def master_port(self) -> int:

return int(default_port)

def world_size(self):
return None
def world_size(self) -> int:
return int(os.environ["SLURM_NTASKS"])

def set_world_size(self, size: int) -> None:
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["SLURM_PROCID"])

def set_global_rank(self, rank: int) -> None:
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
return int(os.environ['SLURM_LOCALID'])
Expand Down
18 changes: 16 additions & 2 deletions pytorch_lightning/plugins/environments/torchelastic_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@

class TorchElasticEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()
@staticmethod
def is_using_torchelastic() -> bool:
""" Returns ``True`` if the current process was launched using the torchelastic command. """
required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
return all(v in os.environ for v in required_env_vars)

def creates_children(self) -> bool:
return True
Expand All @@ -51,6 +54,17 @@ def world_size(self) -> Optional[int]:
world_size = os.environ.get('WORLD_SIZE')
return int(world_size) if world_size is not None else world_size

def set_world_size(self, size: int) -> None:
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["RANK"])

def set_global_rank(self, rank: int) -> None:
log.debug(
"TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
)

def local_rank(self) -> int:
return int(os.environ['LOCAL_RANK'])

Expand Down
22 changes: 11 additions & 11 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def __init__(
self._ddp_kwargs = kwargs
self._has_spawned_children = False
self.task_idx = None
self.node_rank = 0
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self.set_world_ranks()

@property
def root_device(self):
Expand Down Expand Up @@ -193,7 +193,7 @@ def setup_distributed(self):
# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)
self.init_ddp_connection()

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
Expand All @@ -213,11 +213,11 @@ def _check_can_spawn_children(self):
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes
def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
Expand Down Expand Up @@ -260,11 +260,11 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None:
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@

class DDP2Plugin(DDPPlugin):

@property
def global_rank(self) -> int:
return self.node_rank

@property
def world_size(self) -> int:
return self.num_nodes

def setup(self, model):
self._model = model
# set the task idx
Expand Down Expand Up @@ -64,7 +72,5 @@ def _is_single_process_single_device(self) -> bool:
return False

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank
self.world_size = self.num_nodes
self.cluster_environment.set_global_rank(self.node_rank)
self.cluster_environment.set_world_size(self.num_nodes)
26 changes: 16 additions & 10 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ def __init__(
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.node_rank = 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self._local_rank = 0
self.set_world_ranks()

@property
def local_rank(self) -> int:
return self._local_rank

def __getstate__(self):
""" Makes this plugin pickleable without destroying the queue in the current process. """
Expand Down Expand Up @@ -95,12 +100,12 @@ def setup(self, model):
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()

def set_world_ranks(self, process_idx):
self.local_rank = process_idx
self.node_rank = self.cluster_environment.node_rank()
self.task_idx = self.cluster_environment.local_rank()
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes
def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

@property
def mp_spawn_kwargs(self):
Expand Down Expand Up @@ -213,11 +218,12 @@ def configure_ddp(self):
)
self._register_ddp_hooks()

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
def init_ddp_connection(self, global_rank: Optional[int], world_size: Optional[int]) -> None:
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ class DataParallelPlugin(ParallelPlugin):
def __init__(self, parallel_devices: Optional[List[torch.device]]):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)

@property
def global_rank(self) -> int:
return 0

@property
def local_rank(self) -> int:
return 0

@property
def node_rank(self) -> int:
return 0

@property
def world_size(self) -> int:
return 1

def setup(self, model):
# model needs to be moved to the device before it is wrapped
model.to(self.root_device)
Expand Down
25 changes: 16 additions & 9 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,31 @@ class HorovodPlugin(ParallelPlugin):

def __init__(self, parallel_devices: Optional[List[torch.device]] = None):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)
rank_zero_only.rank = self.global_rank

@property
def global_rank(self) -> int:
return hvd.rank()

@property
def local_rank(self) -> int:
return hvd.local_rank()

@property
def world_size(self) -> int:
return hvd.size()

@property
def root_device(self):
return self.parallel_devices[self.local_rank]

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

self.global_rank = hvd.rank()
self.local_rank = hvd.local_rank()
self.world_size = hvd.size()
rank_zero_only.rank = self.global_rank

self.model_to_device()

def pre_dispatch(self):
Expand All @@ -63,14 +70,14 @@ def _unpack_lightning_optimizer(opt):
# increased total batch size
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group["lr"] *= hvd.size()
param_group["lr"] *= self.world_size

# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
lr_schedulers = self.lightning_module.trainer.lr_schedulers
for scheduler in lr_schedulers:
scheduler = scheduler["scheduler"]
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
Expand Down
19 changes: 16 additions & 3 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(
super().__init__()
self.parallel_devices = parallel_devices
self.cluster_environment = cluster_environment
self.global_rank = 0
self.world_size = 1
self.local_rank = 0

@property
@abstractmethod
Expand All @@ -53,6 +50,22 @@ def on_gpu(self):
def lightning_module(self):
return unwrap_lightning_module(self._model)

@property
def global_rank(self) -> int:
return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0

@property
def local_rank(self) -> int:
return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0

@property
def node_rank(self) -> int:
return self.cluster_environment.node_rank() if self.cluster_environment is not None else 0

@property
def world_size(self) -> int:
return self.cluster_environment.world_size() if self.cluster_environment is not None else 1

@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
Expand Down
Loading

0 comments on commit 33cc9fe

Please sign in to comment.