-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy patheval_supervised.py
206 lines (123 loc) · 5.81 KB
/
eval_supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from model.geometry import *
import os
import torch
import sys
sys.path.append("correspondence")
from tqdm import tqdm
import argparse
from model.registration import Registration
import yaml
from easydict import EasyDict as edict
from model.loss import compute_flow_metrics
from utils.benchmark_utils import setup_seed
from utils.utils import Logger, AverageMeter
from utils.tiktok import Timers
from correspondence.landmark_estimator import Landmark_Model
def join(loader, node):
seq = loader.construct_sequence(node)
return '_'.join([str(i) for i in seq])
yaml.add_constructor('!join', join)
setup_seed(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help= 'Path to the config file.')
parser.add_argument('--visualize', action = 'store_true', help= 'visualize the registration results')
args = parser.parse_args()
with open(args.config,'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
config['snapshot_dir'] = 'snapshot/%s/%s' % (config['folder'], config['exp_dir'])
os.makedirs(config['snapshot_dir'], exist_ok=True)
config = edict(config)
# backup the experiment
os.system(f'cp -r config {config.snapshot_dir}')
os.system(f'cp -r data {config.snapshot_dir}')
os.system(f'cp -r model {config.snapshot_dir}')
os.system(f'cp -r utils {config.snapshot_dir}')
if config.gpu_mode:
config.device = torch.cuda.current_device()
else:
config.device = torch.device('cpu')
ldmk_model = Landmark_Model(config_file = config.ldmk_config, device=config.device)
config['kpfcn_config'] = ldmk_model.kpfcn_config
model = Registration(config)
timer = Timers()
from correspondence.datasets._4dmatch import _4DMatch
from correspondence.datasets.dataloader import get_dataloader
splits = [ '4DMatch-F', '4DLoMatch-F' ]
for split in splits:
config.split['test'] = split
stats_meter = None
test_set = _4DMatch(config, 'test', data_augmentation=False)
test_loader, _ = get_dataloader(test_set, config, shuffle=False)
logger = Logger(os.path.join(config.snapshot_dir, config.split["test"] + ".log"))
num_iter = len(test_set)
c_loader_iter = test_loader.__iter__()
for c_iter in tqdm(range(num_iter)):
inputs = c_loader_iter.next()
for k, v in inputs.items():
if type(v) == list:
inputs [k] = [item.to(config.device) for item in v]
elif type(v) in [dict, float, type(None), np.ndarray]:
pass
else:
inputs [k] = v.to(config.device)
"""predict landmarks"""
ldmk_s, ldmk_t, inlier_rate, inlier_rate_2 = ldmk_model.inference (inputs, reject_outliers=config.reject_outliers, inlier_thr=config.inlier_thr, timer=timer)
src_pcd, tgt_pcd = inputs["src_pcd_list"][0], inputs["tgt_pcd_list"][0]
s2t_flow = inputs['sflow_list'][0]
rot, trn = inputs['batched_rot'][0], inputs['batched_trn'][0]
correspondence = inputs['correspondences_list'][0]
"""compute scene flow GT"""
src_pcd_deformed = src_pcd + s2t_flow
s_pc_wrapped = ( rot @ src_pcd_deformed.T + trn ).T
s2t_flow = s_pc_wrapped - src_pcd
flow_gt = s2t_flow.to(config.device)
"""compute overlap mask"""
overlap = torch.zeros(len(src_pcd))
overlap[correspondence[:, 0]] = 1
overlap = overlap.bool()
overlap = overlap.to(config.device)
if config.deformation_model in ["NDP"]:
model.load_pcds(src_pcd, tgt_pcd, landmarks=(ldmk_s, ldmk_t))
timer.tic("registration")
warped_pcd, iter, timer = model.register(visualize=args.visualize, timer = timer)
timer.toc("registration")
flow = warped_pcd - model.src_pcd
for key, value in iter.items():
timer.tictoc(key, value)
elif config.deformation_model == "ED": # Lepard+NICP
model.load_pcds(src_pcd, tgt_pcd)
depth_paths = inputs['depth_paths_list'][0]
cam_intrin = inputs['cam_intrin']
# get pixel landmarks
uv_src = xyz_2_uv(ldmk_s, cam_intrin)
uv_tgt = xyz_2_uv(ldmk_t, cam_intrin)
landmarks = (uv_src.to(config.device), uv_tgt.to(config.device))
timer.tic("graph construction")
model.load_raw_pcds_from_depth(depth_paths[0], depth_paths[1], cam_intrin, landmarks=landmarks)
timer.toc("graph construction")
timer.tic("registration")
warped_pcd, point_mask = model.register(visualize=args.visualize)
timer.toc("registration")
flow = warped_pcd - model.src_pcd[point_mask]
flow_gt = flow_gt[point_mask]
overlap = overlap[point_mask]
else:
raise KeyError()
metric_info = compute_flow_metrics(flow, flow_gt, overlap=overlap)
if stats_meter is None:
stats_meter = dict()
for key, _ in metric_info.items():
stats_meter[key] = AverageMeter()
for key, value in metric_info.items():
stats_meter[key].update(value)
message = f'{c_iter}/{len(test_set)}: '
for key, value in stats_meter.items():
message += f'{key}: {value.avg:.3f}\t'
logger.write(message + '\n')
print("score on ", split, '\n', message)
# note down average time cost
print('time cost average')
for ele in timer.get_strings():
logger.write(ele + '\n')
print(ele)