Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Rare Word F1 #3566

Merged
merged 7 commits into from
Apr 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion parlai/tasks/wizard_of_wikipedia/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
E.g. `wizard_of_wikipedia:WizardDialogKnowledgeTeacher:random_split`
"""

from typing import Optional, Tuple
from __future__ import annotations
from typing import List, Optional, Tuple
from parlai.core.message import Message
from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric
from parlai.core.params import ParlaiParser
Expand Down Expand Up @@ -100,6 +101,67 @@ def _path(opt, split='random_split'):
return os.path.join(dp, df)


class RareWordF1Calculator:
"""
Helper class for computing F1 with an emphasis on infrequent words.
"""

def __init__(self, corpus: str, top_p: float = 0.5):
try:
import nltk
except ImportError:
raise ImportError('Please install nltk (e.g. pip install nltk).')
words = normalize_answer(corpus).split()
self._freq_dist = nltk.FreqDist(words)
self._cutoff_count = RareWordF1Calculator._find_cutoff_count(
self._freq_dist, top_p
)

@staticmethod
def _find_cutoff_count(freq_dist, top_p: float) -> int:
"""
Finds the word occurance for which the cumulative occurances are `top_p` of the
overall word count.
"""
assert top_p < 1
target = sum(freq_dist.values()) * top_p
cumul = 0
for _, v in freq_dist.most_common():
cumul += v
if cumul > target:
return v
raise RuntimeError(f"Invalid top {top_p*100}% of the corpus distribution")

@staticmethod
def _filter(freq_dist, cutoff: int, text: str) -> str:
"""
For words that are found in the reference distribution, filters those with an
occurrence count less than the cutoff.
"""
words = normalize_answer(text).split()
return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff])

def compute(self, guess: str, answers: List[str]) -> F1Metric:
guess = RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, guess)
answers = [
RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a)
for a in answers
]
if not any(len(a) for a in answers):
# no rare words in labels, set denominator to zero
return F1Metric(0, 0)
return F1Metric.compute(guess, answers)


def _build_rare_word_f1(datapath: str) -> RareWordF1Calculator:
all_text = ''
data_path = os.path.join(datapath, 'wizard_of_wikipedia', 'data.json')
with PathManager.open(data_path) as f:
data = json.load(f)
all_text += ' '.join(m['text'] for d in data for m in d['dialog']) + ' '
return RareWordF1Calculator(all_text, top_p=0.5)


class WizardOfWikipediaTeacher(FixedDialogTeacher):
"""
The default teacher; essentially reads the json file and outputs the raw data.
Expand Down Expand Up @@ -210,6 +272,10 @@ def __init__(self, opt, shared=None):
self.knowledge_separator = opt.get('include_knowledge_separator', False)
self.chosen_topic_delimiter = opt.get('chosen_topic_delimiter', '\n')
self.num_exs = sum(self.len_episode(i) for i in range(len(self.data)))
if shared and 'rare_word_f1' in shared:
self.rare_word_f1 = shared['rare_word_f1']
elif self.label_type == 'response':
self.rare_word_f1 = _build_rare_word_f1(opt['datapath'])

@classmethod
def add_cmdline_args(
Expand Down Expand Up @@ -258,6 +324,12 @@ def add_cmdline_args(
)
return parser

def share(self):
shared = super().share()
if hasattr(self, 'rare_word_f1'):
shared['rare_word_f1'] = self.rare_word_f1
return shared

def len_episode(self, ep):
d = self.data[ep]
wizard_first = 'Wizard' in d['dialog'][0]['speaker']
Expand Down Expand Up @@ -390,6 +462,10 @@ def custom_evaluation(
model_response['text'], [teacher_action['checked_sentence']]
),
)
self.metrics.add(
'rare_word_f1',
self.rare_word_f1.compute(model_response['text'], labels),
)
elif (
self.label_type == 'chosen_sent'
and TOKEN_KNOWLEDGE in model_response['text']
Expand Down