-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexplore_vgg16.py
76 lines (52 loc) · 2.23 KB
/
explore_vgg16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import tensorflow as tf
from read_vid import *
import tensorflow.contrib.slim.nets as nets
from sklearn.manifold import TSNE
''' plot projections of features pretrained vgg16 extracts from echocardio grams.'''
IMG_NET_WEIGHT_PATH = '/home/yasaman/HN/image_net_trained/vgg_16.ckpt'
RECORD_PATH = "/home/yasaman/echo/subj_video.tfrecords"
vgg=nets.vgg
input_files = tf.placeholder(tf.string, shape=None)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parse_example)
dataset = dataset.shuffle(buffer_size=1000)
iterator = dataset.make_initializable_iterator()
next_ptid, next_outcome, next_frames = iterator.get_next()
outcome_onehot = tf.one_hot(next_outcome, depth=3, dtype=tf.int64)
logits, intermed = vgg.vgg_16(next_frames, is_training=False, spatial_squeeze=False)
fc7 = intermed['vgg_16/fc7']
fc7 = tf.squeeze(fc7)
restorer = tf.train.Saver()
features = []
labels = []
ids = []
unique_ids = set()
with tf.Session() as sess:
restorer.restore(sess, IMG_NET_WEIGHT_PATH)
sess.run(iterator.initializer, feed_dict={input_files:RECORD_PATH})
for i in range(40):
# print(np.argwhere(np.isnan(frames[0])), frames[0].min(), frames[0].max())
extracted_features, label, ptid = sess.run((fc7, next_outcome, next_ptid))
print("label", label, "ptid", ptid, extracted_features.shape[0])
features.append(extracted_features)
labels.append(label * np.ones(extracted_features.shape[0]))
ids.append(np.repeat(ptid, extracted_features.shape[0]))
unique_ids.add(ptid[0])
cmap = plt.get_cmap('viridis')
id_col = dict(zip(unique_ids, cmap(np.linspace(0,1, len(unique_ids)))))
all_frames = np.concatenate(features, axis=0)
all_labels = np.concatenate(labels, axis=0)
all_ids = np.concatenate(ids, axis=0)
embedded_frames = TSNE().fit_transform(all_frames)
fig = plt.figure()
ax1 = plt.subplot(111)
for ptid in id_col.keys():
ax1.plot(embedded_frames[all_ids == ptid,0], embedded_frames[all_ids == ptid, 1], 'o', label=ptid, c=id_col[ptid])
plt.legend()
'''
ax1.plot(embedded_frames[all_labels==0,0], embedded_frames[all_labels==0,1], 'o')
ax1.plot(embedded_frames[all_labels==1,0], embedded_frames[all_labels==1,1], 'ro')
ax1.plot(embedded_frames[all_labels==2,0], embedded_frames[all_labels==2,1], 'ko')
'''
plt.show()