-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathevaluate.py
executable file
·151 lines (136 loc) · 4.25 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
import argparse
import stringdist
import sys
def accuracy(words):
"""Word accuracy"""
val = [gold == norm for (gold, norm) in words]
return val
def cer(words):
"""Character error rate (CER), defined as Levenshtein distance normalized by
reference word length."""
val = [
(0 if gold == norm else stringdist.levenshtein(gold, norm) / len(gold))
for (gold, norm) in words
]
return val
def main(args, stemmer=None):
vocab = None
if args.trainfile:
vocab = set((line.strip().split("\t")[0] for line in args.trainfile))
data = []
for refline, normline in zip(args.reffile, args.normfile):
if "\t" not in refline:
continue
orig, gold = refline.strip().split("\t")
norm = normline.strip()
norm = norm.split("\t")[-1] if "\t" in norm else norm
if args.only_incorrect and gold == norm:
continue
if args.only_knowns and orig not in vocab:
continue
if args.only_unknowns and orig in vocab:
continue
if args.stem:
gold, norm = stemmer.stemWords([gold, norm])
data.append((gold, norm))
if args.print:
for gold, norm in data:
print("{}\t{}".format(gold, norm))
return
print(" Tokens: {:d}".format(len(data)))
if not data:
return
acc = accuracy(data)
print("Word accuracy: {:.4f}".format(sum(acc) / len(acc)))
err = cer(data)
print(" Average CER: {:.4f}".format(sum(err) / len(err)))
if __name__ == "__main__":
description = "Evaluate normalization quality."
epilog = ""
parser = argparse.ArgumentParser(description=description, epilog=epilog)
parser.add_argument(
"reffile",
metavar="REFFILE",
type=argparse.FileType("r", encoding="UTF-8"),
help="Reference normalizations in two-column format",
)
parser.add_argument(
"normfile",
metavar="NORMFILE",
type=argparse.FileType("r", encoding="UTF-8"),
help=(
"Predicted normalizations; can be one- or two-column"
" (with normalizations expected in second column)"
),
)
parser.add_argument(
"trainfile",
metavar="TRAINFILE",
type=argparse.FileType("r", encoding="UTF-8"),
nargs="?",
help="Training file in two-column format; required for some options",
)
parser.add_argument(
"-s",
"--stem",
metavar="LANGUAGE",
type=str,
help=(
"Stem word forms before evaluation; "
'LANGUAGE is a language name (e.g., "english")'
),
)
parser.add_argument(
"-i",
"--only-incorrect",
action="store_true",
default=False,
help="Only evaluate on the subset of incorrect normalizations",
)
parser.add_argument(
"-k",
"--only-knowns",
action="store_true",
default=False,
help=(
"Only evaluate on the subset of known/in-vocabulary tokens;"
" requires TRAINFILE"
),
)
parser.add_argument(
"-u",
"--only-unknowns",
action="store_true",
default=False,
help=(
"Only evaluate on the subset of unknown/out-of-vocabulary "
"tokens; requires TRAINFILE"
),
)
parser.add_argument(
"--print",
action="store_true",
default=False,
help=("Print the data that will be compared"),
)
if len(sys.argv) < 2:
parser.print_help()
exit(1)
args = parser.parse_args()
if args.only_knowns and args.only_unknowns:
parser.error("can't select both --only-knowns and --only-unknowns")
if (args.only_knowns or args.only_unknowns) and not args.trainfile:
parser.error("--only-knowns/--only-unknowns requires TRAINFILE")
stemmer = None
if args.stem:
import Stemmer
try:
stemmer = Stemmer.Stemmer(args.stem.lower())
except KeyError:
parser.error(
"No stemming algorithm for '{}'; valid choices are: {}".format(
args.stem, ", ".join(Stemmer.algorithms())
)
)
main(args, stemmer=stemmer)