-
Notifications
You must be signed in to change notification settings - Fork 191
/
Copy pathspan_based_f1.py
308 lines (264 loc) · 14.4 KB
/
span_based_f1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
# -*- coding: utf-8 -*-
# @Author: Wenwen Yu
# @Created Time: 7/11/2020 10:39 PM
from typing import *
from collections import defaultdict
import torch
from torchtext.vocab import Vocab
from allennlp.common.checks import ConfigurationError
from allennlp.nn.util import get_lengths_from_binary_sequence_mask
from allennlp.training.metrics.metric import Metric
from allennlp.data.dataset_readers.dataset_utils.span_utils import (
bio_tags_to_spans,
bioul_tags_to_spans,
iob1_tags_to_spans,
bmes_tags_to_spans,
TypedStringSpan
)
'''
Copy-paste from allennlp.training.metrics.span_based_f1_measure
with modifications:
* add accuracy meature mEA (mean Entity Accuracy)
* rename precision, recall, f1 to mEP, mER, mEF
* numerical stability
'''
TAGS_TO_SPANS_FUNCTION_TYPE = Callable[
[List[str], Optional[List[str]]], List[TypedStringSpan]] # pylint: disable=invalid-name
class SpanBasedF1Measure(Metric):
"""
The Conll SRL metrics are based on exact span matching. This metric
implements span-based precision and recall metrics for a BIO tagging
scheme. It will produce precision, recall and F1 measures per tag, as
well as overall statistics. Note that the implementation of this metric
is not exactly the same as the perl script used to evaluate the CONLL 2005
data - particularly, it does not consider continuations or reference spans
as constituents of the original span. However, it is a close proxy, which
can be helpful for judging model performance during training. This metric
works properly when the spans are unlabeled (i.e., your labels are
simply "B", "I", "O" if using the "BIO" label encoding).
"""
def __init__(self,
vocab: Vocab = None,
ignore_classes: List[str] = None,
label_encoding: Optional[str] = "BIO",
tags_to_spans_function: Optional[TAGS_TO_SPANS_FUNCTION_TYPE] = None) -> None:
"""
Parameters
----------
vocabulary : ``Vocabulary``, required.
A vocabulary containing the tag namespace.
tag_namespace : str, required.
This metric assumes that a BIO format is used in which the
labels are of the format: ["B-LABEL", "I-LABEL"].
ignore_classes : List[str], optional.
Span labels which will be ignored when computing span metrics.
A "span label" is the part that comes after the BIO label, so it
would be "ARG1" for the tag "B-ARG1". For example by passing:
``ignore_classes=["V"]``
the following sequence would not consider the "V" span at index (2, 3)
when computing the precision, recall and F1 metrics.
["O", "O", "B-V", "I-V", "B-ARG1", "I-ARG1"]
This is helpful for instance, to avoid computing metrics for "V"
spans in a BIO tagging scheme which are typically not included.
label_encoding : ``str``, optional (default = "BIO")
The encoding used to specify label span endpoints in the sequence.
Valid options are "BIO", "IOB1", "BIOUL" or "BMES".
tags_to_spans_function: ``Callable``, optional (default = ``None``)
If ``label_encoding`` is ``None``, ``tags_to_spans_function`` will be
used to generate spans.
"""
if label_encoding and tags_to_spans_function:
raise ConfigurationError(
'Both label_encoding and tags_to_spans_function are provided. '
'Set "label_encoding=None" explicitly to enable tags_to_spans_function.'
)
if label_encoding:
if label_encoding not in ["BIO", "IOB1", "BIOUL", "BMES"]:
raise ConfigurationError("Unknown label encoding - expected 'BIO', 'IOB1', 'BIOUL', 'BMES'.")
elif tags_to_spans_function is None:
raise ConfigurationError(
'At least one of the (label_encoding, tags_to_spans_function) should be provided.'
)
self._label_encoding = label_encoding
self._tags_to_spans_function = tags_to_spans_function
self._label_vocabulary = vocab
self._ignore_classes: List[str] = ignore_classes or []
# These will hold per label span counts.
self._true_positives: Dict[str, int] = defaultdict(int)
self._false_positives: Dict[str, int] = defaultdict(int)
self._false_negatives: Dict[str, int] = defaultdict(int)
self._total: Dict[str, int] = defaultdict(int)
self.mapped_class = []
for k, v in self._label_vocabulary.stoi.items():
if k == '<pad>' or k == '<unk>':
self.mapped_class.append(self._label_vocabulary.stoi['O'])
else:
self.mapped_class.append(v)
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.Tensor] = None,
prediction_map: Optional[torch.Tensor] = None):
"""
Parameters
----------
predictions : ``torch.Tensor``, required.
A tensor of predictions of shape (batch_size, sequence_length, num_classes).
gold_labels : ``torch.Tensor``, required.
A tensor of integer class label of shape (batch_size, sequence_length). It must be the same
shape as the ``predictions`` tensor without the ``num_classes`` dimension.
mask: ``torch.Tensor``, optional (default = None).
A masking tensor the same size as ``gold_labels``.
prediction_map: ``torch.Tensor``, optional (default = None).
A tensor of size (batch_size, num_classes) which provides a mapping from the index of predictions
to the indices of the label vocabulary. If provided, the output label at each timestep will be
``vocabulary.get_index_to_token_vocabulary(prediction_map[batch, argmax(predictions[batch, t]))``,
rather than simply ``vocabulary.get_index_to_token_vocabulary(argmax(predictions[batch, t]))``.
This is useful in cases where each Instance in the dataset is associated with a different possible
subset of labels from a large label-space (IE FrameNet, where each frame has a different set of
possible roles associated with it).
"""
if mask is None:
mask = torch.ones_like(gold_labels)
predictions, gold_labels, mask, prediction_map = self.detach_tensors(predictions,
gold_labels,
mask, prediction_map)
num_classes = predictions.size(-1)
if (gold_labels >= num_classes).any():
raise ConfigurationError("A gold label passed to SpanBasedF1Measure contains an "
"id >= {}, the number of classes.".format(num_classes))
sequence_lengths = get_lengths_from_binary_sequence_mask(mask).long()
argmax_predictions = predictions.max(-1)[1]
if prediction_map is None:
batch_size = gold_labels.size(0)
prediction_map = torch.tensor([self.mapped_class for i in range(batch_size)]).long().to(gold_labels.device)
argmax_predictions = torch.gather(prediction_map, 1, argmax_predictions)
gold_labels = torch.gather(prediction_map, 1, gold_labels.long())
argmax_predictions = argmax_predictions.float()
# Iterate over timesteps in batch.
batch_size = gold_labels.size(0)
for i in range(batch_size):
sequence_prediction = argmax_predictions[i, :]
sequence_gold_label = gold_labels[i, :]
length = sequence_lengths[i]
if length == 0:
# It is possible to call this metric with sequences which are
# completely padded. These contribute nothing, so we skip these rows.
continue
predicted_string_labels = [self._label_vocabulary.itos[int(label_id)]
for label_id in sequence_prediction[:length].tolist()]
gold_string_labels = [self._label_vocabulary.itos[int(label_id)]
for label_id in sequence_gold_label[:length].tolist()]
# print('pred_str: {} \n gold_str: {}'.format(predicted_string_labels, gold_string_labels))
tags_to_spans_function = None
# `label_encoding` is empty and `tags_to_spans_function` is provided.
if self._label_encoding is None and self._tags_to_spans_function:
tags_to_spans_function = self._tags_to_spans_function
# Search by `label_encoding`.
elif self._label_encoding == "BIO":
tags_to_spans_function = bio_tags_to_spans
elif self._label_encoding == "IOB1":
tags_to_spans_function = iob1_tags_to_spans
elif self._label_encoding == "BIOUL":
tags_to_spans_function = bioul_tags_to_spans
elif self._label_encoding == "BMES":
tags_to_spans_function = bmes_tags_to_spans
predicted_spans = tags_to_spans_function(predicted_string_labels, self._ignore_classes)
gold_spans = tags_to_spans_function(gold_string_labels, self._ignore_classes)
predicted_spans = self._handle_continued_spans(predicted_spans)
gold_spans = self._handle_continued_spans(gold_spans)
for span in gold_spans:
self._total[span[0]] += 1
for span in predicted_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
# These spans weren't predicted.
for span in gold_spans:
self._false_negatives[span[0]] += 1
@staticmethod
def _handle_continued_spans(spans: List[TypedStringSpan]) -> List[TypedStringSpan]:
"""
The official CONLL 2012 evaluation script for SRL treats continued spans (i.e spans which
have a `C-` prepended to another valid tag) as part of the span that they are continuing.
This is basically a massive hack to allow SRL models which produce a linear sequence of
predictions to do something close to structured prediction. However, this means that to
compute the metric, these continuation spans need to be merged into the span to which
they refer. The way this is done is to simply consider the span for the continued argument
to start at the start index of the first occurrence of the span and end at the end index
of the last occurrence of the span. Handling this is important, because predicting continued
spans is difficult and typically will effect overall average F1 score by ~ 2 points.
Parameters
----------
spans : ``List[TypedStringSpan]``, required.
A list of (label, (start, end)) spans.
Returns
-------
A ``List[TypedStringSpan]`` with continued arguments replaced with a single span.
"""
span_set: Set[TypedStringSpan] = set(spans)
continued_labels: List[str] = [label[2:] for (label, span) in span_set if label.startswith("C-")]
for label in continued_labels:
continued_spans = {span for span in span_set if label in span[0]}
span_start = min(span[1][0] for span in continued_spans)
span_end = max(span[1][1] for span in continued_spans)
replacement_span: TypedStringSpan = (label, (span_start, span_end))
span_set.difference_update(continued_spans)
span_set.add(replacement_span)
return list(span_set)
def get_metric(self, reset: bool = False):
"""
Returns
-------
A Dict per label containing following the span based metrics:
precision : float
recall : float
f1-measure : float
Additionally, an ``overall`` key is included, which provides the precision,
recall and f1-measure for all spans.
"""
all_tags: Set[str] = set()
all_tags.update(self._true_positives.keys())
all_tags.update(self._false_positives.keys())
all_tags.update(self._false_negatives.keys())
all_tags.update(self._total.keys())
all_metrics = {}
for tag in all_tags:
precision, recall, f1_measure = self._compute_metrics(self._true_positives[tag],
self._false_positives[tag],
self._false_negatives[tag])
precision_key = "mEP" + "-" + tag
recall_key = "mER" + "-" + tag
f1_key = "mEF" + "-" + tag
accuracy_key = "mEA" + "-" + tag
all_metrics[precision_key] = precision
all_metrics[recall_key] = recall
all_metrics[f1_key] = f1_measure
all_metrics[accuracy_key] = self._true_positives[tag] / (self._total[tag] + 1e-13)
# Compute the precision, recall and f1 for all spans jointly.
precision, recall, f1_measure = self._compute_metrics(sum(self._true_positives.values()),
sum(self._false_positives.values()),
sum(self._false_negatives.values()))
all_metrics["mEP-overall"] = precision
all_metrics["mER-overall"] = recall
all_metrics["mEF-overall"] = f1_measure
if sum(self._total.values()) != 0:
all_metrics["mAE-overall"] = sum(self._true_positives.values()) / sum(self._total.values())
else:
all_metrics["mAE-overall"] = 0
if reset:
self.reset()
return all_metrics
@staticmethod
def _compute_metrics(true_positives: int, false_positives: int, false_negatives: int):
precision = float(true_positives) / float(true_positives + false_positives + 1e-13)
recall = float(true_positives) / float(true_positives + false_negatives + 1e-13)
f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
return precision, recall, f1_measure
def reset(self):
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)
self._total = defaultdict(int)