-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathtest.py
125 lines (88 loc) · 3.65 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import os,sys,inspect
import tensorflow as tf
import time
from datetime import datetime
import os
import hickle as hkl
import os.path as osp
from glob import glob
import sklearn.metrics as metrics
import math
from input import Dataset
import globals as g_
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)
import model
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', osp.dirname(sys.argv[0]) + '/tmp/',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.app.flags.DEFINE_string('weights', '',
"""finetune with a pretrained model""")
np.set_printoptions(precision=3)
def test(dataset, ckptfile):
print 'test() called'
V = g_.NUM_VIEWS
batch_size = FLAGS.batch_size
data_size = dataset.size()
print 'dataset size:', data_size
with tf.Graph().as_default():
startstep = 0
global_step = tf.Variable(startstep, trainable=False)
view_ = tf.placeholder('float32', shape=(None, V, 227, 227, 3), name='im0')
y_ = tf.placeholder('int64', shape=(None), name='y')
keep_prob_ = tf.placeholder('float32')
fc8 = model.inference_multiview(view_, g_.NUM_CLASSES, keep_prob_)
loss = model.loss(fc8, y_)
train_op = model.train(loss, global_step, data_size)
prediction = model.classify(fc8)
saver = tf.train.Saver(tf.all_variables(), max_to_keep=1000)
init_op = tf.global_variables_initializer()
sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
saver.restore(sess, ckptfile)
print 'restore variables done'
step = startstep
predictions = []
labels = []
print "Start testing"
print "Size:", data_size
print "It'll take", int(math.ceil(data_size/batch_size)), "iterations."
for batch_x, batch_y in dataset.batches(batch_size):
step += 1
start_time = time.time()
feed_dict = {view_: batch_x,
y_ : batch_y,
keep_prob_: 1.0}
pred, loss_value = sess.run(
[prediction, loss,],
feed_dict=feed_dict)
duration = time.time() - start_time
assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
if step % 10 == 0:
sec_per_batch = float(duration)
print '%s: step %d, loss=%.2f (%.1f examples/sec; %.3f sec/batch)' \
% (datetime.now(), step, loss_value,
FLAGS.batch_size/duration, sec_per_batch)
predictions.extend(pred.tolist())
labels.extend(batch_y.tolist())
# print labels
# print predictions
acc = metrics.accuracy_score(labels, predictions)
print 'acc:', acc*100
def main(argv):
st = time.time()
print 'start loading data'
listfiles, labels = read_lists(g_.TEST_LOL)
dataset = Dataset(listfiles, labels, subtract_mean=False, V=g_.NUM_VIEWS)
print 'done loading data, time=', time.time() - st
test(dataset, FLAGS.weights)
def read_lists(list_of_lists_file):
listfile_labels = np.loadtxt(list_of_lists_file, dtype=str).tolist()
listfiles, labels = zip(*[(l[0], int(l[1])) for l in listfile_labels])
return listfiles, labels
if __name__ == '__main__':
main(sys.argv)