diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index c3f32004f4ef4..b144648e4754d 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -13,6 +13,8 @@ # limitations under the License import os +import torch +import torch.distributed as torch_distrib import subprocess import sys from os.path import abspath @@ -20,10 +22,13 @@ from typing import Optional import numpy as np -import torch +from pytorch_lightning import _logger as log from pytorch_lightning.utilities.distributed import find_free_network_port -from pytorch_lightning.accelerators.ddp_base_backend import DDPBase +from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities import AMPType + try: from hydra.utils import to_absolute_path, get_original_cwd @@ -34,13 +39,14 @@ HYDRA_AVAILABLE = True -class DDPBackend(DDPBase): +class DDPBackend(Accelerator): def __init__(self, trainer, mode: str = 'ddp'): super().__init__(trainer) self.task_idx = None self._has_spawned_children = False self.mode = mode + self.interactive_ddp_procs = [] def setup(self, model): if self.mode == 'ddp': @@ -59,6 +65,10 @@ def __torchelastic_setup(self): self.task_idx = int(os.environ['LOCAL_RANK']) def __ddp_script_mode_setup(self): + # do nothing when already in a ddp subprocess + if os.environ.get('PL_IN_DDP_SUBPROCESS', '0') == '1': + return + assert self.trainer.global_rank == 0 self._check_can_spawn_children() self._has_spawned_children = True @@ -105,7 +115,7 @@ def __ddp_script_mode_setup(self): os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}' - self.trainer.interactive_ddp_procs = [] + self.interactive_ddp_procs = [] for local_rank in range(1, self.trainer.num_processes): env_copy = os.environ.copy() env_copy['LOCAL_RANK'] = f'{local_rank}' @@ -118,7 +128,7 @@ def __ddp_script_mode_setup(self): if HydraConfig.initialized(): cwd = get_original_cwd() proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.trainer.interactive_ddp_procs.append(proc) + self.interactive_ddp_procs.append(proc) # starting all processes at once can cause issues # with dataloaders delay between 1-10 seconds @@ -127,14 +137,116 @@ def __ddp_script_mode_setup(self): self.task_idx = 0 + # wait for all the procs to start + sleep(2) + def train(self): model = self.trainer.model if self.mode == 'ddp': - results = self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True) - del os.environ['WORLD_SIZE'] + results = self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True) + if 'WORLD_SIZE' in os.environ: + del os.environ['WORLD_SIZE'] return results else: - self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model) + return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) + + def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): + """ + Entry point for ddp + Args: + process_idx: + mp_queue: multiprocessing queue + model: + is_master: + proc_offset: + Returns: + """ + # offset the process id if requested + process_idx = process_idx + proc_offset + + # show progressbar only on progress_rank 0 + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # determine which process we are and world size + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + model.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # call sync_bn before .cuda(), configure_apex and configure_ddp + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) + + # MODEL + # copy model to each gpu + self.model_to_device(model, process_idx, is_master) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.setup_optimizers(model) + + # set model properties before going into wrapper + self.trainer.model_connector.copy_trainer_model_properties(model) + + # AMP - run through amp wrapper before going to distributed DP + # DDP uses all GPUs on the machine + device_ids = self.get_device_ids() + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # set up training routine + self.barrier('ddp_setup') + self.trainer.train_loop.setup_training(model) + + # train or test + results = self.train_or_test() + + # clean up memory + torch.cuda.empty_cache() + + return results + + def training_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model(*args) + else: + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + + def barrier(self, name: str = None): + if torch_distrib.is_initialized(): + torch_distrib.barrier() def _check_can_spawn_children(self): if self._has_spawned_children: @@ -149,15 +261,7 @@ def set_world_ranks(self, process_idx): self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes def model_to_device(self, model, process_idx, is_master): - gpu_idx = process_idx - - # when using ddp, the master process (proc 0) continues running as the main one - # this means that the local rank will always be 0 - # (even if cuda visible devices has other visible gpus) - # this means that the master process needs to pull the 0th visible index as the device number - if is_master: - available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') - gpu_idx = int(available_gpus[self.trainer.local_rank]) + gpu_idx = int(os.environ.get('PL_DDP_PID', process_idx)) gpu_idx = int(os.environ.get('PL_DDP_PID', gpu_idx)) @@ -168,3 +272,6 @@ def model_to_device(self, model, process_idx, is_master): def get_device_ids(self): device_ids = [self.trainer.root_gpu] return device_ids + + def on_train_end(self): + pass diff --git a/tests/backends/test_ddp.py b/tests/backends/test_ddp.py index 91f22c4d7c59d..08f9a7e109808 100644 --- a/tests/backends/test_ddp.py +++ b/tests/backends/test_ddp.py @@ -37,21 +37,21 @@ def test_multi_gpu_model_ddp_test_only(tmpdir, cli_args): assert result['status'] == 'complete' -# @pytest.mark.parametrize('cli_args', [ -# pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'), -# ]) -# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -# def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args): -# # call the script -# call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20) -# -# # load the results of the script -# result_path = os.path.join(tmpdir, 'ddp.result') -# result = torch.load(result_path) -# -# # verify the file wrote the expected outputs -# assert result['status'] == 'complete' -# -# model_outs = result['result'] -# for out in model_outs: -# assert out['test_acc'] > 0.90 +@pytest.mark.parametrize('cli_args', [ + pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'), +]) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args): + # call the script + call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20) + + # load the results of the script + result_path = os.path.join(tmpdir, 'ddp.result') + result = torch.load(result_path) + + # verify the file wrote the expected outputs + assert result['status'] == 'complete' + + model_outs = result['result'] + for out in model_outs: + assert out['test_acc'] > 0.90