This repository has been archived by the owner on Feb 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathloss.py
60 lines (48 loc) · 1.68 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author: An Tao
@Contact: [email protected]
@File: loss.py
@Time: 2020/1/2 10:26 AM
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChamferLoss(nn.Module):
def __init__(self):
super(ChamferLoss, self).__init__()
self.use_cuda = torch.cuda.is_available()
def batch_pairwise_dist(self, x, y):
bs, num_points_x, points_dim = x.size()
_, num_points_y, _ = y.size()
xx = x.pow(2).sum(dim=-1)
yy = y.pow(2).sum(dim=-1)
zz = torch.bmm(x, y.transpose(2, 1))
rx = xx.unsqueeze(1).expand_as(zz.transpose(2, 1))
ry = yy.unsqueeze(1).expand_as(zz)
P = (rx.transpose(2, 1) + ry - 2 * zz)
return P
def forward(self, preds, gts):
P = self.batch_pairwise_dist(gts, preds)
mins, _ = torch.min(P, 1)
loss_1 = torch.sum(mins)
mins, _ = torch.min(P, 2)
loss_2 = torch.sum(mins)
return loss_1 + loss_2
class CrossEntropyLoss(nn.Module):
def __init__(self, smoothing=True):
super(CrossEntropyLoss, self).__init__()
self.smoothing = smoothing
def forward(self, preds, gts):
gts = gts.contiguous().view(-1)
if self.smoothing:
eps = 0.2
n_class = preds.size(1)
one_hot = torch.zeros_like(preds).scatter(1, gts.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(preds, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
else:
loss = F.cross_entropy(preds, gts, reduction='mean')
return loss