-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathevaluate-v0.1.py
70 lines (57 loc) · 2.3 KB
/
evaluate-v0.1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# coding=utf-8
# eval with test set.
d = "given id:idx json test result, evaluate with buckets,generated by getBucket-*.py"
import sys,os,json,operator
import argparse
def get_args():
parser = argparse.ArgumentParser(description=d)
parser.add_argument("qafile",action='store',type=str,help="/path/to/qa.json")
parser.add_argument("testIds",action="store",type=str,help='/path/to/test_question.ids')
parser.add_argument('output',action='store',type=str,help='/path/to/output.json')
parser.add_argument("--bucket",action="store",type=str,default="",help='bucket like question type')
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
qadata = json.load(open(args.qafile,"r"))
testIds = [one.strip() for one in open(args.testIds,"r").readlines()]
gt2answers = {}
for data in qadata:
qid = str(data['question_id'])
if qid in testIds:
# get the answer indx in the choices
gt2answers[qid] = data['multiple_choices_4'].index(data['answer'])
print "got %s ground truth qa"%len(gt2answers)
# the qa group for each group accuracy
groups = []
groupIdTotal = 0
if args.bucket != "":
bucket = json.load(open(args.bucket,"r"))
for group in bucket:
groups.append({"ids":group['ids'],"name":group['name'],"len":len(group['ids'])})
groupIdTotal+=len(group['ids'])
#groups.sort(key=operator.itemgetter("len"))
groups.sort(key=operator.itemgetter("name"))
pred = json.load(open(args.output,"r"))
assert len(pred) == len(gt2answers), ("test output has different number of QA!")
total = len(pred)
correct = 0
for qid in pred:
if int(pred[qid]) == int(gt2answers[qid]):
correct+=1
overall = correct/float(total)
print "Overall acc: %s (%s/%s)"%(overall,correct,total)
separator = " " # stupid ubuntu libreoffice
if(len(groups) > 0):
accs = []
for group in groups:
gtotal = len(group['ids'])
gcorrect = 0
for qid in group['ids']:
if int(pred[qid]) == int(gt2answers[qid]):
gcorrect+=1
accs.append({"name":group['name'],"acc":gcorrect/float(gtotal)})
print "\tgroup %s: %s (%s/%s)"%(group['name'],gcorrect/float(gtotal),gcorrect,gtotal)
# print for fast copy to excel
accs.sort(key=operator.itemgetter("name"))
print separator.join(["%s"%one['name'] for one in accs] + ["overall"])
print separator.join(["%s"%one['acc'] for one in accs] + ["%s"%overall])