-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathstarLoss_v2.py
154 lines (125 loc) · 6.61 KB
/
starLoss_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
torch_ver = torch.__version__.split('.')
from .smoothL1Loss import SmoothL1Loss
from .wingLoss import WingLoss
def get_channel_sum(input):
temp = torch.sum(input, dim=3)
output = torch.sum(temp, dim=2)
return output
def expand_two_dimensions_at_end(input, dim1, dim2):
input = input.unsqueeze(-1).unsqueeze(-1)
input = input.expand(-1, -1, dim1, dim2)
return input
class STARLoss_v2(nn.Module):
def __init__(self, w=1, dist='smoothl1', num_dim_image=2, EPSILON=1e-5):
super(STARLoss_v2, self).__init__()
self.w = w
self.num_dim_image = num_dim_image
self.EPSILON = EPSILON
self.dist = dist
if self.dist == 'smoothl1':
self.dist_func = SmoothL1Loss()
elif self.dist == 'l1':
self.dist_func = F.l1_loss
elif self.dist == 'l2':
self.dist_func = F.mse_loss
elif self.dist == 'wing':
self.dist_func = WingLoss()
else:
raise NotImplementedError
def __repr__(self):
return "STARLoss()"
def _make_grid(self, h, w):
yy, xx = torch.meshgrid(
torch.arange(h).float() / (h - 1) * 2 - 1,
torch.arange(w).float() / (w - 1) * 2 - 1)
return yy, xx
def weighted_mean(self, heatmap):
batch, npoints, h, w = heatmap.shape
yy, xx = self._make_grid(h, w)
yy = yy.view(1, 1, h, w).to(heatmap)
xx = xx.view(1, 1, h, w).to(heatmap)
yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints
xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints
coords = torch.stack([xx_coord, yy_coord], dim=-1)
return coords
def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5):
batch_size, num_points, height, width = htp.shape
yv, xv = self._make_grid(height, width)
xv = Variable(xv)
yv = Variable(yv)
if htp.is_cuda:
xv = xv.cuda()
yv = yv.cuda()
xmean = means[:, :, 0]
xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height,
width) # [batch_size, 68, 64, 64]
ymean = means[:, :, 1]
yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height,
width) # [batch_size, 68, 64, 64]
wt_xv_minus_mean = xv_minus_mean
wt_yv_minus_mean = yv_minus_mean
wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096]
wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096]
vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096]
htp_vec = htp.view(batch_size * num_points, 1, height * width)
htp_vec = htp_vec.expand(-1, 2, -1)
covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2]
covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2]
V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68]
V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68]
denominator = V_1 - (V_2 / V_1)
covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image)
return covariance
def ambiguity_guided_decompose(self, error, evalues, evectors):
bs, npoints = error.shape[:2]
normal_vector = evectors[:, :, 0]
tangent_vector = evectors[:, :, 1]
normal_error = torch.matmul(normal_vector.unsqueeze(-2), error.unsqueeze(-1))
tangent_error = torch.matmul(tangent_vector.unsqueeze(-2), error.unsqueeze(-1))
normal_error = normal_error.squeeze(dim=-1)
tangent_error = tangent_error.squeeze(dim=-1)
normal_dist = self.dist_func(normal_error, torch.zeros_like(normal_error).to(normal_error), reduction='none')
tangent_dist = self.dist_func(tangent_error, torch.zeros_like(tangent_error).to(tangent_error), reduction='none')
normal_dist = normal_dist.reshape(bs, npoints, 1)
tangent_dist = tangent_dist.reshape(bs, npoints, 1)
dist = torch.cat((normal_dist, tangent_dist), dim=-1)
scale_dist = dist / torch.sqrt(evalues + self.EPSILON)
scale_dist = scale_dist.sum(-1)
return scale_dist
def eigenvalue_restriction(self, evalues, batch, npoints):
eigen_loss = torch.abs(evalues.view(batch, npoints, 2)).sum(-1)
return eigen_loss
def forward(self, heatmap, groundtruth):
"""
heatmap: b x n x 64 x 64
groundtruth: b x n x 2
output: b x n x 1 => 1
"""
# normalize
bs, npoints, h, w = heatmap.shape
heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
heatmap = heatmap / heatmap_sum.view(bs, npoints, 1, 1)
means = self.weighted_mean(heatmap) # [bs, 68, 2]
covars = self.unbiased_weighted_covariance(heatmap, means) # covars [bs, 68, 2, 2]
# TODO: GPU-based eigen-decomposition
# https://github.com/pytorch/pytorch/issues/60537
_covars = covars.view(bs * npoints, 2, 2).cpu()
if int(torch_ver[0]) > 1 or (int(torch_ver[0]) == 1 and int(torch_ver[1]) >= 8):
evalues, evectors = torch.linalg.eigh(_covars) # evalues [bs * 68, 2], evectors [bs * 68, 2, 2]
else:
evalues, evectors = _covars.symeig(eigenvectors=True) # Pre-torch 1.8
evalues = evalues.view(bs, npoints, 2).to(heatmap)
evectors = evectors.view(bs, npoints, 2, 2).to(heatmap)
# STAR Loss
# Ambiguity-guided Decomposition
loss_trans = self.ambiguity_guided_decompose(groundtruth - means, evalues, evectors)
# Eigenvalue Restriction
loss_eigen = self.eigenvalue_restriction(evalues, bs, npoints)
star_loss = loss_trans + self.w * loss_eigen
return star_loss.mean()