-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcorr.py
83 lines (51 loc) · 2.31 KB
/
corr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn.functional as F
def sampler(img, coords):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
del coords
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
del xgrid,ygrid
img = F.grid_sample(img, grid, align_corners=True)
del grid
return img
# used only for images with relative poses
class GuidedCorrBlock:
def __init__(self,
fmap2,
flow_bas, flow_dir,
corr_levels, corr_radius):
self.D = fmap2.shape[1]
self.corr_levels = corr_levels
self.corr_radius = corr_radius
self.fmap2_pyra = [fmap2]
for i in range(1,self.corr_levels):
self.fmap2_pyra.append(F.avg_pool2d(self.fmap2_pyra[-1], 2, stride=2))
self.flow_bas = flow_bas
self.flow_dir = flow_dir
self.flow_mask = (flow_dir.detach().abs()).max(dim=1,keepdim=True)[1]
def get_flow(self,match):
match -= self.flow_bas
match = ( match * self.flow_dir ).sum(dim=1,keepdim=True)
return match
def get_match(self, flow):
return self.flow_bas + flow*self.flow_dir
def get_cost(self, fmap1_pyra, flow):
corrs = []
for p in range(self.corr_levels):
cur_scale = 2**p
for s in range(-self.corr_radius,self.corr_radius+1):
match = self.flow_bas + ( flow + s )*self.flow_dir
match = match.permute(0, 2, 3, 1) / cur_scale
# fmap2_warp = sampler(self.fmap2_pyra[p],match)
corr_s = ( fmap1_pyra*sampler(self.fmap2_pyra[p],match) ).sum(dim=1,keepdims=True)
del match
corr_s = corr_s/self.D
corrs.append(corr_s)
del corr_s
del flow,fmap1_pyra
corrs = torch.cat(corrs,dim=1)
return corrs