Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation #2

Merged
merged 7 commits into from
Jul 27, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added PR calculations from VPR tutorial, fixed colormaps, combined tr…
…aining sets in similarity matrices
AdamDHines committed Jul 18, 2023
commit 7f0ea1d5789f0221ef96a689e9dee5e5bb3a4e94
181 changes: 141 additions & 40 deletions VPRTempo.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
import torch
import gc
import re
import math
import sys
sys.path.append('./src')
sys.path.append('./weights')
@@ -41,6 +42,7 @@

from os import path
from alive_progress import alive_bar
from metrics import createPR


'''
@@ -51,7 +53,7 @@ def __init__(self):
super().__init__()

# Image and patch normalization settings
self.dataPath = '/home/adam/data/VPRTempo_training/training_data/' # training datapath
self.trainingPath = '/home/adam/data/VPRTempo_training/training_data/' # training datapath
self.testPath = '/home/adam/data/VPRTempo_training/testing_data/'
self.imWidth = 28 # image width for patch norm
self.imHeight = 28 # image height for patch norm
@@ -80,7 +82,8 @@ def __init__(self):
for b in bound:
self.img_load.append(lines[b])

# save text file here for testing images so its the same as the training
with open('./output/images.pkl', 'wb') as f:
pickle.dump(self.img_load, f)

if self.location_repeat > 1:
for n in bound:
@@ -161,7 +164,12 @@ def patch_normalise_pad(self):

# Process the loaded images - resize, normalize color, & patch normalize
def processImage(self):

# gamma correct images
mid = 0.5
mean = np.mean(self.img)
gamma = math.log(mid*255)/math.log(mean)
self.img = np.power(self.img,gamma).clip(0,255).astype(np.uint8)

# resize image to 28x28 and patch normalize
self.img = cv2.resize(self.img,(self.imWidth, self.imHeight))
self.patch_normalise_pad()
@@ -173,6 +181,11 @@ def loadImages(self):

if self.test_true:
self.dataPath = self.testPath
range_pop = range(self.test_t,self.train_img)
for m in reversed(range_pop):
self.img_load.pop(m)
else:
self.dataPath = self.trainingPath

self.ims = []
self.ids = []
@@ -187,14 +200,6 @@ def loadImages(self):
self.processImage()
self.ims.append(self.img)
self.ids.append(m)

# pickle image ids to keep track of shuffling
if self.test_true:
with open('./output/test_ids.pkl','wb') as f:
pickle.dump(self.ids,f)
else:
with open('./output/train_ids.pkl','wb') as f:
pickle.dump(self.ids,f)

data = {'x': np.array(self.ims), 'y': np.array(self.ids),
'rows': self.imWidth, 'cols': self.imHeight}
@@ -400,54 +405,80 @@ def set_spikes():

# network testing function
def test_network():
# Test the output
# set up variables to load output into
self.correct_idx = []
self.numcorrect = 0
self.mat_dict = {}
self.net_x = np.array([])
self.tp = 0
self.tn = 0
self.fp = 0
self.fn = 0
idx = []

# reset the output similarity matrices
for n in range(self.location_repeat):
self.mat_dict[str(n)] = []
idx.append(n)

# create a dictionary of locations, with location repeats
match_dict = {}
for n in range(self.test_t):
match_dict[str(n)] = []
for n in range(self.location_repeat):
index_add = self.test_t*n
for m in range(self.test_t):
match_dict[str(m)].append(m + index_add)
pause=1
# run the network
for t in range(self.test_t):
blitnet.runSim(net,1)
nidx = np.argmax(net['x'][-1])
if nidx < int(self.train_img/self.location_repeat):
if self.train_ids[nidx] == self.test_ids[t] or self.train_ids[(nidx+int(self.train_img/self.location_repeat))] == self.test_ids[t]:
self.numcorrect = self.numcorrect+1
self.correct_idx.append(t)
# output the index of highest amplitude spike
tonump = net['x'][-1].detach().cpu().numpy()
nidx = np.argmax(tonump)

# if output value is 0, classify as false negative
if np.max(tonump) == 0:
self.fn = self.fn + 1

# evaluate if the output is TP or a FP
else:
if self.train_ids[nidx] == self.test_ids[t] or self.train_ids[(nidx-int(self.train_img/self.location_repeat))] == self.test_ids[t]:
index_lst = match_dict[str(t)]
if nidx in index_lst:
self.numcorrect = self.numcorrect+1
self.tp = self.tp+1
self.correct_idx.append(t)
tonump = net['x'][-1].detach().cpu().numpy()
else:
self.fp = self.fp+1

self.net_x = np.concatenate((self.net_x,tonump),axis=0)
split_mat = np.split(tonump,self.location_repeat)
for n in range(self.location_repeat):
self.mat_dict[str(idx[n])]= np.concatenate((self.mat_dict[str(idx[n])],split_mat[n]),axis=0)

yield

# calculate and plot distance matrices & PR curves
def plotit(netx):
reshape_mat = np.reshape(temp_mat,(self.test_t,int(self.train_img/self.location_repeat)))
# calculate and plot distance matrices
def plotit(netx,name):
reshape_mat = np.reshape(netx,(self.test_t,int(self.train_img/self.location_repeat)))
# plot the matrix
fig = plt.figure()
plt.matshow(reshape_mat,fig, cmap=plt.cm.prism)
fig.suptitle(("Similarity VPRTempo: Location "+str(int(n)+1)),fontsize = 12)
plt.xlabel("Database",fontsize = 12)
plt.ylabel("Query",fontsize = 12)
plt.matshow(reshape_mat,fig, cmap=plt.cm.Greens)
plt.colorbar(label="Spike amplitude")
fig.suptitle("Similarity "+name,fontsize = 12)
plt.xlabel("Query",fontsize = 12)
plt.ylabel("Database",fontsize = 12)
plt.show()

# calculate PR curves


# network validation using alternative place matching algorithms and P@R calculation
def network_validator():
# reload training images for the comparisons
self.test_true= False # reset the test true flag
self.dataPath = '/home/adam/data/VPRTempo_training/training_data/'
self.test_imgs = self.ims.copy()
with open('./nordland_imageNames.txt') as file:
lines = [line.rstrip() for line in file]
self.img_load = lines
self.loadImages()
# run sum of absolute differences caluclation
validate.SAD(self)
@@ -466,10 +497,63 @@ def network_validator():
# Reset number of epochs to 1 to run testing network once
self.epoch = 1

# get the ground truth matrix
self.loadImages()
# Set input layer spikes as the testing images
print('Setting spike times')
with alive_bar(len(self.spike_rates)) as sbar:
for i in set_spikes():
sbar()

# run the test netowrk
print('Getting ground truth matrix')
with alive_bar(self.test_t) as testbar:
for i in test_network():
testbar()

# plot the similarity matrices for each location repetition
append_mat = []
for n in self.mat_dict:
if int(n) != 0:
append_mat = append_mat + self.mat_dict[str(n)]
else:
append_mat = np.copy(self.mat_dict[str(n)])
plot_name = "training images"
plotit(append_mat,plot_name)

# pickle the ground truth matrix
reshape_mat = np.reshape(append_mat,(self.test_t,int(self.train_img/self.location_repeat)))
boolval = reshape_mat > 0
GTsoft = boolval.astype(int)
GT = np.zeros((self.test_t,self.test_t), dtype=int)

for n in range(len(GT)):
GT[n,n] = 1
plot_name = "Similarity absolute ground truth"
fig = plt.figure()
plt.matshow(GT,fig, cmap=plt.cm.Greens)
plt.colorbar(label="Spike amplitude")
fig.suptitle(plot_name,fontsize = 12)
plt.xlabel("Query",fontsize = 12)
plt.ylabel("Database",fontsize = 12)
plt.show()
with open('./output/GT.pkl', 'wb') as f:
pickle.dump(GT, f)

self.VPRTempo_correct = 100*self.numcorrect/self.test_t
print(self.VPRTempo_correct,'% correct')

# Clear the network output spikes
blitnet.setSpikeTimes(net,2,[])

# Reset network details
net['set_spks'][0] = []
#net['rec_spks'] = [True,True,True]
net['sspk_idx'] = [0,0,0]
net['step_num'] = 0
net['spikes'] = [[],[],[]]

# Load the testing images
with open('./nordland_testNames.txt') as file:
lines = [line.rstrip() for line in file]
self.img_load = lines
self.test_true = True # Set image path to the testing images
self.loadImages()

@@ -479,12 +563,6 @@ def network_validator():
for i in set_spikes():
sbar()

# load the training and testing IDs for correct matching
with open('./output/train_ids.pkl', 'rb') as f:
self.train_ids = pickle.load(f)
with open('./output/test_ids.pkl', 'rb') as f:
self.test_ids = pickle.load(f)

# run the test netowrk
print('Running testing dataset on trained network')
with alive_bar(self.test_t) as testbar:
@@ -496,10 +574,32 @@ def network_validator():
print(self.VPRTempo_correct,'% correct')

# plot the similarity matrices for each location repetition
append_mat = []
for n in self.mat_dict:
temp_mat = self.mat_dict[str(n)]
plotit(temp_mat)
if int(n) != 0:
append_mat = append_mat + self.mat_dict[str(n)]
else:
append_mat = np.copy(self.mat_dict[str(n)])
plot_name = "VPRTempo"
plotit(append_mat,plot_name)
# pickle the ground truth matrix
S_in = np.reshape(append_mat,(self.test_t,int(self.train_img/self.location_repeat)))
with open('./output/S_in.pkl', 'wb') as f:
pickle.dump(S_in, f)

# calculate the precision of the system
self.precision = self.tp/(self.tp+self.fp)
self.recall = self.tp/self.test_t
P, R = createPR(S_in, GT, GTsoft)
# plot PR curve
fig = plt.figure()
plt.plot(R,P)
fig.suptitle("VPRTempo Precision Recall curve",fontsize = 12)
plt.xlabel("Recall",fontsize = 12)
plt.ylabel("Precision",fontsize = 12)
plt.show()

# plot spikes if they were recorded
if net['set_spks'][0] == True:
blitnet.plotSpikes(net,0)

@@ -510,6 +610,7 @@ def network_validator():
# if validation is set to True, run comparison methods
if self.validation:
network_validator()
pause=1

'''
Run the network
Binary file removed output/test_ids.pkl
Binary file not shown.
Binary file removed output/train_ids.pkl
Binary file not shown.
55 changes: 2 additions & 53 deletions src/validation.py
Original file line number Diff line number Diff line change
@@ -56,62 +56,11 @@ def run_sad():
test_mat = np.concatenate((test_mat,test),axis=0)

