-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathdistrib_data_parallel.py
542 lines (431 loc) · 19.6 KB
/
distrib_data_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
"""
Lightning supports model training on a cluster managed by SLURM in the following cases:
1. Training on a single cpu or single GPU.
2. Train on multiple GPUs on the same node using DataParallel or DistributedDataParallel
3. Training across multiple GPUs on multiple different nodes via DistributedDataParallel.
.. note:: A node means a machine with multiple GPUs
Running grid search on a cluster
--------------------------------
To use lightning to run a hyperparameter search (grid-search or random-search) on a cluster do 4 things:
(1). Define the parameters for the grid search
.. code-block:: python
from test_tube import HyperOptArgumentParser
# subclass of argparse
parser = HyperOptArgumentParser(strategy='random_search')
parser.add_argument('--learning_rate', default=0.002, type=float, help='the learning rate')
# let's enable optimizing over the number of layers in the network
parser.opt_list('--nb_layers', default=2, type=int, tunable=True, options=[2, 4, 8])
hparams = parser.parse_args()
.. note:: You must set `Tunable=True` for that argument to be considered in the permutation set.
Otherwise test-tube will use the default value. This flag is useful when you don't want
to search over an argument and want to use the default instead.
(2). Define the cluster options in the
`SlurmCluster object <https://williamfalcon.github.io/test-tube/hpc/SlurmCluster>`_ (over 5 nodes and 8 gpus)
.. code-block:: python
from test_tube.hpc import SlurmCluster
# hyperparameters is a test-tube hyper params object
# see https://williamfalcon.github.io/test-tube/hyperparameter_optimization/HyperOptArgumentParser/
hyperparams = args.parse()
# init cluster
cluster = SlurmCluster(
hyperparam_optimizer=hyperparams,
log_path='/path/to/log/results/to',
python_cmd='python3'
)
# let the cluster know where to email for a change in job status (ie: complete, fail, etc...)
cluster.notify_job_status(email='[email protected]', on_done=True, on_fail=True)
# set the job options. In this instance, we'll run 20 different models
# each with its own set of hyperparameters giving each one 1 GPU (ie: taking up 20 GPUs)
cluster.per_experiment_nb_gpus = 8
cluster.per_experiment_nb_nodes = 5
# we'll request 10GB of memory per node
cluster.memory_mb_per_node = 10000
# set a walltime of 10 minues
cluster.job_time = '10:00'
(3). Make a main function with your model and trainer. Each job will call this function with a particular
hparams configuration.::
from pytorch_lightning import Trainer
def train_fx(trial_hparams, cluster_manager, _):
# hparams has a specific set of hyperparams
my_model = MyLightningModel()
# give the trainer the cluster object
trainer = Trainer()
trainer.fit(my_model)
`
(4). Start the grid/random search::
# run the models on the cluster
cluster.optimize_parallel_cluster_gpu(
train_fx,
nb_trials=20,
job_name='my_grid_search_exp_name',
job_display_name='my_exp')
.. note:: `nb_trials` specifies how many of the possible permutations to use. If using `grid_search` it will use
the depth first ordering. If using `random_search` it will use the first k shuffled options. FYI, random search
has been shown to be just as good as any Bayesian optimization method when using a reasonable number of samples (60),
see this `paper <http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf>`_ for more information.
Walltime auto-resubmit
----------------------
Lightning automatically resubmits jobs when they reach the walltime. Make sure to set the SIGUSR1 signal in
your SLURM script.::
# 90 seconds before training ends
#SBATCH --signal=SIGUSR1@90
When lightning receives the SIGUSR1 signal it will:
1. save a checkpoint with 'hpc_ckpt' in the name.
2. resubmit the job using the SLURM_JOB_ID
When the script starts again, Lightning will:
1. search for a 'hpc_ckpt' checkpoint.
2. restore the model, optimizers, schedulers, epoch, etc...
"""
import os
import re
from abc import ABC, abstractmethod
from typing import Union
import subprocess
import sys
from time import sleep
import numpy as np
from os.path import abspath
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True
class TrainerDDPMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
on_gpu: bool
num_gpu_nodes: int
logger: Union[LightningLoggerBase, bool]
checkpoint_callback: Union[ModelCheckpoint, bool]
data_parallel_device_ids: ...
distributed_backend: str
amp_level: str
use_tpu: bool
default_root_dir: str
use_native_amp: bool
progress_bar_callback: ...
num_processes: int
@property
@abstractmethod
def num_gpus(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""
@property
@abstractmethod
def use_amp(self) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def init_optimizers(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def init_tpu(self):
# turn off all the GPU stuff
self.distributed_backend = None
# enable tpu
self.use_tpu = True
def set_distributed_mode(self, distributed_backend):
self.use_dp = False
self.use_ddp = False
self.use_ddp2 = False
self.use_horovod = False
self.single_gpu = False
if distributed_backend is None:
if self.has_horovodrun():
self._set_horovod_backend()
elif self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
elif self.num_gpus == 1:
self.single_gpu = True
elif self.num_gpus > 1:
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
' Setting distributed_backend=ddp for you.')
self.distributed_backend = 'ddp'
distributed_backend = 'ddp'
if distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.num_gpus == 1:
self.single_gpu = True
self.use_dp = True
elif self.num_gpus > 1:
self.use_dp = True
elif distributed_backend == "ddp":
if self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
self.use_ddp = True # ddp_cpu
elif self.num_gpus == 1:
self.single_gpu = True
self.use_ddp = True
elif self.num_gpus > 1:
self.use_ddp = True
self.num_processes = self.num_gpus
elif distributed_backend == "ddp2":
# do nothing if num_gpus == 0
if self.num_gpus >= 1:
self.use_ddp2 = True
elif distributed_backend == "ddp_cpu":
if self.num_gpus > 0:
rank_zero_warn('You requested one or more GPUs, but set the backend to `ddp_cpu`.'
' Training will not use GPUs.')
self.use_ddp = True
self.data_parallel_device_ids = None
self.on_gpu = False
elif distributed_backend == 'horovod':
self._set_horovod_backend()
# throw error to force user ddp or ddp2 choice
if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp):
raise MisconfigurationException(
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2'
)
log.info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
def configure_slurm_ddp(self, num_gpu_nodes):
self.is_slurm_managing_tasks = False
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
if self.use_ddp:
self.num_requested_gpus = self.num_gpus * num_gpu_nodes
self.num_slurm_tasks = 0
try:
self.num_slurm_tasks = int(os.environ['SLURM_NTASKS'])
self.is_slurm_managing_tasks = self.num_slurm_tasks == self.num_requested_gpus
# in interactive mode we don't manage tasks
job_name = os.environ['SLURM_JOB_NAME']
if job_name == 'bash':
self.is_slurm_managing_tasks = False
except Exception:
# likely not on slurm, so set the slurm managed flag to false
self.is_slurm_managing_tasks = False
# used for tests only, set this flag to simulate slurm managing a task
try:
should_fake = int(os.environ['FAKE_SLURM_MANAGING_TASKS'])
if should_fake:
self.is_slurm_managing_tasks = True
except Exception:
pass
# notify user the that slurm is managing tasks
if self.is_slurm_managing_tasks:
log.info('Multi-processing is handled by Slurm.')
def determine_ddp_node_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_NODEID'])
# torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK.
# otherwise use given node rank or default to node rank 0
env_vars = ['NODE_RANK', 'GROUP_RANK']
node_ids = [(k, os.environ.get(k, None)) for k in env_vars]
node_ids = [(k, v) for k, v in node_ids if v is not None]
if len(node_ids) == 0:
log.warning("No environment variable for node rank defined. Set as 0.")
return 0
if len(node_ids) > 1:
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "
f"Using the first one.")
k, rank = node_ids.pop()
log.info(f"Using environment variable {k} for node rank ({rank}).")
return int(rank)
def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
if data_parallel_device_ids is None:
return
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# when slurm is managing the task it sets the visible devices
if not is_slurm_managing_tasks and 'CUDA_VISIBLE_DEVICES' not in os.environ:
if isinstance(data_parallel_device_ids, int):
id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids)))
os.environ["CUDA_VISIBLE_DEVICES"] = id_str
else:
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
# don't make this debug... this is good UX
log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
def __set_random_port(self):
"""
When running DDP NOT managed by SLURM, the ports might collide
:return:
"""
try:
default_port = os.environ['MASTER_PORT']
except Exception:
import random
default_port = random.randint(10000, 19000)
os.environ['MASTER_PORT'] = str(default_port)
def spawn_ddp_children(self, model):
self.__set_random_port()
port = os.environ['MASTER_PORT']
master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR']
os.environ['MASTER_PORT'] = f'{port}'
os.environ['MASTER_ADDR'] = f'{master_address}'
# allow the user to pass the node rank
node_rank = '0'
if 'NODE_RANK' in os.environ:
node_rank = os.environ['NODE_RANK']
if 'GROUP_RANK' in os.environ:
node_rank = os.environ['GROUP_RANK']
os.environ['NODE_RANK'] = node_rank
os.environ['LOCAL_RANK'] = '0'
# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
full_path = abspath(command[0])
command[0] = full_path
command = ['python'] + command
# since this script sets the visible devices we replace the gpus flag with a number
num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__()
# if script called without a flag, pass in a flag anyhow
if '--gpus' not in command:
arg_gpus = len(self.gpus) if isinstance(self.gpus, list) else self.gpus
command += ['--gpus', arg_gpus]
gpu_flag_idx = command.index('--gpus')
command[gpu_flag_idx + 1] = f'{num_gpus}'
os.environ['WORLD_SIZE'] = f'{num_gpus * self.num_nodes}'
self.interactive_ddp_procs = []
for local_rank in range(1, self.num_processes):
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'
# import pdb; pdb.set_trace()
# start process
proc = subprocess.Popen(command, env=env_copy)
self.interactive_ddp_procs.append(proc)
# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)
local_rank = 0
self.ddp_train(local_rank, model, is_master=True)
def ddp_train(self, process_idx, model, is_master=False):
"""
Entry point into a DP thread
:param gpu_idx:
:param model:
:param cluster_obj:
:return:
"""
# show progressbar only on progress_rank 0
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
# determine which process we are and world size
if self.use_ddp:
self.proc_rank = self.node_rank * self.num_processes + process_idx
self.world_size = self.num_nodes * self.num_processes
elif self.use_ddp2:
self.proc_rank = self.node_rank
self.world_size = self.num_nodes
# set warning rank
rank_zero_only.rank = self.proc_rank
# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self
model.init_ddp_connection(self.proc_rank, self.world_size, self.is_slurm_managing_tasks)
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
# MODEL
# copy model to each gpu
if self.on_gpu:
gpu_idx = process_idx
if is_master:
# source of truth is cuda for gpu idx
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
local_rank = int(os.environ['LOCAL_RANK'])
gpu_idx = int(gpus[local_rank])
self.root_gpu = gpu_idx
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)
# set model properties before going into wrapper
self.copy_trainer_model_properties(model)
# AMP
# run through amp wrapper before going to distributed DP
# TODO: remove in v0.8.0
if self.use_amp and not self.use_native_amp:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
# DDP2 uses all GPUs on the machine
if self.distributed_backend == 'ddp':
device_ids = [self.root_gpu]
elif self.use_ddp2:
device_ids = self.data_parallel_device_ids
else: # includes ddp_cpu
device_ids = None
# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
# continue training routine
self.run_pretrain_routine(model)
def save_spawn_weights(self, model):
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
:param model:
:return:
"""
if self.proc_rank == 0:
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)
def load_spawn_weights(self, original_model):
"""
Load the temp weights saved in the process
To recover the trained model from the ddp process we load the saved weights
:param model:
:return:
"""
loaded_model = original_model
if self.proc_rank == 0:
# load weights saved in ddp
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)
# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())
# remove ddp weights
os.remove(path)
return loaded_model
def resolve_root_node_address(self, root_node):
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
number = numbers.split(',', maxsplit=1)[0]
if '-' in number:
number = number.split('-')[0]
number = re.sub('[^0-9]', '', number)
root_node = name + number
return root_node
def _set_horovod_backend(self):
self.check_horovod()
self.use_horovod = True
# Initialize Horovod to get rank / size info
hvd.init()
if self.on_gpu:
# Horovod assigns one local GPU per process
self.root_gpu = hvd.local_rank()
def check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not HOROVOD_AVAILABLE:
raise MisconfigurationException(
'Requested `distributed_backend="horovod"`, but Horovod is not installed.'
'Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]'
)
if self.num_gpus > 1 or self.num_nodes > 1:
raise MisconfigurationException(
'Horovod does not support setting num_nodes / num_gpus explicitly. Use '
'horovodrun / mpirun to configure the number of processes.'
)
@staticmethod
def has_horovodrun():
"""Returns True if running with `horovodrun` using Gloo or OpenMPI."""
return 'OMPI_COMM_WORLD_RANK' in os.environ or 'HOROVOD_RANK' in os.environ