From 6b6d65a72e667e2defd257b4522d468bd209a694 Mon Sep 17 00:00:00 2001 From: Spencer Poff Date: Wed, 31 Mar 2021 15:30:48 -0700 Subject: [PATCH 1/6] rare word f1 --- parlai/core/metrics.py | 66 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 13a2bf1d3d8..2bd20555d7c 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -530,6 +530,72 @@ def compute(guess: str, answers: List[str]) -> F1Metric: return F1Metric(max(f1 for p, r, f1 in scores), 1) +class RareWordF1Metric: + """ + Helper class for computing F1 with an emphasis on infrequent words. + """ + + def __init__(self, corpus: str, top_p: Optional[float] = 0.5): + try: + import nltk + except ImportError: + raise ImportError('Please install nltk (e.g. pip install nltk).') + words = normalize_answer(corpus).split() + self._freq_dist = nltk.FreqDist(words) + self._cutoff = RareWordF1Metric._find_cutoff_count(self._freq_dist, top_p) + + @staticmethod + def _find_cutoff_count(freq_dist, top_p: float) -> int: + """ + Finds the word occurance for which the cumulative occurances + are `top_p` of the overall word count. + """ + assert top_p < 1 + target = sum(freq_dist.values()) * top_p + cumul = 0 + for _, v in freq_dist.most_common(): + cumul += v + if cumul > target: + return v + + @staticmethod + def _rarity_weight(freq_dist, word: str, cutoff: int) -> float: + """ + A score multiplier that signifies how rare a word is. + The words with more than `cutoff` occurances in the corpus will + have a weight of 0, and very rare words will have a frequency near 1. + """ + return max(0, (cutoff - freq_dist[word]) / cutoff) + + def _weighted_f1_score(self, pred_items: List[str], gold_items: List[str]) -> float: + """ + Compute f1 given a set of gold and prediction items, weighted by the infrequency of each word. + """ + weights = { + w: RareWordF1Metric._rarity_weight( + freq_dist=self._freq_dist, word=w, cutoff=self._cutoff + ) + for w in set(pred_items + gold_items) + } + common = Counter(gold_items) & Counter(pred_items) + weighted_common = {w: c * weights[w] for w, c in common.items()} + true_pos_score = sum(weighted_common.values()) + if true_pos_score == 0: + return 0 + precision = true_pos_score / sum(weights[w] for w in pred_items) + recall = true_pos_score / sum(weights[w] for w in gold_items) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + def compute(self, guess: str, answers: List[str]) -> F1Metric: + g_tokens = normalize_answer(guess).split() + scores = [ + self._weighted_f1_score(g_tokens, normalize_answer(a).split()) + for a in answers + ] + return F1Metric(max(scores), 1) + + class ExactMatchMetric(AverageMetric): @staticmethod def compute(guess: str, answers: List[str]) -> ExactMatchMetric: From 071b61fedc05c2ece522d0bc7cda7f5c8a80b601 Mon Sep 17 00:00:00 2001 From: Spencer Poff Date: Thu, 1 Apr 2021 08:26:30 -0700 Subject: [PATCH 2/6] small docstring fix --- parlai/core/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 2bd20555d7c..16080c85036 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -563,7 +563,7 @@ def _rarity_weight(freq_dist, word: str, cutoff: int) -> float: """ A score multiplier that signifies how rare a word is. The words with more than `cutoff` occurances in the corpus will - have a weight of 0, and very rare words will have a frequency near 1. + have a weight of 0, and very rare words will have a weight near 1. """ return max(0, (cutoff - freq_dist[word]) / cutoff) From a8f507eed11dc495e544c3e2826cecd747d64372 Mon Sep 17 00:00:00 2001 From: Spencer Poff Date: Tue, 6 Apr 2021 20:59:00 -0700 Subject: [PATCH 3/6] simplify and add to wizard --- parlai/core/metrics.py | 47 +++++++--------------- parlai/tasks/wizard_of_wikipedia/agents.py | 30 +++++++++++++- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 16080c85036..7dfbef25911 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -530,19 +530,21 @@ def compute(guess: str, answers: List[str]) -> F1Metric: return F1Metric(max(f1 for p, r, f1 in scores), 1) -class RareWordF1Metric: +class RareWordF1Calculator: """ Helper class for computing F1 with an emphasis on infrequent words. """ - def __init__(self, corpus: str, top_p: Optional[float] = 0.5): + def __init__(self, corpus: str, top_p: float = 0.5): try: import nltk except ImportError: raise ImportError('Please install nltk (e.g. pip install nltk).') words = normalize_answer(corpus).split() self._freq_dist = nltk.FreqDist(words) - self._cutoff = RareWordF1Metric._find_cutoff_count(self._freq_dist, top_p) + self._cutoff_count = RareWordF1Calculator._find_cutoff_count( + self._freq_dist, top_p + ) @staticmethod def _find_cutoff_count(freq_dist, top_p: float) -> int: @@ -557,43 +559,24 @@ def _find_cutoff_count(freq_dist, top_p: float) -> int: cumul += v if cumul > target: return v + raise RuntimeError(f"Invalid top {top_p*100}% of the corpus distribution") @staticmethod - def _rarity_weight(freq_dist, word: str, cutoff: int) -> float: - """ - A score multiplier that signifies how rare a word is. - The words with more than `cutoff` occurances in the corpus will - have a weight of 0, and very rare words will have a weight near 1. + def _filter(freq_dist, cutoff: int, text: str) -> str: """ - return max(0, (cutoff - freq_dist[word]) / cutoff) - - def _weighted_f1_score(self, pred_items: List[str], gold_items: List[str]) -> float: + For words that are found in the reference distribution, filters those + with an occurrence count less than the cutoff. """ - Compute f1 given a set of gold and prediction items, weighted by the infrequency of each word. - """ - weights = { - w: RareWordF1Metric._rarity_weight( - freq_dist=self._freq_dist, word=w, cutoff=self._cutoff - ) - for w in set(pred_items + gold_items) - } - common = Counter(gold_items) & Counter(pred_items) - weighted_common = {w: c * weights[w] for w, c in common.items()} - true_pos_score = sum(weighted_common.values()) - if true_pos_score == 0: - return 0 - precision = true_pos_score / sum(weights[w] for w in pred_items) - recall = true_pos_score / sum(weights[w] for w in gold_items) - f1 = (2 * precision * recall) / (precision + recall) - return f1 + words = normalize_answer(text).split() + return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff]) def compute(self, guess: str, answers: List[str]) -> F1Metric: - g_tokens = normalize_answer(guess).split() - scores = [ - self._weighted_f1_score(g_tokens, normalize_answer(a).split()) + guess = RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, guess) + answers = [ + RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a) for a in answers ] - return F1Metric(max(scores), 1) + return F1Metric.compute(guess, answers) class ExactMatchMetric(AverageMetric): diff --git a/parlai/tasks/wizard_of_wikipedia/agents.py b/parlai/tasks/wizard_of_wikipedia/agents.py index 29fbe815c6d..a5bb969378a 100644 --- a/parlai/tasks/wizard_of_wikipedia/agents.py +++ b/parlai/tasks/wizard_of_wikipedia/agents.py @@ -17,7 +17,12 @@ from typing import Optional, Tuple from parlai.core.message import Message -from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric +from parlai.core.metrics import ( + AverageMetric, + normalize_answer, + F1Metric, + RareWordF1Calculator, +) from parlai.core.params import ParlaiParser from parlai.core.opt import Opt import copy @@ -181,6 +186,15 @@ def share(self): ############################################################### +def _build_rare_word_f1(datapath: str) -> RareWordF1Calculator: + all_text = '' + data_path = os.path.join(datapath, 'wizard_of_wikipedia', 'data.json') + with PathManager.open(data_path) as f: + data = json.load(f) + all_text += ' '.join(m['text'] for d in data for m in d['dialog']) + ' ' + return RareWordF1Calculator(all_text, top_p=0.5) + + class WizardDialogKnowledgeTeacher(WizardOfWikipediaTeacher): """ Teacher that returns the following action dict: @@ -210,6 +224,10 @@ def __init__(self, opt, shared=None): self.knowledge_separator = opt.get('include_knowledge_separator', False) self.chosen_topic_delimiter = opt.get('chosen_topic_delimiter', '\n') self.num_exs = sum(self.len_episode(i) for i in range(len(self.data))) + if shared and 'rare_word_f1' in shared: + self.rare_word_f1 = shared['rare_word_f1'] + elif self.label_type == 'response': + self.rare_word_f1 = _build_rare_word_f1(opt['datapath']) @classmethod def add_cmdline_args( @@ -258,6 +276,12 @@ def add_cmdline_args( ) return parser + def share(self): + shared = super().share() + if hasattr(self, 'rare_word_f1'): + shared['rare_word_f1'] = self.rare_word_f1 + return shared + def len_episode(self, ep): d = self.data[ep] wizard_first = 'Wizard' in d['dialog'][0]['speaker'] @@ -390,6 +414,10 @@ def custom_evaluation( model_response['text'], [teacher_action['checked_sentence']] ), ) + self.metrics.add( + 'rare_word_f1', + self.rare_word_f1.compute(model_response['text'], labels), + ) elif ( self.label_type == 'chosen_sent' and TOKEN_KNOWLEDGE in model_response['text'] From cce7f210028fa88182f3b6346b24c89135f6bd96 Mon Sep 17 00:00:00 2001 From: Spencer Poff Date: Mon, 12 Apr 2021 16:14:07 -0700 Subject: [PATCH 4/6] don't count labels without rare words --- parlai/core/metrics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 7dfbef25911..a7ca0e6762d 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -576,6 +576,9 @@ def compute(self, guess: str, answers: List[str]) -> F1Metric: RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a) for a in answers ] + if not any(len(a) for a in answers): + # no rare words in labels, set denominator to zero + return F1Metric(0, 0) return F1Metric.compute(guess, answers) From 337eca04118920e5e9b0b8a5005fe6b9a6ad4ccf Mon Sep 17 00:00:00 2001 From: Spencer Poff Date: Tue, 13 Apr 2021 20:42:04 -0700 Subject: [PATCH 5/6] move to wizard --- parlai/core/metrics.py | 52 -------------- parlai/tasks/wizard_of_wikipedia/agents.py | 80 +++++++++++++++++----- 2 files changed, 64 insertions(+), 68 deletions(-) diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 160690381fd..b29c276b825 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -530,58 +530,6 @@ def compute(guess: str, answers: List[str]) -> F1Metric: return F1Metric(max(f1 for p, r, f1 in scores), 1) -class RareWordF1Calculator: - """ - Helper class for computing F1 with an emphasis on infrequent words. - """ - - def __init__(self, corpus: str, top_p: float = 0.5): - try: - import nltk - except ImportError: - raise ImportError('Please install nltk (e.g. pip install nltk).') - words = normalize_answer(corpus).split() - self._freq_dist = nltk.FreqDist(words) - self._cutoff_count = RareWordF1Calculator._find_cutoff_count( - self._freq_dist, top_p - ) - - @staticmethod - def _find_cutoff_count(freq_dist, top_p: float) -> int: - """ - Finds the word occurance for which the cumulative occurances - are `top_p` of the overall word count. - """ - assert top_p < 1 - target = sum(freq_dist.values()) * top_p - cumul = 0 - for _, v in freq_dist.most_common(): - cumul += v - if cumul > target: - return v - raise RuntimeError(f"Invalid top {top_p*100}% of the corpus distribution") - - @staticmethod - def _filter(freq_dist, cutoff: int, text: str) -> str: - """ - For words that are found in the reference distribution, filters those - with an occurrence count less than the cutoff. - """ - words = normalize_answer(text).split() - return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff]) - - def compute(self, guess: str, answers: List[str]) -> F1Metric: - guess = RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, guess) - answers = [ - RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a) - for a in answers - ] - if not any(len(a) for a in answers): - # no rare words in labels, set denominator to zero - return F1Metric(0, 0) - return F1Metric.compute(guess, answers) - - class ExactMatchMetric(AverageMetric): @staticmethod def compute(guess: str, answers: List[str]) -> ExactMatchMetric: diff --git a/parlai/tasks/wizard_of_wikipedia/agents.py b/parlai/tasks/wizard_of_wikipedia/agents.py index a5bb969378a..299cfb75e0b 100644 --- a/parlai/tasks/wizard_of_wikipedia/agents.py +++ b/parlai/tasks/wizard_of_wikipedia/agents.py @@ -15,14 +15,10 @@ E.g. `wizard_of_wikipedia:WizardDialogKnowledgeTeacher:random_split` """ -from typing import Optional, Tuple +from __future__ import annotations +from typing import List, Optional, Tuple from parlai.core.message import Message -from parlai.core.metrics import ( - AverageMetric, - normalize_answer, - F1Metric, - RareWordF1Calculator, -) +from parlai.core.metrics import AverageMetric, normalize_answer, F1Metric from parlai.core.params import ParlaiParser from parlai.core.opt import Opt import copy @@ -105,6 +101,67 @@ def _path(opt, split='random_split'): return os.path.join(dp, df) +class RareWordF1Calculator: + """ + Helper class for computing F1 with an emphasis on infrequent words. + """ + + def __init__(self, corpus: str, top_p: float = 0.5): + try: + import nltk + except ImportError: + raise ImportError('Please install nltk (e.g. pip install nltk).') + words = normalize_answer(corpus).split() + self._freq_dist = nltk.FreqDist(words) + self._cutoff_count = RareWordF1Calculator._find_cutoff_count( + self._freq_dist, top_p + ) + + @staticmethod + def _find_cutoff_count(freq_dist, top_p: float) -> int: + """ + Finds the word occurance for which the cumulative occurances + are `top_p` of the overall word count. + """ + assert top_p < 1 + target = sum(freq_dist.values()) * top_p + cumul = 0 + for _, v in freq_dist.most_common(): + cumul += v + if cumul > target: + return v + raise RuntimeError(f"Invalid top {top_p*100}% of the corpus distribution") + + @staticmethod + def _filter(freq_dist, cutoff: int, text: str) -> str: + """ + For words that are found in the reference distribution, filters those + with an occurrence count less than the cutoff. + """ + words = normalize_answer(text).split() + return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff]) + + def compute(self, guess: str, answers: List[str]) -> F1Metric: + guess = RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, guess) + answers = [ + RareWordF1Calculator._filter(self._freq_dist, self._cutoff_count, a) + for a in answers + ] + if not any(len(a) for a in answers): + # no rare words in labels, set denominator to zero + return F1Metric(0, 0) + return F1Metric.compute(guess, answers) + + +def _build_rare_word_f1(datapath: str) -> RareWordF1Calculator: + all_text = '' + data_path = os.path.join(datapath, 'wizard_of_wikipedia', 'data.json') + with PathManager.open(data_path) as f: + data = json.load(f) + all_text += ' '.join(m['text'] for d in data for m in d['dialog']) + ' ' + return RareWordF1Calculator(all_text, top_p=0.5) + + class WizardOfWikipediaTeacher(FixedDialogTeacher): """ The default teacher; essentially reads the json file and outputs the raw data. @@ -186,15 +243,6 @@ def share(self): ############################################################### -def _build_rare_word_f1(datapath: str) -> RareWordF1Calculator: - all_text = '' - data_path = os.path.join(datapath, 'wizard_of_wikipedia', 'data.json') - with PathManager.open(data_path) as f: - data = json.load(f) - all_text += ' '.join(m['text'] for d in data for m in d['dialog']) + ' ' - return RareWordF1Calculator(all_text, top_p=0.5) - - class WizardDialogKnowledgeTeacher(WizardOfWikipediaTeacher): """ Teacher that returns the following action dict: From f5dd7863858a9c0e5b1d99ff0e1bbdaca3f26f23 Mon Sep 17 00:00:00 2001 From: Spencer Poff Date: Tue, 13 Apr 2021 20:45:31 -0700 Subject: [PATCH 6/6] autoformat doc strings --- parlai/tasks/wizard_of_wikipedia/agents.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/parlai/tasks/wizard_of_wikipedia/agents.py b/parlai/tasks/wizard_of_wikipedia/agents.py index 299cfb75e0b..054d4a66ae3 100644 --- a/parlai/tasks/wizard_of_wikipedia/agents.py +++ b/parlai/tasks/wizard_of_wikipedia/agents.py @@ -120,8 +120,8 @@ def __init__(self, corpus: str, top_p: float = 0.5): @staticmethod def _find_cutoff_count(freq_dist, top_p: float) -> int: """ - Finds the word occurance for which the cumulative occurances - are `top_p` of the overall word count. + Finds the word occurance for which the cumulative occurances are `top_p` of the + overall word count. """ assert top_p < 1 target = sum(freq_dist.values()) * top_p @@ -135,8 +135,8 @@ def _find_cutoff_count(freq_dist, top_p: float) -> int: @staticmethod def _filter(freq_dist, cutoff: int, text: str) -> str: """ - For words that are found in the reference distribution, filters those - with an occurrence count less than the cutoff. + For words that are found in the reference distribution, filters those with an + occurrence count less than the cutoff. """ words = normalize_answer(text).split() return " ".join([w for w in words if freq_dist.get(w, cutoff) < cutoff])