-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathstats.py
98 lines (86 loc) · 2.98 KB
/
stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import sys
import os
import numpy as np
import argparse
import cPickle as pickle
from triples import Triples
class Stats:
def __init__(self, t, usePR=False):
self.t = t
self.setStats(usePR)
np.random.seed(123123)
def setStats(self, usePR):
ne = len(self.t.eNames)
nr = len(self.t.rNames)
self.eFreq = np.zeros((ne, ), dtype='int32')
self.rFreq = np.zeros((nr, ), dtype='int32')
for s, o, p in self.t.train:
self.eFreq[s] += 1
self.eFreq[o] += 1
self.rFreq[p] += 1
if usePR:
print "Using pagerank"
self.eIndices = (-self.t.pagerank).argsort()
else:
self.eIndices = (-self.eFreq).argsort()
self.rIndices = (-self.rFreq).argsort()
def getEnts(self, rankBand, nSamples):
if rankBand[1] < 0:
x = np.arange(self.eIndices.shape[0]-rankBand[0])
else:
x = np.arange(rankBand[1]-rankBand[0])
np.random.shuffle(x)
x = x[:nSamples] + rankBand[0]
return self.eIndices[x]
def getRels(self, rankBand, nSamples):
if rankBand[1] < 0:
x = np.arange(self.rIndices.shape[0]-rankBand[0])
else:
x = np.arange(rankBand[1]-rankBand[0])
np.random.shuffle(x)
x = x[:nSamples] + rankBand[0]
return self.rIndices[x]
def getParser():
parser = argparse.ArgumentParser(description="parser for arguments")
parser.add_argument("-d", "--datafile", type=str, help="pickled triple file", required=True)
parser.add_argument("-o", "--outfile", type=str, help="file to save stats")
return parser
def main():
parser = getParser()
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(1)
t = Triples(args.datafile)
stats = Stats(t)
rRanges = [((0,50), 50), ((50,100), 50), ((100,200), 100), ((200, 500), 300), ((500,t.nr), t.nr-500)]
idxSets = []
for rankBand, ns in rRanges:
idxSets.append(stats.getRels(rankBand, ns))
rels = []
cats = []
for idxSet in idxSets:
cur_rels = []
cur_cats = {}
for idx in idxSet:
cur_rels.append(t.rNames[idx])
cat = t.rNames[idx].split("/")[1]
cur_cats[cat] = cur_cats.get(cat,0)+1
rels.append(cur_rels)
cats.append(cur_cats)
if args.outfile:
with open(args.outfile+'.rel.txt', "w") as fout:
for idx, cur_rels in enumerate(rels):
fout.write("%s\n"%str(rRanges[idx][0]))
for rel in cur_rels:
fout.write("%s\n"%rel)
with open(args.outfile+'.cat.txt', "w") as fout:
for idx, cur_cats in enumerate(cats):
fout.write("%s\n"%str(rRanges[idx][0]))
for cat, count in cur_cats.iteritems():
fout.write("%s:%d\t"%(cat, count))
fout.write("\n")
import pdb; pdb.set_trace()
if __name__ == "__main__":
main()