-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
85 lines (69 loc) · 2.75 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 2 16:30:21 2017
@author: user
"""
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
#%% Evaluate one image
# when training, comment the following codes.
#from PIL import Image
#import matplotlib.pyplot as plt
#
def get_one_image(train):
'''Randomly pick one image from training data
Return: ndarray
'''
n = len(train)
ind = np.random.randint(0, n)
img_dir = train[ind]
image = Image.open(img_dir)
plt.imshow(image)
image = image.resize([208, 208])
image = np.array(image)
return image
def evaluate_one_image():
'''Test one image against the saved models and parameters
'''
# you need to change the directories to yours.
train_dir = '/home/user/Desktop/flower-tensorflow/train/'
train, train_label = input_data.get_files(train_dir)
image_array = get_one_image(train)
with tf.Graph().as_default():
BATCH_SIZE = 1
N_CLASSES = 5
image = tf.cast(image_array, tf.float32)
image = tf.image.per_image_standardization(image)
image = tf.reshape(image, [1, 208, 208, 3])
logit = model.inference(image, BATCH_SIZE, N_CLASSES)
logit = tf.nn.softmax(logit)
x = tf.placeholder(tf.float32, shape=[208, 208, 3])
# you need to change the directories to yours.
logs_train_dir = '/home/user/Desktop/flower-tensorflow/train_logits/'
saver = tf.train.Saver()
with tf.Session() as sess:
print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success, global_step is %s' % global_step)
else:
print('No checkpoint file found')
prediction = sess.run(logit, feed_dict={x: image_array})
print prediction
max_index = np.argmax(prediction)
if max_index==0:
print('This is a daisy with possibility %.6f' %prediction[:, 0])
elif max_index==1:
print('This is a roses with possibility %.6f' %prediction[:, 1])
elif max_index==2:
print('This is a sunflowers with possibility %.6f' %prediction[:, 2])
elif max_index==3:
print('This is a dandelion with possibility %.6f' %prediction[:, 3])
else:
print('This is a tuplits with possibility %.6f' %prediction[:, 4])
#%%