This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtrain_net.py
309 lines (264 loc) · 11 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import os
import sys
from collections import OrderedDict
import torch
from torch.nn.parallel import DistributedDataParallel
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.config import get_cfg
from detectron2.data import (
MetadataCatalog,
build_detection_test_loader,
build_detection_train_loader,
)
from detectron2.engine import default_argument_parser, default_setup, default_writers, launch
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator,
COCOEvaluator,
COCOPanopticEvaluator,
DatasetEvaluators,
LVISEvaluator,
PascalVOCDetectionEvaluator,
SemSegEvaluator,
inference_on_dataset,
print_csv_format,
)
from detectron2.modeling import build_model
# from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.solver import build_lr_scheduler
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.utils.events import EventStorage
from detectron2.data.dataset_mapper import DatasetMapper
from torch.cuda.amp import GradScaler
from vlpart.data import (
DatasetMapperAnn,
DatasetMapperWithImage,
DatasetMapperFilterByBox,
)
from vlpart.data.custom_build_augmentation import build_custom_augmentation
from vlpart.data.custom_dataset_dataloader import (
build_custom_train_loader,
build_custom_test_loader,
)
from vlpart.solver.custom_solver import build_custom_optimizer
from vlpart.evaluation import (
PASCALPARTEvaluator,
PACOEvaluator,
AnnJsonGenerator,
RecallEvaluator,
)
from vlpart import add_vlpart_config
import warnings
warnings.filterwarnings('ignore')
logger = logging.getLogger("detectron2")
def get_evaluator(cfg, dataset_name, output_folder=None):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each builtin dataset.
For your own dataset, you can simply create an evaluator manually in your
script and do not have to worry about the hacky if-else logic here.
"""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if cfg.MODEL.ANN_GENERATOR:
return AnnJsonGenerator(cfg, dataset_name=dataset_name, output_dir=cfg.OUTPUT_ANN_DIR)
if cfg.MODEL.EVAL_PROPOSAL:
return RecallEvaluator(dataset_name, output_dir=output_folder)
if evaluator_type == "coco":
return COCOEvaluator(dataset_name, output_dir=output_folder)
if evaluator_type == "pascal_part":
return PASCALPARTEvaluator(dataset_name, output_dir=output_folder)
if evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
if evaluator_type == "lvis":
return LVISEvaluator(dataset_name, cfg, True, output_folder)
if evaluator_type == "paco":
return PACOEvaluator(dataset_name, cfg, True, output_folder, cfg.MODEL.EVAL_ATTR, cfg.MODEL.EVAL_PER)
if len(evaluator_list) == 0:
raise NotImplementedError(
"no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
)
if len(evaluator_list) == 1:
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
def do_test(cfg, model):
results = OrderedDict()
mapper = DatasetMapperAnn(cfg, is_train=False) \
if cfg.MODEL.ANN_GENERATOR else None
for dataset_name in cfg.DATASETS.TEST:
# data_loader = build_detection_test_loader(cfg, dataset_name, mapper=mapper)
data_loader = build_custom_test_loader(cfg, dataset_name, mapper=mapper)
evaluator = get_evaluator(
cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
)
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
print_csv_format(results_i)
if len(results) == 1:
results = list(results.values())[0]
return results
def build_optimizer(cfg, model):
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for key, value in model.named_parameters(recurse=True):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
elif optimizer_type == "RMSprop":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.RMSprop)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
def do_train(cfg, model, resume=False):
model.train()
if cfg.SOLVER.USE_CUSTOM_SOLVER:
optimizer = build_custom_optimizer(cfg, model)
else:
assert cfg.SOLVER.OPTIMIZER == 'SGD'
assert cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE != 'full_model'
assert cfg.SOLVER.BACKBONE_MULTIPLIER == 1.
optimizer = build_optimizer(cfg, model)
# optimizer = build_optimizer(cfg, model)
scheduler = build_lr_scheduler(cfg, optimizer)
checkpointer = DetectionCheckpointer(
model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
)
start_iter = (
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
)
max_iter = cfg.SOLVER.MAX_ITER
periodic_checkpointer = PeriodicCheckpointer(
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)
writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []
use_with_image_mapper = cfg.WITH_IMAGE_LABELS
MapperClass = DatasetMapperWithImage if use_with_image_mapper else DatasetMapperFilterByBox
mapper = MapperClass(cfg, True) if cfg.INPUT.CUSTOM_AUG == '' else \
MapperClass(cfg, True, augmentations=build_custom_augmentation(cfg, True))
if cfg.DATALOADER.SAMPLER_TRAIN in ['TrainingSampler', 'RepeatFactorTrainingSampler']:
data_loader = build_detection_train_loader(cfg, mapper=mapper)
else:
data_loader = build_custom_train_loader(cfg, mapper=mapper)
if cfg.FP16:
scaler = GradScaler()
logger.info("Starting training from iteration {}".format(start_iter))
with EventStorage(start_iter) as storage:
for data, iteration in zip(data_loader, range(start_iter, max_iter)):
storage.iter = iteration
loss_dict = model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
if cfg.FP16:
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
else:
losses.backward()
optimizer.step()
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
scheduler.step()
if (
cfg.TEST.EVAL_PERIOD > 0
and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
and iteration != max_iter - 1
):
torch.cuda.empty_cache()
do_test(cfg, model)
comm.synchronize()
if iteration - start_iter > 5 and (
(iteration + 1) % 20 == 0 or iteration == max_iter - 1
):
for writer in writers:
writer.write()
periodic_checkpointer.step(iteration)
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_vlpart_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
if '/auto' in cfg.OUTPUT_DIR:
file_name = os.path.basename(args.config_file)[:-5]
cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace('/auto', '/{}'.format(file_name))
logger.info('OUTPUT_DIR: {}'.format(cfg.OUTPUT_DIR))
cfg.freeze()
default_setup(
cfg, args
) # if you don't like any of the default setup, write your own setup code
return cfg
def main(args):
cfg = setup(args)
model = build_model(cfg)
logger.info("Model:\n{}".format(model))
if args.eval_only:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
return do_test(cfg, model)
distributed = comm.get_world_size() > 1
if distributed:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
find_unused_parameters=cfg.FIND_UNUSED_PARAM
)
do_train(cfg, model, resume=args.resume)
return do_test(cfg, model)
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)