Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch integration #3

Merged
merged 11 commits into from
Jul 27, 2023
Prev Previous commit
Next Next commit
Working on PR curves
AdamDHines committed Jul 11, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 4602e793e97478c71f29e0db4cebdc00a800dd76
5 changes: 4 additions & 1 deletion VPRTempo.py
Original file line number Diff line number Diff line change
@@ -389,6 +389,7 @@ def test_network():
self.correct_idx = []
self.numcorrect = 0
self.mat_dict = {}
self.net_x = np.array([])
idx = []
for n in range(self.location_repeat):
self.mat_dict[str(n)] = []
@@ -405,6 +406,7 @@ def test_network():
self.numcorrect = self.numcorrect+1
self.correct_idx.append(t)
tonump = net['x'][-1].detach().cpu().numpy()
self.net_x = np.concatenate((self.net_x,tonump),axis=0)
split_mat = np.split(tonump,self.location_repeat)
for n in range(self.location_repeat):
self.mat_dict[str(idx[n])]= np.concatenate((self.mat_dict[str(idx[n])],split_mat[n]),axis=0)
@@ -473,11 +475,12 @@ def network_validator():
reshape_mat = np.reshape(temp_mat,(self.test_t,int(self.train_img/self.location_repeat)))
# plot the matrix
fig = plt.figure()
plt.matshow(reshape_mat,fig)
plt.matshow(reshape_mat,fig, cmap=plt.cm.prism)
fig.suptitle(("Similarity VPRTempo: Location "+str(int(n)+1)),fontsize = 12)
plt.xlabel("Database",fontsize = 12)
plt.ylabel("Query",fontsize = 12)
plt.show()

if net['set_spks'][0] == True:
blitnet.plotSpikes(net,0)

Binary file modified output/test_ids.pkl
Binary file not shown.
Binary file modified output/train_ids.pkl
Binary file not shown.
Binary file modified src/__pycache__/blitnet_open.cpython-311.pyc
Binary file not shown.
57 changes: 53 additions & 4 deletions src/validation.py
Original file line number Diff line number Diff line change
@@ -55,15 +55,64 @@ def run_sad():

test_mat = np.concatenate((test_mat,test),axis=0)

sim_mat = np.reshape(test_mat,(self.test_t,self.train_img))
sort_mat = np.sort(sim_mat,axis=1)[:,::-1]
for x in range(len(sim_mat[1])):
# create the similarity matrix
sim_mat = np.copy(np.reshape(test_mat,(self.test_t,self.train_img)))
sort_mat = np.sort(sim_mat,axis=1)

# perform P@0-100%R calculations going through all the threshold values
tp_corr = {}
fp_corr = {}
tn_corr = {}
fn_corr = {}
for xdx, x in enumerate(range(200)):
thresh_vals = sort_mat[:,x]
tp_corr[str(xdx)] = 0
fp_corr[str(xdx)] = 0
tn_corr[str(xdx)] = 0
fn_corr[str(xdx)] = 0


ap_pos = []
for m, mdx in enumerate(thresh_vals):
temp_row = np.copy(sim_mat[m,:])
temp_row[temp_row>mdx] = -1
nonzero = np.where(temp_row>0)

positives = np.argwhere(temp_row!=-1)
ap_pos.append(len(positives[0]))
negatives = np.argwhere(temp_row==-1)

# calculate true positives and false positives
if m in positives or m+self.test_t in positives:
tp_corr[str(xdx)] = tp_corr[str(xdx)] + 1
fp_corr[str(xdx)] = tp_corr[str(xdx)] - 1
if m not in negatives or m+self.test_t not in negatives:
tn_corr[str(xdx)] = tn_corr[str(xdx)] + len(negatives)
fp_corr[str(xdx)] = len(positives) - tp_corr[str(xdx)]
fn_corr[str(xdx)] = len(negatives)-tn_corr[str(xdx)]

tp_corr[str(xdx)] = tp_corr[str(xdx)]/self.test_t
# store the output
self.SAD_correct = 100*(numcorrect/len(self.test_imgs))
tp = np.array([])
fp = np.array([])
for n in range(len(tp_corr)):
tp = np.append(tp,np.array(tp_corr[str(n)]))
fp = np.append(fp,np.array(fp_corr[str(n)]))
precision = (tp)/(tp+fp)
# plot out the PR curve
array_pr = np.array([])
for e in range(len(tp_corr)):
array_pr = np.append(array_pr,np.array(tp_corr[str(e)]))
x_axes = np.linspace(0,1,num=len(tp))
y_axes = np.linspace(0,1,num=6)
fig = plt.figure()
plt.plot(x_axes,np.flip(precision))
plt.ylim(0,1)
fig.suptitle("PR curve SAD",fontsize = 12)
plt.xlabel("Recall",fontsize = 12)
plt.ylabel("Precision",fontsize = 12)
plt.show()

run_sad()
print('Number of correct images with SAD: '+str(self.SAD_correct)+'%')