-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathrun_training.py
285 lines (243 loc) · 14.3 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import multiprocessing
import os
import socket
from typing import Union, Optional
import nnunetv2
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from torch.backends import cudnn
def find_free_network_port() -> int:
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
def get_trainer_from_args(dataset_name_or_id: Union[int, str],
configuration: str,
fold: int,
trainer_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
use_compressed: bool = False,
device: torch.device = torch.device('cuda')):
# load nnunet class and do sanity checks
nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if nnunet_trainer is None:
raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
f'nnunetv2.training.nnUNetTrainer ('
f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere '
f'else, please move it there.')
assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \
'nnUNetTrainer'
# handle dataset input. If it's an ID we need to convert to int from string
if dataset_name_or_id.startswith('Dataset'):
pass
else:
try:
dataset_name_or_id = int(dataset_name_or_id)
except ValueError:
raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern '
f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your '
f'input: {dataset_name_or_id}')
# initialize nnunet trainer
preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))
plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json')
plans = load_json(plans_file)
dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))
nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device)
return nnunet_trainer
def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool,
pretrained_weights_file: str = None):
if continue_training and pretrained_weights_file is not None:
raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only '
'be used at the beginning of the training.')
if continue_training:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')
# special case where --c is used to run a previously aborted validation
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')
if not isfile(expected_checkpoint_file):
print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to "
f"continue from. Starting a new training...")
expected_checkpoint_file = None
elif validation_only:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
if not isfile(expected_checkpoint_file):
raise RuntimeError(f"Cannot run validation because the training is not finished yet!")
else:
if pretrained_weights_file is not None:
if not nnunet_trainer.was_initialized:
nnunet_trainer.initialize()
load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True)
expected_checkpoint_file = None
if expected_checkpoint_file is not None:
nnunet_trainer.load_checkpoint(expected_checkpoint_file)
def setup_ddp(rank, world_size):
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup_ddp():
dist.destroy_process_group()
def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val,
pretrained_weights, npz, val_with_best, world_size):
setup_ddp(rank, world_size)
torch.cuda.set_device(torch.device('cuda', dist.get_rank()))
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p,
use_compressed)
if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.'
maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights)
if torch.cuda.is_available():
cudnn.deterministic = False
cudnn.benchmark = True
if not val:
nnunet_trainer.run_training()
if val_with_best:
nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
nnunet_trainer.perform_actual_validation(npz)
cleanup_ddp()
def run_training(dataset_name_or_id: Union[str, int],
configuration: str, fold: Union[int, str],
trainer_class_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
pretrained_weights: Optional[str] = None,
num_gpus: int = 1,
use_compressed_data: bool = False,
export_validation_probabilities: bool = False,
continue_training: bool = False,
only_run_validation: bool = False,
disable_checkpointing: bool = False,
val_with_best: bool = False,
device: torch.device = torch.device('cuda')):
if plans_identifier == 'nnUNetPlans':
print("\n############################\n"
"INFO: You are using the old nnU-Net default plans. We have updated our recommendations. "
"Please consider using those instead! "
"Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md"
"\n############################\n")
if isinstance(fold, str):
if fold != 'all':
try:
fold = int(fold)
except ValueError as e:
print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
raise e
if val_with_best:
assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing'
if num_gpus > 1:
assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}"
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ.keys():
port = str(find_free_network_port())
print(f"using port {port}")
os.environ['MASTER_PORT'] = port # str(port)
mp.spawn(run_ddp,
args=(
dataset_name_or_id,
configuration,
fold,
trainer_class_name,
plans_identifier,
use_compressed_data,
disable_checkpointing,
continue_training,
only_run_validation,
pretrained_weights,
export_validation_probabilities,
val_with_best,
num_gpus),
nprocs=num_gpus,
join=True)
else:
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
plans_identifier, use_compressed_data, device=device)
if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'
maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
if torch.cuda.is_available():
cudnn.deterministic = False
cudnn.benchmark = True
if not only_run_validation:
nnunet_trainer.run_training()
if val_with_best:
nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
nnunet_trainer.perform_actual_validation(export_validation_probabilities)
def run_training_entry():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_name_or_id', type=str,
help="Dataset name or ID to train with")
parser.add_argument('configuration', type=str,
help="Configuration that should be trained")
parser.add_argument('fold', type=str,
help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.')
parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer')
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans')
parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only '
'be used when actually training. Beta. Use with caution.')
parser.add_argument('-num_gpus', type=int, default=1, required=False,
help='Specify the number of GPUs to use for training')
parser.add_argument("--use_compressed", default=False, action="store_true", required=False,
help="[OPTIONAL] If you set this flag the training cases will not be decompressed. Reading compressed "
"data is much more CPU and (potentially) RAM intensive and should only be used if you "
"know what you are doing")
parser.add_argument('--npz', action='store_true', required=False,
help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted '
'segmentations). Needed for finding the best ensemble.')
parser.add_argument('--c', action='store_true', required=False,
help='[OPTIONAL] Continue training from latest checkpoint')
parser.add_argument('--val', action='store_true', required=False,
help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.')
parser.add_argument('--val_best', action='store_true', required=False,
help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead '
'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! '
'WARNING: This will use the same \'validation\' folder as the regular validation '
'with no way of distinguishing the two!')
parser.add_argument('--disable_checkpointing', action='store_true', required=False,
help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and '
'you dont want to flood your hard drive with checkpoints.')
parser.add_argument('-device', type=str, default='cuda', required=False,
help="Use this to set the device the training should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!")
args = parser.parse_args()
assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
if args.device == 'cpu':
# let's allow torch to use hella threads
import multiprocessing
torch.set_num_threads(multiprocessing.cpu_count())
device = torch.device('cpu')
elif args.device == 'cuda':
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
device = torch.device('cuda')
else:
device = torch.device('mps')
run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best,
device=device)
if __name__ == '__main__':
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
# multiprocessing.set_start_method("spawn")
run_training_entry()