Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
woi: filter knowledge (#4114)
Browse files Browse the repository at this point in the history
* woi: filter knowledge

* moar

* moar
  • Loading branch information
jaseweston authored Oct 28, 2021
1 parent 33ce8b5 commit 86aa8e4
Showing 1 changed file with 48 additions and 2 deletions.
50 changes: 48 additions & 2 deletions parlai/tasks/wizard_of_internet/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
from parlai.utils.data import DatatypeHelper
import parlai.utils.logging as logging
import parlai.tasks.wizard_of_internet.constants as CONST
from parlai.core.mutators import register_mutator
from parlai.core.mutators import register_mutator, ManyEpisodeMutator
from parlai.tasks.wizard_of_wikipedia.agents import (
AddLabel as AddLabelWizWiki,
AddLabelLM as AddLabelLMWizWiki,
CheckedSentenceAsLabel as CheckedSentenceAsLabelWizWiki,
AddCheckedSentence as AddCheckedSentenceWizWiki,
)

import random
from .build import build


Expand Down Expand Up @@ -638,3 +637,50 @@ class AddLabelLM(AddLabelLMWizWiki):
"""

pass


@register_mutator("woi_filter_no_passage_used")
class WoiFilterNoPassageUsed(ManyEpisodeMutator):
"""
Allows to filter any examples where no passage was selected to base the wizard reply
on.
This works best in flattened mode. E.g. run with: parlai display_data -t
wizard_of_internet -n 100 -dt valid --mutators flatten+filter_no_passage_used
"""

def many_episode_mutation(self, episode):
out_episodes = []
for e in episode:
checked_sentences = e.get(CONST.SELECTED_SENTENCES)
checked_sentences = ' '.join(checked_sentences)
if checked_sentences == CONST.NO_SELECTED_SENTENCES_TOKEN:
pass
else:
out_episodes.append([e])
return out_episodes


@register_mutator("woi_filter_selected_knowledge_in_retrieved_docs")
class WoiFilterSelectedKnowledgeInRetrievedDocs(ManyEpisodeMutator):
"""
Allows to filter any examples where '__retrieved-docs__' field does contain the
'__selected-sentences__'.
"""

def many_episode_mutation(self, episode):
out_episodes = []
for e in episode:
checked_sentences = e.get(CONST.SELECTED_SENTENCES)
docs = ' '.join(e.get('__retrieved-docs__'))
if ' '.join(checked_sentences) != CONST.NO_SELECTED_SENTENCES_TOKEN:
found = True
for sent in checked_sentences:
s = sent.lstrip(' ').rstrip(' ')
if s not in docs:
found = False
if found:
out_episodes.append([e])
else:
pass
return out_episodes

0 comments on commit 86aa8e4

Please sign in to comment.