-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfig.py
141 lines (118 loc) · 5.96 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import time
import random
import math
import sys
import argparse
import torch.nn as nn
import torch.nn.functional as F
import xlrd
from dataset import *
from torch.utils import *
parser = argparse.ArgumentParser(description = 'retrieval')
parser.add_argument('--dataset', type = str, default = 'voc2007', help = "dataset name") #coco, flickr, voc2007, voc2012, nuswide
parser.add_argument('--hash_bit', type = int, default = 48, help = "number of hash code bits") #12, 16, 24, 32, 36, 48, 64
parser.add_argument('--batch_size', type = int, default = 100, help = "batch size")
parser.add_argument('--epochs', type = int, default = 100, help = "epochs")
parser.add_argument('--cuda', type = int, default = 0, help = "cuda id")
parser.add_argument('--backbone', type = str, default = 'googlenet', help = "backbone") #googlenet, resnet, alexnet
parser.add_argument('--beta', type = float, default = 0.5, help = "hyper-parameter for regularization")
parser.add_argument('--retrieve', type = int, default = 0, help = "retrieval number")
parser.add_argument('--no_save', action = 'store_true', default = False, help = "No save")
parser.add_argument('--seed', type = int, default = 0, help = "random seed")
parser.add_argument('--rate', type = float, default = 0.02, help = "rate")
parser.add_argument('--test', action = 'store_true', default = False, help = "testing") # for testing
args = parser.parse_args()
# Hyper-parameters
train_flag = bool(1 - args.test)
backbone = args.backbone
retrieve = args.retrieve
save_flag = bool(1 - args.no_save)
dataset = args.dataset
num_epochs = args.epochs
batch_size = args.batch_size
if backbone == 'googlenet':
feature_rate = 0.02
elif backbone == 'alexnet':
feature_rate = 0.01
criterion_rate = args.rate
num_bits = args.hash_bit
# hyper-parameters
beta = args.beta
seed =args.seed
# path for loading and saving models
path = './result/' + dataset + '_' + backbone + '_' + str(num_bits)
model_path = path + '.ckpt'
if train_flag and save_flag:
file_path = path + '.txt'
f = open(file_path, 'w')
# Device configuration
device = torch.device('cuda:'+str(args.cuda) if torch.cuda.is_available() else 'cpu')
# data pre-treatment
if backbone == 'googlenet':
data_transform = {
"train": transforms.Compose([transforms.Resize((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
elif backbone in ['resnet', 'alexnet']:
data_transform = {
"train": transforms.Compose([transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# load train data
if dataset == 'flickr':
num_classes = 38
if retrieve == 0:
retrieve = 1000
Flickr25k.init('./data/flickr25k/', 1000, 4000)
trainset = Flickr25k('./data/flickr25k/', 'train', transform = data_transform['train'])
testset = Flickr25k('./data/flickr25k/', 'query', transform = data_transform['val'])
database = Flickr25k('./data/flickr25k/', 'retrieval', transform = data_transform['val'])
elif dataset == 'voc2007':
if retrieve == 0:
retrieve = 5011
num_classes = 20
trainset = VOCBase(root = './data', year = '2007', image_set = 'trainval', download = True, transform = data_transform['train'])
testset = VOCBase(root = './data', year = '2007', image_set = 'test', download = True, transform = data_transform['val'])
database = VOCBase(root = './data', year = '2007', image_set = 'trainval', download = True, transform = data_transform['val'])
elif dataset == 'voc2012':
num_classes = 20
if retrieve == 0:
retrieve = 5717
trainset = VOCBase(root = './data', year = '2012', image_set = 'train', download = True, transform = data_transform['train'])
testset = VOCBase(root = './data', year = '2012', image_set = 'val', download = True, transform = data_transform['val'])
database = VOCBase(root = './data', year = '2012', image_set = 'train', download = True, transform = data_transform['val'])
elif dataset == 'nuswide':
if retrieve == 0:
retrieve = 5000
num_classes = 21
trainset = ImageList(open('./data/nus_wide/train.txt', 'r').readlines(), transform = data_transform['train'])
testset = ImageList(open('./data/nus_wide/test.txt', 'r').readlines(), transform = data_transform['val'])
database = ImageList(open('./data/nus_wide/database.txt', 'r').readlines(), transform = data_transform['val'])
train_num = len(trainset)
test_num = len(testset)
database_num = len(database)
trainloader = data.DataLoader(dataset = trainset,
batch_size = batch_size,
shuffle = True,
num_workers = 8)
testloader = data.DataLoader(dataset = testset,
batch_size = batch_size,
shuffle = False,
num_workers = 8)
databaseloader = data.DataLoader(dataset = database,
batch_size = batch_size,
shuffle = False,
num_workers = 8)
# find the value of ζ
sheet = xlrd.open_workbook('codetable.xlsx').sheet_by_index(0)
threshold = sheet.row(num_bits)[math.ceil(math.log(num_classes, 2))].value
print(threshold)
print('------------- data prepared -------------')