-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathloss.py
67 lines (51 loc) · 1.74 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Losses
"""
# pylint: disable=C0301,C0103,R0902,R0915,W0221,W0622
##
# LIBRARIES
import torch
##
def l1_loss(input, target):
""" L1 Loss without reduce flag.
Args:
input (FloatTensor): Input tensor
target (FloatTensor): Output tensor
Returns:
[FloatTensor]: L1 distance between input and output
"""
return torch.mean(torch.abs(input - target))
##
def l2_loss(input, target, size_average=True):
""" L2 Loss without reduce flag.
Args:
input (FloatTensor): Input tensor
target (FloatTensor): Output tensor
Returns:
[FloatTensor]: L2 distance between input and output
"""
if size_average:
return torch.mean(torch.pow((input-target), 2))
else:
return torch.pow((input-target), 2)
def loss_func(adj, A_hat, attrs, X_hat):
# Attribute reconstruction loss
diff_attribute = torch.pow(X_hat - attrs, 2)
attribute_reconstruction_errors = torch.sqrt(torch.sum(diff_attribute, 1))
attribute_cost = torch.mean(attribute_reconstruction_errors)
# structure reconstruction loss
diff_structure = torch.pow(A_hat - adj, 2)
structure_reconstruction_errors = torch.sqrt(torch.sum(diff_structure, 1))
structure_cost = torch.mean(structure_reconstruction_errors)
return structure_cost, attribute_cost
def loss_cal(x, x_aug):
T = 0.2
batch_size, _ = x.size()
x_abs = x.norm(dim=1)
x_aug_abs = x_aug.norm(dim=1)
sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
sim_matrix = torch.exp(sim_matrix / T)
pos_sim = sim_matrix[range(batch_size), range(batch_size)]
loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
loss = - torch.log(loss).mean()
return loss