forked from ShomyLiu/pytorch-relation-extraction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
74 lines (59 loc) · 1.93 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
import numpy as np
import time
def now():
return str(time.strftime('%Y-%m-%d %H:%M:%S'))
def save_pr(out_dir, name, epoch, pre, rec, fp_res=None, opt=None):
if opt is None:
out = open('{}/{}_{}_PR.txt'.format(out_dir, name, epoch + 1), 'w')
else:
out = open('{}/{}_{}_{}_PR.txt'.format(out_dir, name, opt, epoch + 1), 'w')
if fp_res is not None:
fp_out = open('{}/{}_{}_FP.txt'.format(out_dir, name, epoch + 1), 'w')
for idx, r, p in fp_res:
fp_out.write('{} {} {}\n'.format(idx, r, p))
fp_out.close()
for p, r in zip(pre, rec):
out.write('{} {}\n'.format(p, r))
out.close()
def eval_metric(true_y, pred_y, pred_p):
'''
calculate the precision and recall for p-r curve
reglect the NA relation
'''
assert len(true_y) == len(pred_y)
positive_num = len([i for i in true_y if i[0] > 0])
index = np.argsort(pred_p)[::-1]
tp = 0
fp = 0
fn = 0
all_pre = [0]
all_rec = [0]
fp_res = []
for idx in range(len(true_y)):
i = true_y[index[idx]]
j = pred_y[index[idx]]
if i[0] == 0: # NA relation
if j > 0:
fp_res.append((index[idx], j, pred_p[index[idx]]))
fp += 1
else:
if j == 0:
fn += 1
else:
for k in i:
if k == -1:
break
if k == j:
tp += 1
break
if fp + tp == 0:
precision = 1.0
else:
precision = tp * 1.0 / (tp + fp)
recall = tp * 1.0 / positive_num
if precision != all_pre[-1] or recall != all_rec[-1]:
all_pre.append(precision)
all_rec.append(recall)
print("tp={}; fp={}; fn={}; positive_num={}".format(tp, fp, fn, positive_num))
return all_pre[1:], all_rec[1:], fp_res