-
Notifications
You must be signed in to change notification settings - Fork 114
/
Copy pathDPN.py
151 lines (119 loc) · 5.07 KB
/
DPN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from utils.tools import *
from network import *
import os
import torch
import torch.optim as optim
import time
import numpy as np
import random
torch.multiprocessing.set_sharing_strategy('file_system')
# DPN(IJCAI2020)
# paper [Deep Polarized Network for Supervised Learning of Accurate Binary Hashing Codes](https://www.ijcai.org/Proceedings/2020/115)
# code [DPN](https://github.com/kamwoh/DPN)
# [DPN] epoch:150, bit:48, dataset:imagenet, MAP:0.675, Best MAP: 0.688
# [DPN] epoch:70, bit:48, dataset:cifar10-1, MAP:0.778, Best MAP: 0.787
# [DPN] epoch:10, bit:48, dataset:nuswide_21, MAP:0.818, Best MAP: 0.818
# [DPN-T] epoch:10, bit:48, dataset:cifar10-1, MAP:0.134, Best MAP: 0.134
def get_config():
config = {
"m": 1,
"p": 0.5,
"optimizer": {"type": optim.RMSprop, "optim_params": {"lr": 1e-5, "weight_decay": 1e-5}},
"info": "[DPN]",
# "info": "[DPN-A]",
# "info": "[DPN-T]",
# "info": "[DPN-A-T]",
"resize_size": 256,
"crop_size": 224,
"batch_size": 32,
"net": AlexNet,
# "net": ResNet,
"dataset": "cifar10-1",
# "dataset": "imagenet",
# "dataset": "coco",
# "dataset": "nuswide_21",
"epoch": 150,
"test_map": 10,
# "device":torch.device("cpu"),
"device": torch.device("cuda:1"),
"bit_list": [48],
}
config = config_dataset(config)
return config
class DPNLoss(torch.nn.Module):
def __init__(self, config, bit):
super(DPNLoss, self).__init__()
self.is_single_label = config["dataset"] not in {"nuswide_21", "nuswide_21_m", "coco"}
self.target_vectors = self.get_target_vectors(config["n_class"], bit, config["p"]).to(config["device"])
self.multi_label_random_center = torch.randint(2, (bit,)).float().to(config["device"])
self.m = config["m"]
self.U = torch.zeros(config["num_train"], bit).float().to(config["device"])
self.Y = torch.zeros(config["num_train"], config["n_class"]).float().to(config["device"])
def forward(self, u, y, ind, config):
self.U[ind, :] = u.data
self.Y[ind, :] = y.float()
if "-T" in config["info"]:
# Ternary Assignment
u = (u.abs() > self.m).float() * u.sign()
t = self.label2center(y)
polarization_loss = (self.m - u * t).clamp(0).mean()
return polarization_loss
def label2center(self, y):
if self.is_single_label:
hash_center = self.target_vectors[y.argmax(axis=1)]
else:
# for multi label, use the same strategy as CSQ
center_sum = y @ self.target_vectors
random_center = self.multi_label_random_center.repeat(center_sum.shape[0], 1)
center_sum[center_sum == 0] = random_center[center_sum == 0]
hash_center = 2 * (center_sum > 0).float() - 1
return hash_center
# Random Assignments of Target Vectors
def get_target_vectors(self, n_class, bit, p=0.5):
target_vectors = torch.zeros(n_class, bit)
for k in range(20):
for index in range(n_class):
ones = torch.ones(bit)
sa = random.sample(list(range(bit)), int(bit * p))
ones[sa] = -1
target_vectors[index] = ones
return target_vectors
# Adaptive Updating
def update_target_vectors(self):
self.U = (self.U.abs() > self.m).float() * self.U.sign()
self.target_vectors = (self.Y.t() @ self.U).sign()
def train_val(config, bit):
device = config["device"]
train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)
config["num_train"] = num_train
net = config["net"](bit).to(device)
optimizer = config["optimizer"]["type"](net.parameters(), **(config["optimizer"]["optim_params"]))
criterion = DPNLoss(config, bit)
Best_mAP = 0
for epoch in range(config["epoch"]):
current_time = time.strftime('%H:%M:%S', time.localtime(time.time()))
print("%s[%2d/%2d][%s] bit:%d, dataset:%s, training...." % (
config["info"], epoch + 1, config["epoch"], current_time, bit, config["dataset"]), end="")
net.train()
train_loss = 0
for image, label, ind in train_loader:
image = image.to(device)
label = label.to(device)
optimizer.zero_grad()
u = net(image)
loss = criterion(u, label.float(), ind, config)
train_loss += loss.item()
loss.backward()
optimizer.step()
if "-A" in config["info"]:
criterion.update_target_vectors()
train_loss = train_loss / len(train_loader)
print("\b\b\b\b\b\b\b loss:%.3f" % (train_loss))
if (epoch + 1) % config["test_map"] == 0:
Best_mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, epoch, num_dataset)
if __name__ == "__main__":
config = get_config()
print(config)
for bit in config["bit_list"]:
config["pr_curve_path"] = f"log/alexnet/DPN_{config['dataset']}_{bit}.json"
train_val(config, bit)