-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
68 lines (54 loc) · 2.09 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import normalized_mutual_info_score, confusion_matrix
def seed_everything(seed):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed_all(seed)
print('Global seed:', seed)
def clustering_accuracy(y_true, y_pred):
"""
Calculate clustering accuracy. Require scikit-learn installed
# Arguments
y: true labels, numpy.array with shape `(n_samples,)`
y_pred: predicted labels, numpy.array with shape `(n_samples,)`
# Return
accuracy, in [0,1]
"""
y_true = y_true.astype(np.int64)
assert y_pred.size == y_true.size
D = max(y_pred.max(), y_true.max()) + 1
w = np.zeros((D, D), dtype=np.int64)
for i in range(y_pred.size):
w[y_pred[i], y_true[i]] += 1
ind = linear_sum_assignment(w.max() - w)
ind = np.asarray(ind)
ind = np.transpose(ind)
return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
def measure_cluster(y_pred, y_true):
acc = clustering_accuracy(y_true, y_pred)
nmi = normalized_mutual_info_score(y_true, y_pred, average_method='geometric')
cm = confusion_matrix(y_true, y_pred)
row_max = cm.max(axis=1).sum()
total = cm.sum()
pur = row_max / total
return acc, nmi, pur
def target_distribution(batch: torch.Tensor) -> torch.Tensor:
"""
Compute the target distribution p_ij, given the batch (q_ij), as in 3.1.3 Equation 3 of
Xie/Girshick/Farhadi; this is used the KL-divergence loss function.
:param batch: [batch size, number of clusters] Tensor of dtype float
:return: [batch size, number of clusters] Tensor of dtype float
"""
weight = (batch ** 2) / torch.sum(batch, 0)
return (weight.t() / torch.sum(weight, 1)).t()
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)