-
Notifications
You must be signed in to change notification settings - Fork 19
/
calculate_distance.py
52 lines (38 loc) · 1.26 KB
/
calculate_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import sys
import cPickle as pickle
import datetime, math, sys, time
from sklearn.datasets import fetch_mldata
import numpy as np
import cupy as cp
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import FunctionSet, Variable, optimizers, cuda, serializers
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, help = 'which gpu device to use', default = 1)
parser.add_argument('--dataset', type=str, default = 'mnist')
args = parser.parse_args()
chainer.cuda.get_device(args.gpu).use()
if args.dataset == 'mnist':
sys.path.append('mnist')
from load_mnist import *
whole = load_mnist_whole(PATH = 'mnist/', scale=1.0/128.0, shift=-1.0)
else:
print 'The dataset is not supported.'
exit(-1)
data = cuda.to_gpu(whole.data)
num_data = [10]
print num_data
dist_accum = 0
dist_list = [[] for i in range(len(num_data))]
for i in range(len(data)):
if i % 1000 == 0:
print i
dist = cp.sqrt(cp.sum((data - data[i])**2, axis = 1))
dist[i] = 1000
sorted_dist = np.sort(cuda.to_cpu(dist))
for j in range(len(num_data)):
dist_list[j].append(sorted_dist[num_data[j]])
for i in range(len(num_data)):
np.savetxt(args.dataset + '/' + str(num_data[i]) + 'th_neighbor.txt', np.array(dist_list[i]))