-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy pathloss.py
120 lines (89 loc) · 4.56 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from core.config import cfg
from funcs_utils import stop
class CoordLoss(nn.Module):
def __init__(self, has_valid=False):
super(CoordLoss, self).__init__()
self.has_valid = has_valid
self.criterion = nn.L1Loss(reduction='mean')
def forward(self, pred, target, target_valid):
if self.has_valid:
pred, target = pred * target_valid, target * target_valid
loss = self.criterion(pred, target)
return loss
class LaplacianLoss(nn.Module):
def __init__(self, faces, average=False):
super(LaplacianLoss, self).__init__()
self.nv = 6890 # SMPL
self.nf = faces.shape[0]
self.average = average
laplacian = np.zeros([self.nv, self.nv]).astype(np.float32)
laplacian[faces[:, 0], faces[:, 1]] = -1
laplacian[faces[:, 1], faces[:, 0]] = -1
laplacian[faces[:, 1], faces[:, 2]] = -1
laplacian[faces[:, 2], faces[:, 1]] = -1
laplacian[faces[:, 2], faces[:, 0]] = -1
laplacian[faces[:, 0], faces[:, 2]] = -1
r, c = np.diag_indices(laplacian.shape[0])
laplacian[r, c] = -laplacian.sum(1)
for i in range(self.nv):
laplacian[i, :] /= (laplacian[i, i] + 1e-8)
self.register_buffer('laplacian', torch.from_numpy(laplacian).cuda().float())
def forward(self, x):
batch_size = x.size(0)
x = torch.cat([torch.matmul(self.laplacian, x[i])[None, :, :] for i in range(batch_size)], 0)
# x = torch.cat([torch.matmul(self.laplacian, x[i])[None, :, :] for i in range(batch_size)], 0)
x = x.pow(2).sum(2)
if self.average:
return x.sum() / batch_size
else:
return x.mean()
class NormalVectorLoss(nn.Module):
def __init__(self, face):
super(NormalVectorLoss, self).__init__()
self.face = face
def forward(self, coord_out, coord_gt):
face = torch.LongTensor(self.face).cuda()
v1_out = coord_out[:, face[:, 1], :] - coord_out[:, face[:, 0], :]
v1_out = F.normalize(v1_out, p=2, dim=2) # L2 normalize to make unit vector
v2_out = coord_out[:, face[:, 2], :] - coord_out[:, face[:, 0], :]
v2_out = F.normalize(v2_out, p=2, dim=2) # L2 normalize to make unit vector
v3_out = coord_out[:, face[:, 2], :] - coord_out[:, face[:, 1], :]
v3_out = F.normalize(v3_out, p=2, dim=2) # L2 nroamlize to make unit vector
v1_gt = coord_gt[:, face[:, 1], :] - coord_gt[:, face[:, 0], :]
v1_gt = F.normalize(v1_gt, p=2, dim=2) # L2 normalize to make unit vector
v2_gt = coord_gt[:, face[:, 2], :] - coord_gt[:, face[:, 0], :]
v2_gt = F.normalize(v2_gt, p=2, dim=2) # L2 normalize to make unit vector
normal_gt = torch.cross(v1_gt, v2_gt, dim=2)
normal_gt = F.normalize(normal_gt, p=2, dim=2) # L2 normalize to make unit vector
cos1 = torch.abs(torch.sum(v1_out * normal_gt, 2, keepdim=True))
cos2 = torch.abs(torch.sum(v2_out * normal_gt, 2, keepdim=True))
cos3 = torch.abs(torch.sum(v3_out * normal_gt, 2, keepdim=True))
loss = torch.cat((cos1, cos2, cos3), 1)
return loss.mean()
class EdgeLengthLoss(nn.Module):
def __init__(self, face):
super(EdgeLengthLoss, self).__init__()
self.face = face
def forward(self, coord_out, coord_gt):
face = torch.LongTensor(self.face).cuda()
d1_out = torch.sqrt(
torch.sum((coord_out[:, face[:, 0], :] - coord_out[:, face[:, 1], :]) ** 2, 2, keepdim=True))
d2_out = torch.sqrt(
torch.sum((coord_out[:, face[:, 0], :] - coord_out[:, face[:, 2], :]) ** 2, 2, keepdim=True))
d3_out = torch.sqrt(
torch.sum((coord_out[:, face[:, 1], :] - coord_out[:, face[:, 2], :]) ** 2, 2, keepdim=True))
d1_gt = torch.sqrt(torch.sum((coord_gt[:, face[:, 0], :] - coord_gt[:, face[:, 1], :]) ** 2, 2, keepdim=True))
d2_gt = torch.sqrt(torch.sum((coord_gt[:, face[:, 0], :] - coord_gt[:, face[:, 2], :]) ** 2, 2, keepdim=True))
d3_gt = torch.sqrt(torch.sum((coord_gt[:, face[:, 1], :] - coord_gt[:, face[:, 2], :]) ** 2, 2, keepdim=True))
diff1 = torch.abs(d1_out - d1_gt)
diff2 = torch.abs(d2_out - d2_gt)
diff3 = torch.abs(d3_out - d3_gt)
loss = torch.cat((diff1, diff2, diff3), 1)
return loss.mean()
def get_loss(faces):
loss = CoordLoss(has_valid=True), NormalVectorLoss(faces), EdgeLengthLoss(faces), CoordLoss(has_valid=True), CoordLoss(has_valid=True)
return loss