-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
153 lines (130 loc) · 5.06 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# This code is taken from https://github.com/open-mmlab/mmediting
# Modified by Raymond Wong
import argparse
import os
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmderain.apis import multi_gpu_test, set_random_seed, single_gpu_test
from mmderain.core.distributed_wrapper import DistributedDataParallelWrapper
from mmderain.datasets import build_dataloader, build_dataset
from mmderain.models import build_model
def parse_args():
parser = argparse.ArgumentParser(description='mmediting tester')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument('--out', help='output result pickle file')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='whether to use gpu to collect results')
parser.add_argument(
'--save-path',
default=None,
type=str,
help='path to store images and if not given, will not save image')
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
rank, _ = get_dist_info()
# set random seeds
if args.seed is not None:
if rank == 0:
print('set random seed to', args.seed)
set_random_seed(args.seed, deterministic=args.deterministic)
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
loader_cfg = {
**dict((k, cfg.data[k]) for k in ['workers_per_gpu'] if k in cfg.data),
**dict(
samples_per_gpu=1,
drop_last=False,
shuffle=False,
dist=distributed),
**cfg.data.get('test_dataloader', {})
}
data_loader = build_dataloader(dataset, **loader_cfg)
# build the model and load checkpoint
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
args.save_image = args.save_path is not None
empty_cache = cfg.get('empty_cache', False)
if not distributed:
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(
model,
data_loader,
save_path=args.save_path,
save_image=args.save_image)
else:
find_unused_parameters = cfg.get('find_unused_parameters', False)
model = DistributedDataParallelWrapper(
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
device_id = torch.cuda.current_device()
_ = load_checkpoint(
model,
args.checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
outputs = multi_gpu_test(
model,
data_loader,
args.tmpdir,
args.gpu_collect,
save_path=args.save_path,
save_image=args.save_image,
empty_cache=empty_cache)
if rank == 0 and 'eval_result' in outputs[0]:
print('')
# print metrics
stats = dataset.evaluate(outputs)
for stat in stats:
print(f'Eval-{stat}: {stats[stat]}')
# save result pickle
if args.out:
print(f'writing results to {args.out}')
mmcv.dump(outputs, args.out)
if __name__ == '__main__':
main()