# create the similarity matrix
sim_mat = np.copy(np.reshape(test_mat,(self.test_t,self.train_img)))
sort_mat = np.sort(sim_mat,axis=1)
self.sim_mat = np.copy(np.reshape(test_mat,(self.test_t,self.train_img)))

# perform P@0-100%R calculations going through all the threshold values
tp_corr = {}
fp_corr = {}
tn_corr = {}
fn_corr = {}
for xdx, x in enumerate(range(200)):
thresh_vals = sort_mat[:,x]
tp_corr[str(xdx)] = 0
fp_corr[str(xdx)] = 0
tn_corr[str(xdx)] = 0
fn_corr[str(xdx)] = 0

ap_pos = []
for m, mdx in enumerate(thresh_vals):
temp_row = np.copy(sim_mat[m,:])
temp_row[temp_row>mdx] = -1
nonzero = np.where(temp_row>0)

positives = np.argwhere(temp_row!=-1)
ap_pos.append(len(positives[0]))
negatives = np.argwhere(temp_row==-1)

# calculate true positives and false positives
if m in positives or m+self.test_t in positives:
tp_corr[str(xdx)] = tp_corr[str(xdx)] + 1
fp_corr[str(xdx)] = tp_corr[str(xdx)] - 1
if m not in negatives or m+self.test_t not in negatives:
tn_corr[str(xdx)] = tn_corr[str(xdx)] + len(negatives)
fp_corr[str(xdx)] = len(positives) - tp_corr[str(xdx)]
fn_corr[str(xdx)] = len(negatives)-tn_corr[str(xdx)]

tp_corr[str(xdx)] = tp_corr[str(xdx)]/self.test_t
# store the output
self.SAD_correct = 100*(numcorrect/len(self.test_imgs))
tp = np.array([])
fp = np.array([])
for n in range(len(tp_corr)):
tp = np.append(tp,np.array(tp_corr[str(n)]))
fp = np.append(fp,np.array(fp_corr[str(n)]))
precision = (tp)/(tp+fp)
# plot out the PR curve
array_pr = np.array([])
for e in range(len(tp_corr)):
array_pr = np.append(array_pr,np.array(tp_corr[str(e)]))
x_axes = np.linspace(0,1,num=len(tp))
y_axes = np.linspace(0,1,num=6)
fig = plt.figure()
plt.plot(x_axes,np.flip(precision))
plt.ylim(0,1)
fig.suptitle("PR curve SAD",fontsize = 12)
plt.xlabel("Recall",fontsize = 12)
plt.ylabel("Precision",fontsize = 12)
plt.show()


run_sad()
print('Number of correct images with SAD: '+str(self.SAD_correct)+'%')