Skip to content

Commit

Permalink
ref: part a of #3733 (#3766)
Browse files Browse the repository at this point in the history
* ref: part a of #3733

* ref: part a of #3733
  • Loading branch information
williamFalcon authored Oct 1, 2020
1 parent 9a7d1a1 commit 440f837
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 17 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.parametrize('cls_model,max_diff', [
(ParityModuleRNN, 0.05),
(ParityModuleMNIST, 0.5)
(ParityModuleMNIST, 0.55)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir, cls_model, max_diff):
Expand Down
3 changes: 2 additions & 1 deletion pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def cli_main():
# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
result = trainer.test(test_dataloaders=test_loader)
print(result)


if __name__ == '__main__':
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@ def on_trainer_init(

# override with environment flag
gpus = os.environ.get('PL_TRAINER_GPUS', gpus)
self.trainer.gpus = gpus

# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.trainer.gpus = self.trainer.tuner.pick_multiple_gpus(gpus)
else:
self.trainer.gpus = gpus

self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(self.trainer.gpus)
self.trainer.root_gpu = device_parser.determine_root_gpu_device(self.trainer.data_parallel_device_ids)
Expand Down Expand Up @@ -126,6 +125,9 @@ def on_trainer_init(
self.trainer.replace_sampler_ddp = replace_sampler_ddp

def select_accelerator(self):
if self.trainer.accelerator_backend is not None:
return self.trainer.accelerator_backend

# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks

Expand Down Expand Up @@ -294,7 +296,8 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str

# don't make this debug... this is good UX
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
devices = os.environ["CUDA_VISIBLE_DEVICES"]
log.info(f'LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]')

def determine_local_rank(self):
if self.trainer.is_slurm_managing_tasks:
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def _clip_gradients(self, optimizer):
def on_train_epoch_end(self):
pass

def on_train_end(self):
pass

def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop

Expand Down
12 changes: 9 additions & 3 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,25 @@ def __ddp_script_mode_setup(self):
# when the trainer script was called the device has already been scoped by the time
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
# but forward the GPUs selected via environment variables
os.environ['PL_TRAINER_GPUS'] = ','.join([str(i) for i in self.trainer.data_parallel_device_ids])
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'

if self.trainer.logger is not None:
os.environ['PL_EXP_VERSION'] = str(self.trainer.logger.version)

gpu_ids = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if len(gpu_ids) == 1:
gpu_ids = f'{gpu_ids},'

num_gpus = max(1, len(gpu_ids.split(',')))

# set the flag for ddp scripts
os.environ['PL_TRAINER_GPUS'] = gpu_ids

os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}'

self.trainer.interactive_ddp_procs = []
for local_rank in range(1, self.trainer.num_processes):
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'
env_copy['PL_DDP_PID'] = str(self.trainer.data_parallel_device_ids[local_rank])

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
Expand Down Expand Up @@ -155,6 +159,8 @@ def model_to_device(self, model, process_idx, is_master):
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(available_gpus[self.trainer.local_rank])

gpu_idx = int(os.environ.get('PL_DDP_PID', gpu_idx))

self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/ddp_base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def test_step(self, args):
return output

def barrier(self, name: str = None):
torch_distrib.barrier()
if torch_distrib.is_initialized():
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
Expand Down Expand Up @@ -132,7 +133,7 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,12 +1029,14 @@ def init_ddp_connection(
)

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(
f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}"
)
torch_distrib.init_process_group(
torch_backend, rank=global_rank, world_size=world_size
)

if not torch.distributed.is_initialized():
log.info(
f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
)
torch_distrib.init_process_group(
torch_backend, rank=global_rank, world_size=world_size
)

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection
Expand Down Expand Up @@ -40,10 +41,12 @@ def on_trainer_init(self, logger, log_save_interval, row_log_interval):

def configure_logger(self, logger):
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)

# default logger
self.trainer.logger = TensorBoardLogger(
save_dir=self.trainer.default_root_dir,
version=self.trainer.slurm_job_id,
version=version,
name='lightning_logs'
)
elif logger is False:
Expand Down
33 changes: 33 additions & 0 deletions tests/utilities/dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import subprocess
from subprocess import TimeoutExpired
import sys
from pathlib import Path

import pytorch_lightning


def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
file = Path(module_file.__file__).absolute()
cli_args = cli_args.split(' ') if cli_args else []
cli_args += ['--tmpdir', str(tmpdir)]
cli_args += ['--trainer_method', method]
command = [sys.executable, str(file)] + cli_args

# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
env = os.environ.copy()
env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '')

# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)

try:
std, err = p.communicate(timeout=timeout)
err = str(err.decode("utf-8"))
if 'Exception' in err:
raise Exception(err)
except TimeoutExpired:
p.kill()
std, err = p.communicate()

return std, err

0 comments on commit 440f837

Please sign in to comment.