Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Wizint crowdsourcing #3924

Merged
merged 18 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions parlai/crowdsourcing/tasks/wizard_of_internet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Wizard of Internet

This is the crowdsourcing task from the Internet-Augmented Dialogue Generation paper ([link](https://arxiv.org/abs/2107.07566)).
It uses [Mephisto](https://github.com/facebookresearch/Mephisto) platform to collect dialogue data using human workers on Amazon Mechanical Turk.

## How to use
Having setup your ParlAI and Mephisto environment properly (make sure you can run Mephisto demos), you should be able to run this task easily. Most of the configurations for running task are in `conf/dev.yaml` file. Note the files needed in the `data` directory:
*sample_personas.txt* and *sample_locations.txt* are needed to create the curated personas.

You need to have a functional search server running, and sets its address in `search_server` in the `conf/dev.yaml` file. You may set the server up to search internet or any knowledge source of your choosing.
This server responds to the search requests sent by the worker who takes *wizard* role during this task:
It receieves a json with two keys: `q` and `n`, which are a string that is the search query, and an integer that is the number of pages to return, respectively.
It sends its response also as a json under a key named `response` which has a list of documents retrieved for the received search query. Each document is a mapping (dictionary) of *string->string* with at least 3 fields: `url`, `title`, and `content` (see [SearchEngineRetriever](https://github.com/facebookresearch/ParlAI/blob/70ee4a2c63008774fc9e66a8392847554920a14d/parlai/agents/rag/retrieve_api.py#L73) for more info on how this task interacts with the search server).
5 changes: 5 additions & 0 deletions parlai/crowdsourcing/tasks/wizard_of_internet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
236 changes: 236 additions & 0 deletions parlai/crowdsourcing/tasks/wizard_of_internet/acceptability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from typing import Iterable, List
from nltk.stem import PorterStemmer
from parlai.crowdsourcing.utils.acceptability import (
AcceptabilityChecker,
normalize_answer,
)
import parlai.utils.logging as logging


# Bad persona violations
PERSONA_REPEATS_PROMPT = 'repeated the prompt text'
ASKED_WIZARD_QUESTION = 'asked wizard in the persona details'
COPIED_EXTENDED_PERSONA = 'extended persona copies the main persona'
GENERIC_EXTENDED_PERSONA = 'extended persona is generic'

QUESTION_PHRASE = 'what is your'

# Wizard knowledge violations
DEFAULT_KNOWLEDGE_OVERLAP_THRESHOLD = 0.05

POOR_SEARCH_QUERIES = 'poor search queries'
IRRELEVANT_SEARCH__QUERIES = 'irrelevant search terms'
NOT_ENOUGH_SEARCH = 'not enough selected knowledge sources'
SELECTED_SHORT_PIECES = 'short knowledge pieces selected.'
LOW_KNOWLEDGE_OVERLAP = 'low knowledge overlap'


def tokenize_text(text, stemmer, as_set=True):
text = normalize_answer(text)
tokens = [stemmer.stem(word) for word in text.split(' ')]
if as_set:
tokens = set(tokens)
return tokens


def overlap_ratios(a: set, b: set) -> float:
"""
Calculates the Jacard distance between two sets.
"""
overlap = a.intersection(b)
union = a.union(b)
return len(overlap) / (len(union) + 0.001)


def is_valid_agent_chat_message(message, agent_id):
return (
message.get('text')
and message.get('id') == agent_id
and not message.get('is_search_query', False)
)


def bad_persona(persona, stemmer):
"""
Check for poor persona selection by apprentice.
"""
persona_parts = persona.split('\n')

# It is not from the persona selection ones (personas used during the pilot).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to include this? i.e. are we supporting personas from the pilot?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this one is because of the extra explanation that apprentice adds, or the locations that we are adding. Each of those are going be a separate line.

if not (
len(persona_parts) == 2
or (len(persona_parts) == 3 and 'I live in ' in persona_parts[0])
):
logging.warning(f'Old fashioned persona: {persona}')
return

# Removing the location ('I live in X') part
if len(persona_parts) == 3:
persona_parts = persona_parts[1:]

main_pers, ext_pers = [p.lower() for p in persona_parts]

violations = []

# Bad main persona response
if main_pers.startswith('My favorite '):
for phrase in ('i like', 'my favorite'):
persona_core = main_pers
# Remove the original My favorite
persona_core = main_pers[len('My favorite ') :]
if phrase in persona_core.lower():
violations.append(PERSONA_REPEATS_PROMPT)
break

# Extended persona that asks questions
for phrase in (QUESTION_PHRASE,):
if phrase in ext_pers:
violations.append(ASKED_WIZARD_QUESTION)

# Extended persona that mostly repeats the main persona
main_pers_tokens = tokenize_text(main_pers, stemmer)
ext_pers_tokens = tokenize_text(ext_pers, stemmer)
if len(ext_pers_tokens.difference(main_pers_tokens)) < 2:
violations.append(COPIED_EXTENDED_PERSONA)

# Use of non-generic words in persona.
common_phrases = ('i', 'it', 'like', 'very', 'much', 'favorite', 'is', 'am')
tokens = [w.strip() for w in ext_pers.split(' ') if w]
ext_useful_words = [t for t in tokens if t not in common_phrases]
if len(tokens) > 4 and len(ext_useful_words) < 2:
violations.append(GENERIC_EXTENDED_PERSONA)

return violations


def poor_knowledge_selection(messages, persona, stemmer, knwldg_ovlp_thrshld):
"""
Check for poor search and knowledge selection by wizard.
"""
# Collecting search and knowledge selections
search_terms = []
selected_knowledge = []
message_history_tokens = tokenize_text(persona, stemmer)

n_search_query_not_in_history = 0
for msg in messages:
if msg.get('text', None):
message_history_tokens = message_history_tokens.union(
tokenize_text(msg['text'], stemmer)
)

if msg['id'] != 'Wizard':
continue

selections = msg.get('task_data', {}).get('selected_text_candaidtes')
if not selections or selections[0][0]:
continue

search_query = msg['task_data']['search_query']
search_terms.append(search_query)
if message_history_tokens.isdisjoint(tokenize_text(search_query, stemmer)):
n_search_query_not_in_history += 1

selected_parts = []
for doc_id in range(1, len(selections)):
doc_selections = selections[doc_id]
for sentence_id in range(len(doc_selections)):
if doc_selections[sentence_id]:
selected_parts.append(
msg['task_data']['text_candidates'][doc_id - 1]['content'][
sentence_id
]
)

selected_knowledge.append(
{'text': msg['text'], 'knowledge': ' '.join(selected_parts)}
)

knowledge_length = []
knowledge_overlaps = []
for knwldg in selected_knowledge:
knowledge_tokens = tokenize_text(knwldg['knowledge'], stemmer)
knowledge_length.append(len(knowledge_tokens))

response_tokens = tokenize_text(knwldg['text'], stemmer)
knowledge_overlaps.append(overlap_ratios(knowledge_tokens, response_tokens))

violations = []

# Repeated the same search queries
if len(search_terms) - len(set(search_terms)) > 3:
violations.append(POOR_SEARCH_QUERIES)

# Search doesn't have overlap with message history
if n_search_query_not_in_history > 2:
violations.append(IRRELEVANT_SEARCH__QUERIES)

# No selection
if not knowledge_length:
violations.append(NOT_ENOUGH_SEARCH)

# Only selecting short sentences
if np.average(knowledge_length) < 5:
violations.append(SELECTED_SHORT_PIECES)

# Small overlap between response and the selected knowledge parts
knowledge_overlap_avg = np.average(knowledge_overlaps)
if knowledge_overlap_avg < knwldg_ovlp_thrshld:
violations.append(f'{LOW_KNOWLEDGE_OVERLAP} ({knowledge_overlap_avg})')

return violations


class WizardOfInternetAcceptabilityChecker(AcceptabilityChecker):
"""
ParlAI general acceptabilty checker customized for the wizard of internet.
"""

def __init__(self):
self.knowledge_overlap_threshold = DEFAULT_KNOWLEDGE_OVERLAP_THRESHOLD
self.post_stemmer = PorterStemmer()
super().__init__()

def check_messages(
self,
agent_id: str,
persona: str,
messages: List[str],
is_worker_0: bool,
violation_types: Iterable[str] = (),
) -> str:
violations = []
general_chat_violations = super().check_messages(
self.get_conversation_messages(messages, agent_id),
is_worker_0,
violation_types,
)
if general_chat_violations:
violations.extend(general_chat_violations.split(','))

if agent_id == 'Apprentice':
persona_violations = bad_persona(persona, self.post_stemmer)
if persona_violations:
violations.extend(persona_violations)

if agent_id == 'Wizard':
knowledge_violations = poor_knowledge_selection(
messages, persona, self.post_stemmer, self.knowledge_overlap_threshold
)
if knowledge_violations:
violations.extend(knowledge_violations)

return ','.join(violations)

def get_conversation_messages(self, agent_messages, agent_id):
return [
msg['text']
for msg in agent_messages
if is_valid_agent_chat_message(msg, agent_id)
]
42 changes: 42 additions & 0 deletions parlai/crowdsourcing/tasks/wizard_of_internet/conf/dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

#@package _global_
mephisto:
blueprint:
onboarding_qualification: wizard-onboarding-dev
block_qualification: wizard-block-dev
role_qualification: wizard-role-trained-dev
custom_source_dir: ${task_dir}/webapp
world_file: ${task_dir}/worlds.py
task_description_file: ${task_dir}/task_description.html
num_conversations: 1
min_turns: 4
wizard_time_out: 180
apprentice_time_out: 120
search_warning_turn: 2
search_warning_threshold: 1
select_warning_turn: 3
select_warning_threshold: 1
personas_file: "${task_dir}/data/sample_personas.txt"
persona_counts_file: "${task_dir}/data/persona_use_count.txt"
banned_words_file: "${task_dir}/data/bad_words.txt"
max_times_persona_use: 1
locations_file: "${task_dir}/data/sample_locations.txt"
use_personas_with_replacement: true
shuffle_persona: false
search_server: "http://localhost:3005/search_server"

task:
task_name: wizard-of-internet-dev
task_title: "Have a knowledgeable conversation!"
task_description:
"In this task, you will have a conversation with a chat partner.
One of you will play the role of a given character description,
and will discuss your interests.
The other of you will use information from the internet
to discuss your partner's interests in depth."
task_reward: 2.0
task_tags: "chat,dialog"
assignment_duration_in_seconds: 600

mturk:
worker_blocklist_paths: ""
Loading