-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
79 lines (59 loc) · 2.84 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
""" Loss module.
References:
[1] S. Sabour, N. Frosst, and G. E. Hinton, “Dynamic routing between capsules,” in NIPS, pp. 3859–3869, 2017.
"""
from torch import nn
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
from utils import one_hot
class CapsuleLoss(_Loss):
""" Margin Loss
Margin Loss as defined in [1].
Args:
m_plus (float): m+ in the margin loss.
m_min (float): m- in the margin loss.
alpha (float): the scalar that controls the contribution of the reconstruction loss.
include_recon (bool): Use the reconstruction loss.
"""
def __init__(self, m_plus, m_min, alpha, include_recon):
super(CapsuleLoss, self).__init__()
self.m_plus = m_plus
self.m_min = m_min
self.alpha = alpha
self.include_recon = include_recon
# init mean square error loss
self.recon_loss = nn.MSELoss(reduction="none")
def forward(self, images, labels, logits, recon):
""" Forward pass.
Args:
images (FloatTensor): Orginal images. Shape: [batch, channel, height, width].
labels (LongTensor): Class labels. Shape: [batch]
logits (FloatTensor): Class logits. Length of the final capsules. Shape: [batch, classes]
recon (FloatTensor): Reconstructed image. Same shape as images.
Returns:
total_loss (FloatTensor): Sum of all losses. Single value.
margin_loss (FloatTensor): Margin loss defined in [1]. Single value.
recon_loss (FloatTensor): MSE loss of the reconstructed image. None if not included. Single value.
"""
num_classes = logits.shape[1]
labels_one_hot = one_hot(labels, num_classes)
# the factor 0.5 in front of both terms is not in the paper, but used in the source code
present_loss = 0.5 * F.relu(self.m_plus - logits, inplace=True) ** 2
absent_loss = 0.5 * F.relu(logits - self.m_min, inplace=True) ** 2
# the factor 0.5 is the downweight mentioned in the Margin loss in [1]
margin_loss = labels_one_hot * present_loss + 0.5 * (1. - labels_one_hot) * absent_loss
margin_loss_per_sample = margin_loss.sum(dim=1)
margin_loss = margin_loss_per_sample.mean()
if self.include_recon:
# sum over all image dimensions
recon_loss = self.recon_loss(recon, images).sum(dim=-1).sum(dim=-1).sum(dim=-1)
assert len(recon_loss.shape) == 1, "Only batch dimension should be left after in recon loss."
# average of sum over batch dimension
recon_loss = recon_loss.mean()
else:
recon_loss = None
# scale the recon
total_loss = margin_loss
if self.include_recon:
total_loss = total_loss + self.alpha * recon_loss
return total_loss, margin_loss, recon_loss