-
Notifications
You must be signed in to change notification settings - Fork 114
/
Copy pathUnsupervised_BiHalf.py
135 lines (102 loc) · 4.36 KB
/
Unsupervised_BiHalf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from utils.tools import *
from network import *
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
import time
import numpy as np
torch.multiprocessing.set_sharing_strategy('file_system')
# Deep Unsupervised Image Hashing by Maximizing Bit Entropy(AAAI2021)
# paper [Deep Unsupervised Image Hashing by Maximizing Bit Entropy](https://arxiv.org/pdf/2012.12334.pdf)
# code [Deep-Unsupervised-Image-Hashing](https://github.com/liyunqianggyn/Deep-Unsupervised-Image-Hashing)
# [BiHalf Unsupervised] epoch:40, bit:64, dataset:cifar10-2, MAP:0.593, Best MAP: 0.593
def get_config():
config = {
"gamma": 6,
"optimizer": {"type": optim.SGD, "epoch_lr_decrease": 30,
"optim_params": {"lr": 0.0001, "weight_decay": 5e-4, "momentum": 0.9}},
"info": "[BiHalf Unsupervised]",
"resize_size": 256,
"crop_size": 224,
"batch_size": 64,
"net": BiHalfModelUnsupervised,
"dataset": "cifar10-2", # in paper BiHalf is "Cifar-10(I)"
"epoch": 200,
"test_map": 5,
# "device":torch.device("cpu"),
"device": torch.device("cuda:1"),
"bit_list": [64],
}
config = config_dataset(config)
config["topK"] = 1000
return config
class BiHalfModelUnsupervised(nn.Module):
def __init__(self, bit):
super(BiHalfModelUnsupervised, self).__init__()
self.vgg = models.vgg16(pretrained=True)
self.vgg.classifier = nn.Sequential(*list(self.vgg.classifier.children())[:6])
for param in self.vgg.parameters():
param.requires_grad = False
self.fc_encode = nn.Linear(4096, bit)
class Hash(torch.autograd.Function):
@staticmethod
def forward(ctx, U):
# Yunqiang for half and half (optimal transport)
_, index = U.sort(0, descending=True)
N, D = U.shape
B_creat = torch.cat((torch.ones([int(N / 2), D]), -torch.ones([N - int(N / 2), D]))).to(config["device"])
B = torch.zeros(U.shape).to(config["device"]).scatter_(0, index, B_creat)
ctx.save_for_backward(U, B)
return B
@staticmethod
def backward(ctx, g):
U, B = ctx.saved_tensors
add_g = (U - B) / (B.numel())
grad = g + config["gamma"] * add_g
return grad
def forward(self, x):
x = self.vgg.features(x)
x = x.view(x.size(0), -1)
x = self.vgg.classifier(x)
h = self.fc_encode(x)
if not self.training:
return h.sign()
else:
b = BiHalfModelUnsupervised.Hash.apply(h)
target_b = F.cosine_similarity(b[:x.size(0) // 2], b[x.size(0) // 2:])
target_x = F.cosine_similarity(x[:x.size(0) // 2], x[x.size(0) // 2:])
loss = F.mse_loss(target_b, target_x)
return loss
def train_val(config, bit):
device = config["device"]
train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)
config["num_train"] = num_train
net = config["net"](bit).to(device)
optimizer = config["optimizer"]["type"](net.parameters(), **(config["optimizer"]["optim_params"]))
Best_mAP = 0
for epoch in range(config["epoch"]):
lr = config["optimizer"]["optim_params"]["lr"] * (0.1 ** (epoch // config["optimizer"]["epoch_lr_decrease"]))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
current_time = time.strftime('%H:%M:%S', time.localtime(time.time()))
print("%s[%2d/%2d][%s] bit:%d, lr:%.9f, dataset:%s, training...." % (
config["info"], epoch + 1, config["epoch"], current_time, bit, lr, config["dataset"]), end="")
net.train()
train_loss = 0
for image, _, ind in train_loader:
image = image.to(device)
optimizer.zero_grad()
loss = net(image)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_loss = train_loss / len(train_loader)
print("\b\b\b\b\b\b\b loss:%.9f" % (train_loss))
if (epoch + 1) % config["test_map"] == 0:
Best_mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, epoch, num_dataset)
if __name__ == "__main__":
config = get_config()
print(config)
for bit in config["bit_list"]:
train_val(config, bit)