diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/README.md b/parlai/crowdsourcing/tasks/wizard_of_internet/README.md new file mode 100644 index 00000000000..3b9db71a96e --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/README.md @@ -0,0 +1,13 @@ +# Wizard of Internet + +This is the crowdsourcing task from the Internet-Augmented Dialogue Generation paper ([link](https://arxiv.org/abs/2107.07566)). +It uses [Mephisto](https://github.com/facebookresearch/Mephisto) platform to collect dialogue data using human workers on Amazon Mechanical Turk. + +## How to use +Having setup your ParlAI and Mephisto environment properly (make sure you can run Mephisto demos), you should be able to run this task easily. Most of the configurations for running task are in `conf/dev.yaml` file. Note the files needed in the `data` directory: +*sample_personas.txt* and *sample_locations.txt* are needed to create the curated personas. + +You need to have a functional search server running, and sets its address in `search_server` in the `conf/dev.yaml` file. You may set the server up to search internet or any knowledge source of your choosing. +This server responds to the search requests sent by the worker who takes *wizard* role during this task: +It receieves a json with two keys: `q` and `n`, which are a string that is the search query, and an integer that is the number of pages to return, respectively. +It sends its response also as a json under a key named `response` which has a list of documents retrieved for the received search query. Each document is a mapping (dictionary) of *string->string* with at least 3 fields: `url`, `title`, and `content` (see [SearchEngineRetriever](https://github.com/facebookresearch/ParlAI/blob/70ee4a2c63008774fc9e66a8392847554920a14d/parlai/agents/rag/retrieve_api.py#L73) for more info on how this task interacts with the search server). diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/__init__.py b/parlai/crowdsourcing/tasks/wizard_of_internet/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/acceptability.py b/parlai/crowdsourcing/tasks/wizard_of_internet/acceptability.py new file mode 100644 index 00000000000..6026264c168 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/acceptability.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +from typing import Iterable, List +from nltk.stem import PorterStemmer +from parlai.crowdsourcing.utils.acceptability import ( + AcceptabilityChecker, + normalize_answer, +) +import parlai.utils.logging as logging + + +# Bad persona violations +PERSONA_REPEATS_PROMPT = 'repeated the prompt text' +ASKED_WIZARD_QUESTION = 'asked wizard in the persona details' +COPIED_EXTENDED_PERSONA = 'extended persona copies the main persona' +GENERIC_EXTENDED_PERSONA = 'extended persona is generic' + +QUESTION_PHRASE = 'what is your' + +# Wizard knowledge violations +DEFAULT_KNOWLEDGE_OVERLAP_THRESHOLD = 0.05 + +POOR_SEARCH_QUERIES = 'poor search queries' +IRRELEVANT_SEARCH__QUERIES = 'irrelevant search terms' +NOT_ENOUGH_SEARCH = 'not enough selected knowledge sources' +SELECTED_SHORT_PIECES = 'short knowledge pieces selected.' +LOW_KNOWLEDGE_OVERLAP = 'low knowledge overlap' + + +def tokenize_text(text, stemmer, as_set=True): + text = normalize_answer(text) + tokens = [stemmer.stem(word) for word in text.split(' ')] + if as_set: + tokens = set(tokens) + return tokens + + +def overlap_ratios(a: set, b: set) -> float: + """ + Calculates the Jacard distance between two sets. + """ + overlap = a.intersection(b) + union = a.union(b) + return len(overlap) / (len(union) + 0.001) + + +def is_valid_agent_chat_message(message, agent_id): + return ( + message.get('text') + and message.get('id') == agent_id + and not message.get('is_search_query', False) + ) + + +def bad_persona(persona, stemmer): + """ + Check for poor persona selection by apprentice. + """ + persona_parts = persona.split('\n') + + # It is not from the persona selection ones (personas used during the pilot). + if not ( + len(persona_parts) == 2 + or (len(persona_parts) == 3 and 'I live in ' in persona_parts[0]) + ): + logging.warning(f'Old fashioned persona: {persona}') + return + + # Removing the location ('I live in X') part + if len(persona_parts) == 3: + persona_parts = persona_parts[1:] + + main_pers, ext_pers = [p.lower() for p in persona_parts] + + violations = [] + + # Bad main persona response + if main_pers.startswith('My favorite '): + for phrase in ('i like', 'my favorite'): + persona_core = main_pers + # Remove the original My favorite + persona_core = main_pers[len('My favorite ') :] + if phrase in persona_core.lower(): + violations.append(PERSONA_REPEATS_PROMPT) + break + + # Extended persona that asks questions + for phrase in (QUESTION_PHRASE,): + if phrase in ext_pers: + violations.append(ASKED_WIZARD_QUESTION) + + # Extended persona that mostly repeats the main persona + main_pers_tokens = tokenize_text(main_pers, stemmer) + ext_pers_tokens = tokenize_text(ext_pers, stemmer) + if len(ext_pers_tokens.difference(main_pers_tokens)) < 2: + violations.append(COPIED_EXTENDED_PERSONA) + + # Use of non-generic words in persona. + common_phrases = ('i', 'it', 'like', 'very', 'much', 'favorite', 'is', 'am') + tokens = [w.strip() for w in ext_pers.split(' ') if w] + ext_useful_words = [t for t in tokens if t not in common_phrases] + if len(tokens) > 4 and len(ext_useful_words) < 2: + violations.append(GENERIC_EXTENDED_PERSONA) + + return violations + + +def poor_knowledge_selection(messages, persona, stemmer, knwldg_ovlp_thrshld): + """ + Check for poor search and knowledge selection by wizard. + """ + # Collecting search and knowledge selections + search_terms = [] + selected_knowledge = [] + message_history_tokens = tokenize_text(persona, stemmer) + + n_search_query_not_in_history = 0 + for msg in messages: + if msg.get('text', None): + message_history_tokens = message_history_tokens.union( + tokenize_text(msg['text'], stemmer) + ) + + if msg['id'] != 'Wizard': + continue + + selections = msg.get('task_data', {}).get('selected_text_candaidtes') + if not selections or selections[0][0]: + continue + + search_query = msg['task_data']['search_query'] + search_terms.append(search_query) + if message_history_tokens.isdisjoint(tokenize_text(search_query, stemmer)): + n_search_query_not_in_history += 1 + + selected_parts = [] + for doc_id in range(1, len(selections)): + doc_selections = selections[doc_id] + for sentence_id in range(len(doc_selections)): + if doc_selections[sentence_id]: + selected_parts.append( + msg['task_data']['text_candidates'][doc_id - 1]['content'][ + sentence_id + ] + ) + + selected_knowledge.append( + {'text': msg['text'], 'knowledge': ' '.join(selected_parts)} + ) + + knowledge_length = [] + knowledge_overlaps = [] + for knwldg in selected_knowledge: + knowledge_tokens = tokenize_text(knwldg['knowledge'], stemmer) + knowledge_length.append(len(knowledge_tokens)) + + response_tokens = tokenize_text(knwldg['text'], stemmer) + knowledge_overlaps.append(overlap_ratios(knowledge_tokens, response_tokens)) + + violations = [] + + # Repeated the same search queries + if len(search_terms) - len(set(search_terms)) > 3: + violations.append(POOR_SEARCH_QUERIES) + + # Search doesn't have overlap with message history + if n_search_query_not_in_history > 2: + violations.append(IRRELEVANT_SEARCH__QUERIES) + + # No selection + if not knowledge_length: + violations.append(NOT_ENOUGH_SEARCH) + + # Only selecting short sentences + if np.average(knowledge_length) < 5: + violations.append(SELECTED_SHORT_PIECES) + + # Small overlap between response and the selected knowledge parts + knowledge_overlap_avg = np.average(knowledge_overlaps) + if knowledge_overlap_avg < knwldg_ovlp_thrshld: + violations.append(f'{LOW_KNOWLEDGE_OVERLAP} ({knowledge_overlap_avg})') + + return violations + + +class WizardOfInternetAcceptabilityChecker(AcceptabilityChecker): + """ + ParlAI general acceptabilty checker customized for the wizard of internet. + """ + + def __init__(self): + self.knowledge_overlap_threshold = DEFAULT_KNOWLEDGE_OVERLAP_THRESHOLD + self.post_stemmer = PorterStemmer() + super().__init__() + + def check_messages( + self, + agent_id: str, + persona: str, + messages: List[str], + is_worker_0: bool, + violation_types: Iterable[str] = (), + ) -> str: + violations = [] + general_chat_violations = super().check_messages( + self.get_conversation_messages(messages, agent_id), + is_worker_0, + violation_types, + ) + if general_chat_violations: + violations.extend(general_chat_violations.split(',')) + + if agent_id == 'Apprentice': + persona_violations = bad_persona(persona, self.post_stemmer) + if persona_violations: + violations.extend(persona_violations) + + if agent_id == 'Wizard': + knowledge_violations = poor_knowledge_selection( + messages, persona, self.post_stemmer, self.knowledge_overlap_threshold + ) + if knowledge_violations: + violations.extend(knowledge_violations) + + return ','.join(violations) + + def get_conversation_messages(self, agent_messages, agent_id): + return [ + msg['text'] + for msg in agent_messages + if is_valid_agent_chat_message(msg, agent_id) + ] diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/conf/dev.yaml b/parlai/crowdsourcing/tasks/wizard_of_internet/conf/dev.yaml new file mode 100644 index 00000000000..083ff0eb5db --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/conf/dev.yaml @@ -0,0 +1,42 @@ + +#@package _global_ +mephisto: + blueprint: + onboarding_qualification: wizard-onboarding-dev + block_qualification: wizard-block-dev + role_qualification: wizard-role-trained-dev + custom_source_dir: ${task_dir}/webapp + world_file: ${task_dir}/worlds.py + task_description_file: ${task_dir}/task_description.html + num_conversations: 1 + min_turns: 4 + wizard_time_out: 180 + apprentice_time_out: 120 + search_warning_turn: 2 + search_warning_threshold: 1 + select_warning_turn: 3 + select_warning_threshold: 1 + personas_file: "${task_dir}/data/sample_personas.txt" + persona_counts_file: "${task_dir}/data/persona_use_count.txt" + banned_words_file: "${task_dir}/data/bad_words.txt" + max_times_persona_use: 1 + locations_file: "${task_dir}/data/sample_locations.txt" + use_personas_with_replacement: true + shuffle_persona: false + search_server: "http://localhost:3005/search_server" + + task: + task_name: wizard-of-internet-dev + task_title: "Have a knowledgeable conversation!" + task_description: + "In this task, you will have a conversation with a chat partner. + One of you will play the role of a given character description, + and will discuss your interests. + The other of you will use information from the internet + to discuss your partner's interests in depth." + task_reward: 2.0 + task_tags: "chat,dialog" + assignment_duration_in_seconds: 600 + +mturk: + worker_blocklist_paths: "" \ No newline at end of file diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/constants.py b/parlai/crowdsourcing/tasks/wizard_of_internet/constants.py new file mode 100644 index 00000000000..8ca681314c6 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/constants.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Possible roles of an agent during task +NO_ROLE = 0 +WIZARD = 1 +APPRENTICE = 2 +IN_TRAINING = 10 +WIZARD_IN_TRAINING = WIZARD + IN_TRAINING +APPRENTICE_IN_TRAINING = APPRENTICE + IN_TRAINING +# The role_id to role_name mapping +ROLE_NAMES = {WIZARD: 'Wizard', APPRENTICE: 'Apprentice'} + +# The keys to get agent qualification data from opt. +SAVED_DATA_WORKER_KEY = 'worker' +SAVED_DATA_IS_WIZARD_KEY = 'is_wizard' +SAVED_DATA_ROLE_QUALIFICATION_DATA_KEY = 'qualification_dict' +ROLE_QUALIFICATION_NAME_KEY = 'role_qname' + +# OnBoardingSteps +# NOTE: Make sure these number are consistent with OnboardingSteps, +# as they are defined in the SidePane.jsx frontend file. +ONBOARDING_STEPS = { + 'NOT_ONBOARDING': 0, + 'CHAT_INTERFACE': 1, + 'TRY_SEARCH': 2, + 'PERSONA_WIZARD': 3, + 'PERSONA_APPRENTICE': 4, + 'WAITING': 10, +} + +# Name of (bot)agents involved in the task world +ONBOARDING_AGENT = 'OnboardingBot' +PERSONA_AGENT = 'PersonaAgent' +SEARCH_AGENT = 'SearchAgent' +COORDINATOR_AGENT = 'Coordinator' + +# NOTE: do not forget to change ONBOARDING_PERSONA_KEYWORDS below if changing ONBOARDING_PERSONA. +# During the onboarding we are checking for worker responses to have sufficient overlap with the +# list of words in ONBOARDING_PERSONA_KEYWORDS, to ensure they are talking about something relevant +# to the persona topic. Thus, changing ONBOARDING_PERSONA means you need to come up with a relevant +# list of keywords for it in ONBOARDING_PERSONA_KEYWORDS. +ONBOARDING_PERSONA = 'I do yoga on beach every morning.' +# The keywords related to the Onboarding Persona +ONBOARDING_PERSONA_KEYWORDS = ( + 'beach', + 'exercise', + 'gym', + 'healthy', + 'lake', + 'meditat', + 'morning', + 'ocean', + 'outdoor', + 'peace', + 'pose', + 'relax', + 'sea', + 'sport', + 'stress', + 'sunrise', + 'yoga', +) + + +# The wait time in seconds to allow the agents read the instructions during the onboarding. +# After this, we allow them to continue after a small action (for example, type anything). +# The keys are the onboarding tutorial step; values are the wait times corresponding to that. +TUTORIAL_WAIT_TIMES = {'chat-interface': 1, 'persona': 2, 'knowledge': 2} + +# Constants for checking onboarding work quality +WORKER_REJECT_REASON = 'reason_to_reject' +MIN_AVG_CHAR_LENGTH_UTTERANCES = 10 +MIN_AVG_WORD_LENGTH_UTTERANCES = 5 +MIN_NUM_SEARCH_ONBOARDING = 2 +MIN_NUM_SELECTED_SENTENCES_ONBOARDING = 2 + +# A prefix token for curated personas that means this persona requires a location. +# We assign a random location (from a list of cities in US) to this persona. +PERSONA_NEEDS_LOCATION_TOKEN = '*' +PROBABILITY_CHOOSING_TEMPLATE_PERSONA = 0.7 +# Number of topics that its items are shown to agent to pick persona +CURATED_PERSONA_CHOICES = 3 +TEMPLATE_PERSONAS_CHOICES = 2 +# Persona template items bundled based on topic +TEMPLATE_PERSONAS_TOPICS = [ + 'fashion brand,fashion designer,clothing type', + 'book,author', + 'artist,music band,song,singer', + 'tv show,movie,actor,director', + 'sports team,athlete', + 'hobby,game', + 'item to buy,item recently bought', +] +PERSONA_EXPANSION_MIN_LEN_CHAR = 20 + +# Controlling the number of retrieved docs +NUM_RETRIEVED_SEARCH_NEWS = 2 +NUM_RETRIEVED_SEARCH_DOCS = 5 + +# The time (in second) for the cached role counts to be considered fresh. +# Updating this count requires quering the Database, thus is slow. +TALLY_CACHE_TIMEOUT = 10 + +# Long messages +ONBOARDING_WELCOME = ( + 'Welcome onboard!\n' + 'Here you will have an engaging, ' + 'knowledgeable chat with another person. ' + 'This is the chat interface you will be using.\n' + 'Our interactive tutorial introduces you to the main task. ' + 'If you finish all the steps successfully, ' + 'and in reasonable time, we redirect you to the main task.\n' + 'Please have a friendly chitchat pretending you live in a ' + 'world unaffected by covid and recent controversial events.' +) + +ONBOARDING_ACKNOWLEDGE_UNDERSTOOD = ( + 'Please acknowledge that this is clear ' + 'in your response message ' + '(for example, type \'I understand.\' in response.)' +) + +FINISHED_ONBOARDING = ( + 'Good job, you now know how this task works!\n' + 'You can check the task instructions on the left at any time ' + 'during the task. Please wait while we pair ' + 'you with another participant.' +) + +WIZARD_INTRODUCE_KNOWLEDGE = ( + 'During this chat you must pretend that you are a knowledgeable ' + 'entity with conversational ability rather than a human being ' + '(imagine a digital friend on a smartphone).' + 'So you can talk about the world, but your character is NOT able to ' + 'engage in physical activities such as sport activities or eating.' +) + +WIZARD_INTRODUCE_SEARCH = ( + 'We will provide a search bar for you ' + 'to look up useful knowledge about topics that interest ' + 'your chat partner during the conversation.\n' + 'You may try search as many times as you may like ' + 'to find useful information that helps you craft ' + 'engaging and informative messages.\n' + 'Please conduct a natural conversation and avoid copy/paste.' +) + +WIZARD_TRY_SEARCH = ( + 'See the blinking area (in the left panel) ' + 'for the search bar you will be using during this task. ' + 'During the task, when you use this search bar, ' + 'it will bring up a number of articles from the internet. ' + 'You can click on an article to show ' + 'it\'s content, that is split into sentences. ' + 'Use information from these sentences ' + 'to have an informed conversation.\n\n' + 'When you use knowledge from one or more sentences, ' + 'please select them (click the checkbox next to those ' + 'sentences) before sending your message.\n' + 'If you do not use any knowledge from search results, ' + 'select the checkbox for ' + '"Did not use search results for this message."\n\n' + 'Now try out the search functionality to ' + 'craft a message with information on a topic of your choise ' + '(Yoga, sushi, Star wars, anything you choose). ' + 'Here are the steps :\n' + ' 1- use the bar to search.\n' + ' 2- check the search results for finding useful information.\n' + ' 3- write your message using knowledge you find in the search results.\n' + ' 4- make sure you select the checkmark for sentences you used.\n' + ' 5- send the message.' +) + +WIZARD_INTRODUCE_APPRENTICE_PERSONA = ( + 'You can see your partner\'s assigned persona ' + 'description in the left pane (see the blinking box). ' + 'The purpose of the task is to have an in-depth conversation ' + 'with your chat partner about THEIR assigned interests.\n' + 'It is very important to keep in mind that this is a chitchat: ' + 'unless it is necessary, do NOT bring up random facts in the middle of conversation. ' + 'For example, if your chat partner likes a music band ' + 'do not keep talking about band members names or birthdays.\n\n' + 'Use your search bar on the left and craft a message ' + 'that interests your partner, based on their persona, ' + 'using information you find on internet.\n' + 'Don\'t forget to select the sentences from the ' + 'search results that helped you craft that message.' +) + +WIZARD_PERSONA_EMPHASIZE = ( + 'Don\'t forget the focus of this conversation is the interests of your partner (not you). ' + 'Do NOT talk about yourself or your interests and activities; ' + 'talk about theirs (you will see their interests in the blue box in the left panel). ' + 'Have an engaging and knowledgeable chitchat 😀, but avoid sending random or boring facts about the topic. ' + 'For example, if your partner likes Mount Everest, DO NOT say things such as ' + '"Did you know Mount Everest is Earth\'s highest mountain." or ' + '"Its elevation is 8,848.86 meters from the sea level" as this is dull. 😒' +) + +WIZARD_STARTING_INSTRUCTION = ( + 'Please begin the conversation ' + 'by discussing one of your partner’s interests. ' + 'For example, if your partner likes tennis, ' + 'you might discuss whether Roger Federer is better than Rafael Nadal.' +) + +APPRENTICE_INTRODUCE_PERSONA = ( + 'At the beginning of this task we will ask you to ' + 'choose a persona for yourself. ' + 'We keep your selected persona in the left pane ' + '(See the example persona inside the blinking box).\n' + 'During this chat you play the role of someone with that persona. ' + 'The purpose of the task is to have ' + 'an in-depth conversation with your chat partner ' + 'about the interests of someone with your assigned persona.' +) + +APPRENTICE_INTRODUCE_WIZARD = ( + 'Imagine your chat partner is a non-human entity ' + 'you can chat to, for example a digital friend living inside ' + 'your phone. So you can ask their opinion about the world, ' + 'but they are not able to do physical activities, ' + 'such as playing basketball or eating. Don\'t forget that ' + 'the conversation should focus on the interests ' + 'of the persona that you play during this task.' +) + +APPRENTICE_INTRODUCE_WIZARD_KNOWLEDGE = ( + 'Your chat partner has extensive knowledge ' + 'about many things, and access to lots of information.\n' + 'Your partner will strive to enlighten you ' + 'about your topics of interest, according to your persona. ' + 'Feel free to dive deep discussing these topics.' +) + +APPRENTICE_PERSONA_ROLE_INSTRUCTION = ( + 'Let\'s assume you are in the main task and you ' + 'have the example persona that we show now ' + '(blue box on the left). ' + 'Please say something interesting about ' + 'your role\'s persona to continue. ' + 'Don\'t forget to assume the role of someone with ' + 'that persona for the rest of this task.' +) + +APPRENTICE_CHITCHAT_INSTRUCTION = ( + 'Imagine you are in the main chat. Go ahead and send them a ' + 'chitchat message, assuming your assigned persona.' +) + +APPRENTICE_PERSONA_MSG_INSTRUCTION = ( + 'Go ahead and try writing a message ' + 'about your example role\'s persona to your partner. ' + 'You may even ask them questions if you want.' +) + +APPRENTICE_CHOOSE_PERSONA_TEMPLATE_REQUEST = ( + 'Please use the form below to define a persona for the character that ' + 'you will be playing during this task: ' + 'use the first two fields in this form to define an interest for ' + 'the persona. Then add a sentence to refine it ' + 'and to make it more interesting or engaging. ' + 'Create an interesting character that you want to play. ' + 'Remember that the main topic of conversation should be around ' + 'this persona. Be creative and imaginative.' +) + +APPRENTICE_CHOOSE_CURATED_PERSONA_REQUEST = ( + 'Please use the form below to define a persona for the role that ' + 'you will be playing during this task: ' + 'choose the first characteristic and then add a sentence to refine it ' + 'and to make it more interesting or engaging. Remember that ' + 'the main topic of conversation should be around this persona. ' + 'Be creative and imaginative.\n' + 'For example, if you choose "I like swimming", ' + 'you may add "I recently won a national medal in Butterfly Stroke".' +) + +APPRENTICE_STARTING_INSTRUCTION = ( + 'Please assume the character of the person you see in the left panel ' + '(play that role), and have an engaging conversation. ' + 'For example, if your character likes tennis, ' + 'you might discuss whether Roger Federer is better than Rafael Nadal.' +) + +USE_SEARCH_WARNING_MESSAGE = ( + 'Please try to use the search bar more often to look up ' + 'useful information about the interests of your chat partner.' +) + +USE_SEARCH_RESULTS_WARNING_MESSAGE = ( + 'Make sure you search for finding relevant information, ' + 'and select the sentences from search results for ' + 'crafting your messages more often.' +) + +# List of reasons for flagging low quality work. These are common reasons +# handled by parlai.crowdsourcing.utils.acceptability.AcceptabilityChecker module +ACCEPTABILITY_VIOLATIONS = ( + 'all_caps', + 'exact_match', + 'safety', + 'min_words', + 'penalize_greetings', +) diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/run.py b/parlai/crowdsourcing/tasks/wizard_of_internet/run.py new file mode 100644 index 00000000000..f0b1e5489ab --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/run.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import os +from collections import defaultdict +import hydra +from omegaconf import DictConfig +from dataclasses import dataclass, field +from typing import List, Dict, Any + +from parlai.crowdsourcing.tasks.wizard_of_internet import constants +from parlai.crowdsourcing.tasks.wizard_of_internet.wizard_internet_blueprint import ( + WIZARD_INTERNET_PARLAICHAT_BLUEPRINT, +) +from parlai.crowdsourcing.utils.mturk import MTurkRunScriptConfig +import parlai.utils.logging as logging + +from mephisto.abstractions.databases.local_database import LocalMephistoDB +from mephisto.abstractions.blueprints.parlai_chat.parlai_chat_blueprint import ( + SharedParlAITaskState, +) +from mephisto.operations.operator import Operator +from mephisto.operations.hydra_config import register_script_config +from mephisto.tools.scripts import load_db_and_process_config + + +TASK_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + +defaults = [ + {'mephisto/blueprint': WIZARD_INTERNET_PARLAICHAT_BLUEPRINT}, + {'mephisto/architect': 'local'}, + {'mephisto/provider': 'mock'}, + {'conf': 'dev'}, +] + + +@dataclass +class ScriptConfig(MTurkRunScriptConfig): + defaults: List[Any] = field(default_factory=lambda: defaults) + task_dir: str = TASK_DIRECTORY + turn_timeout: int = field( + default=300, + metadata={ + 'help': 'Maximum response time before kicking ' + 'a worker out, default 300 seconds' + }, + ) + + +register_script_config(name='scriptconfig', module=ScriptConfig) + + +def load_apprentice_persona_list(personas_fpath: str, shuffle: bool): + """ + Reads a list of curated apprentice personas. + """ + logging.info('Loading personas.') + with open(personas_fpath, 'r') as pf: + personas = [p.strip() for p in pf if p.strip()] + logging.info(f'{len(personas)} personas loaded.') + if shuffle: + random.shuffle(personas) + return personas + + +def load_previously_used_personas_counts(fpath: str): + """ + Loads an existing count for how many times each persona was used. + + This is useful if the task was restarted after some initial data collection. + """ + logging.info('Loading the previous runs persona counts.') + personas_count = defaultdict(int) + try: + with open(fpath, 'r') as fi: + for pl in fi: + if not pl.strip(): + continue + persona, count = pl.strip().lower().split(';') + personas_count[persona.strip()] = int(count) + except FileNotFoundError: + logging.info( + f'Persona count file not found in {fpath}. Starting new persona use counter.' + ) + + logging.info(f'{len(personas_count)} previously used persona counts loaded.') + return personas_count + + +def get_persona_locations(locations_fpath: str): + """ + Reads a list of locations. + """ + locations = [] + logging.info('Loading the locations file.') + with open(locations_fpath) as lf: + for line in lf: + s = line.strip() + if not s: + continue + locations.append(s) + logging.info(f'{len(locations)} location loaded') + return locations + + +def remove_overused_persona( + personas: List[str], persona_use_count: Dict[str, int], max_persona_use: int +): + """ + Removes personas that were used too often from the list of personas. + """ + if not max_persona_use or not persona_use_count: + return personas + cleaned_personas = [] + for p in personas: + if persona_use_count[p.lower()] < max_persona_use: + cleaned_personas.append(p) + logging.info( + f'{len(cleaned_personas)} out of {len(personas)} personas accepted for use, ' + f'based on use count being less than maximum allowed of {max_persona_use}' + ) + return cleaned_personas + + +def get_world_opt(config: DictConfig): + """ + Generates the main chat world opt from Mephisto config. + """ + blueprint_data = config.mephisto.blueprint + previous_personas_count = load_previously_used_personas_counts( + blueprint_data.persona_counts_file + ) + num_max_persona_use = blueprint_data.max_times_persona_use + personas = load_apprentice_persona_list( + blueprint_data.personas_file, blueprint_data.shuffle_persona + ) + personas = remove_overused_persona( + personas, previous_personas_count, num_max_persona_use + ) + locations = get_persona_locations(blueprint_data.locations_file) + return { + 'send_task_data': True, + 'min_turns': blueprint_data.min_turns, + 'wizard_time_out': blueprint_data.wizard_time_out, + 'apprentice_time_out': blueprint_data.apprentice_time_out, + 'search_warning_turn': blueprint_data.search_warning_turn, + 'search_warning_threshold': blueprint_data.search_warning_threshold, + 'select_warning_turn': blueprint_data.select_warning_turn, + 'select_warning_threshold': blueprint_data.select_warning_threshold, + 'personas': personas, + 'prev_persona_count': previous_personas_count, + 'max_times_persona_use': num_max_persona_use, + 'locations': locations, + 'pick_persona_with_replacement': blueprint_data.use_personas_with_replacement, + 'search_server': blueprint_data.search_server, + 'num_passages_retrieved': blueprint_data.num_passages_retrieved, + 'soft_block_qname': blueprint_data.block_qualification, + constants.ROLE_QUALIFICATION_NAME_KEY: blueprint_data.role_qualification, + } + + +def get_onboarding_world_opt(config: DictConfig): + """ + Generates onboarding world opt from Mephisto config. + """ + blueprint_data = config.mephisto.blueprint + return { + 'wizard_time_out': blueprint_data.wizard_time_out, + 'apprentice_time_out': blueprint_data.apprentice_time_out, + 'send_task_data': False, + 'is_onboarding': True, + 'search_server': blueprint_data.search_server, + 'num_passages_retrieved': blueprint_data.num_passages_retrieved, + 'onboarding_qualification': blueprint_data.onboarding_qualification, + constants.ROLE_QUALIFICATION_NAME_KEY: blueprint_data.role_qualification, + } + + +def get_worker_eval_function(role_qname: str, onboarding_qname: str): + """ + Returns the callback function that is used for checking worker qualification. + + Check `worker_can_do_unit` of `SharedTaskState` in Mephisto. + """ + + def worker_eval_function(worker, unit): + """ + Checks the worker qualification for the task, based on their existing records. + """ + worker_qualification = worker.get_granted_qualification(role_qname) + if not worker_qualification: + # has not done any onboarding training yet + logging.debug('Worker does not have any qualifications (new worker).') + return True + + qualification_status = worker_qualification.value + logging.debug(f'Worker role qualification is {qualification_status}') + if qualification_status in ( + constants.WIZARD_IN_TRAINING, + constants.APPRENTICE_IN_TRAINING, + ): + # The agent had started the onboarding training but was not finished + onboarding_qual = worker.get_granted_qualification(onboarding_qname) + return not onboarding_qual or not onboarding_qual.value + + # The agent has successfully finished the onboarding training + if unit.unit_index == 0: + return qualification_status == constants.WIZARD + else: + return qualification_status == constants.APPRENTICE + + return worker_eval_function + + +def check_role_training_qualification( + db: LocalMephistoDB, qname: str, requester_name: str +): + """ + Initializes the qualification name in DB, if it does not exist. + """ + + logging.info(f'Checking for "{qname}"" qualification.') + if not db.find_qualifications(qname): + logging.info('Creating the qualification.') + db.make_qualification(qname) + reqs = db.find_requesters(requester_name=requester_name, provider_type='mturk') + requester = reqs[-1] + requester._create_new_mturk_qualification(qname) + else: + logging.info('Qualification exists.') + + +def update_persona_use_counts_file( + fptah: str, counts: Dict[str, int], sorted_order=True +): + """ + Writes the persona use counts to file. + + This is to keep track of use counts for the next time that the task was restarted. + See `load_previously_used_personas_counts` function above. + """ + logging.info(f'Writting new persona counts to {fptah}') + items = counts.items() + if sorted_order: + items = sorted(items, key=lambda x: x[1], reverse=True) + saved_count = 0 + with open(fptah, 'w') as fo: + for p, c in items: + if c > 0: + saved_count += 1 + fo.write(f'{p} ; {c}\n') + logging.info(f'Saved {saved_count} recent persona counts successfully.') + + +def add_banned_words_frontend_conf(task_state, fpath: str = None): + """ + Adds the list of banned words to the task config to be used later in the frontend. + + It reads the text file specified in fpath to populate a list banned words. Then adds + this list to Mephisto `task_config` to make it accessible for the front-end app. The + file specified by `fpath` is a plain text file where each line contains a single + banned word/phrase. + """ + banned_words = [] + if fpath and os.path.exists(fpath): + with open(fpath, 'r') as fin: + banned_words.extend([w.strip().lower() for w in fin if w.strip()]) + + task_state.task_config['bannedWords'] = banned_words + + +@hydra.main(config_name='scriptconfig') +def main(cfg: DictConfig) -> None: + db, cfg = load_db_and_process_config(cfg) + world_opt = get_world_opt(cfg) + onboarding_world_opt = get_onboarding_world_opt(cfg) + shared_state = SharedParlAITaskState( + world_opt=world_opt, onboarding_world_opt=onboarding_world_opt + ) + + check_role_training_qualification( + db=db, + qname=world_opt[constants.ROLE_QUALIFICATION_NAME_KEY], + requester_name=cfg.mephisto.provider.requester_name, + ) + + shared_state.task_config['minTurns'] = world_opt['min_turns'] + shared_state.task_config['onboardingPersona'] = constants.ONBOARDING_PERSONA + shared_state.worker_can_do_unit = get_worker_eval_function( + world_opt[constants.ROLE_QUALIFICATION_NAME_KEY], + onboarding_world_opt['onboarding_qualification'], + ) + + banned_words_fpath = cfg.mephisto.blueprint.banned_words_file + add_banned_words_frontend_conf(shared_state, banned_words_fpath) + + operator = Operator(db) + operator.validate_and_run_config(cfg.mephisto, shared_state) + operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=300) + update_persona_use_counts_file( + cfg.mephisto.blueprint.persona_counts_file, world_opt['prev_persona_count'] + ) + + +if __name__ == '__main__': + main() diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/task_description.html b/parlai/crowdsourcing/tasks/wizard_of_internet/task_description.html new file mode 100644 index 00000000000..1ddee0e7226 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/task_description.html @@ -0,0 +1,57 @@ +

Description

+In this task, you will have a conversation with another person. +You both see a character descriptions that we randomly assign to one of you; +that person will assume that personality for the duration of this chat. +The other person has access to internet search and can use that to find +information about the topics that interest the first person (given their personality description). +The goal of this task is to go into depth about the topics that interest the first person, +based on their assigned personality. +Please chitchat pretending you live in a world WITHOUT COVID. +
+

Sample Character description for Person 1

+I love board games. +
+

Sample Conversation

+Person 1: Hi! I spent yesterday finding a new board game to try. +
+Person 2: Oo, what type of board games? +
+Person 1: I like strategy games, especially ones that are sci-fi +
+Person 2: I love Risk, but it takes place on earth, so not sci-fi, +and it takes forever +
+Person 1: Right? How do you feel about cards against humanity? +
+Person 2: Cards against humanity is fun but a little too risque for me +
+
+ +

Reward/Bonus

+We will reward engaging and knowledgeable chats with a bonus. +
+
+

Close Window/Timeout/Return HIT

+Once the conversation has started, close window/timeout or return HIT during the +chat will result in +HIT EXPIRED to you and NO reward paid. +
+You need to continue the conversation for at least 5 rounds. +
+

Important Notice

+1. Be aware the conversations you have will be made public, so act as you +would e.g. on a public social network like Twitter. +
+2. Please do not send long messages: messages cannot exceed 30 words. +
+3. Please do not reference the task or MTurk itself during the conversation, +but speak naturally to the other person. +
+4. Please do not send any message that could make others uncomfortable, +including any level of discrimination, racism, sexism and offensive +religious/politics comments, otherwise the submission will be rejected. +
+
+

If you are ready, please click "Accept HIT" to start this task.

+
+
diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/InfoPanel.jsx b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/InfoPanel.jsx new file mode 100644 index 00000000000..664b00c883b --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/InfoPanel.jsx @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +import React, { useState } from "react"; + +export default function InfoPanel({ isWizard, personaDesc }) { + return (
+ + +
+ ); +} + +function Persona({ isWizard, personaDesc }) { + const persona_str = personaDesc.toString(); + const persona_style = "info-pane " + "persona-pane" + if (persona_str === "") { + return
; + } + + const persona_lines = personaDesc.split("\n").map((s, i) => {return
  • {s}
  • }) + + const header = (isWizard === true) ? + "Your chat partner has the following personality and interests." : + "You will assume the following character, with appropriate related interests."; + + return (
    +

    {header}

    + +
    ); +} + +function TaskDescription({ isWizard }) { + const className = "info-pane " + "instruction-pane"; + // Apprentice + const apprentice_description = (
    +

    Have a conversation with your chat partner about your favorite topic.

    +

    + In this task, you will have a conversation with a chat partner who has knowledge + about many things, and access to lots of information. + You will be assigned a persona; + the purpose of the task is to then have an in-depth conversation about your assigned interests. + Your partner will strive to enlighten you on these topics. + Note that your conversational partner will not share any interests with you; + the conversation should, and will, focus entirely on your assigned interests. +

    +
    ); + + // Wizard + const [isDetailsHidden, setIsDetailsHidden] = useState(true); + const wizard_task_details = isDetailsHidden ? "" : + (
    +

    + You can look up the information that you need during the conversation + by searching the internet with the search bar provided here. + The outcome of this search shows you a number of internet articles, + separated into sentences. + Try to use the information from these sentences to have an informed conversation. + When you use the knowledge from one or more sentences, + select those sentences before sending the message you crafted. + Please conduct a natural conversation and avoid copy/paste. +

    +

    + Your role in this conversation is to assist your partner in learning + and discussing their interests in detail. + Pretend that you are a knowledgeable entity with conversational ability; + your personal interests do not matter for this conversation. + At the end of the conversation, + your partner should be happy to have talked with you, + but should not know anything about you. +

    +
    ); + const triangleCharCode = isDetailsHidden ? 9658 : 9660; + const hide_show_buttons_text = isDetailsHidden ? " more details" : " less details"; + + const wizard_description = (
    +

    Have a conversation with your chat partner about their favorite topics.

    +

    + You will have a conversation with a chat partner who is interested in a few topics. + Your partner’s interests will be displayed to you ahead of time; + the purpose of the conversation is to discuss your partner’s interests in detail. +

    + + {wizard_task_details} +
    ); + + return (isWizard === true) ? wizard_description : apprentice_description; +} \ No newline at end of file diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/Moderator.js b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/Moderator.js new file mode 100644 index 00000000000..9783ad5a140 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/Moderator.js @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +// Main exported functions in this file (valid_utterance and valid_search_query) +// are simple checks on the validity of user input (text or search query). +// There might also be additional checks on the submitted input in ParlAI world. + + +const STOPWORDS = ['', "\'\'", "\'d", "\'ll", "\'m", "\'re", "\'s", "\'ve", '*', + ',', '--', '.', '?', '\`\`', 'a', 'about', 'above', 'after', + 'again', 'against', 'ain', 'all', 'also', 'am', 'an', 'and', + 'any', 'are', 'aren', 'as', 'at', 'be', 'because', 'been', + 'before', 'being', 'below', 'between', 'both', 'but', 'by', + 'can', 'couldn', 'd', 'did', 'didn', 'do', 'does', 'doesn', + 'doing', 'don', 'down', 'during', 'each', 'few', 'for', 'from', + 'further', 'had', 'hadn', 'has', 'hasn', 'have', 'haven', + 'having', 'he', 'her', 'here', 'hers', 'herself', 'him', + 'himself', 'his', 'how', 'i', 'if', 'in', 'into', 'is', 'isn', + 'it', 'its', 'itself', 'just', 'know', 'll', 'm', 'ma', 'me', + 'mightn', 'more', 'most', 'mustn', 'my', 'myself', "n't", + 'needn', 'no', 'nor', 'not', 'now', 'o', 'of', 'off', 'on', + 'once', 'only', 'or', 'other', 'our', 'ours', 'ourselves', + 'out', 'over', 'own', 'people', 're', 'really', 's', 'same', + 'see', 'shan', 'she', 'should', 'shouldn', 'so', 'some', + 'such', 't', 'than', 'that', 'the', 'their', 'theirs', 'them', + 'themselves', 'then', 'there', 'these', 'they', 'this', + 'those', 'through', 'to', 'too', 'under', 'until', 'up', 've', + 'very', 'want', 'was', 'wasn', 'we', 'were', 'weren', 'what', + 'when', 'where', 'which', 'while', 'who', 'whom', 'why', 'will', + 'with', 'won', 'wouldn', 'y', 'you', 'your', 'yours', + 'yourself', 'yourselves']; + +// MIN_TEXT_LENGTH_TO_CHECK_COPY must always be less than OVERLAP_LENGTH_CHECK +// otherwise the check for copy/paste always passes +const MIN_TEXT_LENGTH_TO_CHECK_COPY = 20; +const OVERLAP_LENGTH_CHECK = 30 +const MIN_OVERLAP_REQUIRED = 1; +const MIN_NUM_WORDS_PER_UTTERANCE = 5 + +const GREETING_FAREWELL_WORDS = ['hi', 'hello', 'bye', 'goodbye']; + +function split_tokenize(text) { + const res = text.replace(/[.|. . .|,|;|:|!|\?|\(|\)]/g, function (x) { + return ` ${x} `; + }); + return res.split(" ").filter((w) => w !== ""); +} + +export default function valid_utterance(text, search_results, selected_results, isOnboarding, taskConfig) { + const bWords = taskConfig.bannedWords; + const lowered_text = text.toLowerCase(); + return !(is_too_short(lowered_text, isOnboarding) || + looks_like_greetings(lowered_text, isOnboarding) || + has_did_you_know(lowered_text) || + has_banned_words(lowered_text, bWords) || + is_copy_pasted(lowered_text, search_results) || + has_turker_words(lowered_text) || + needs_more_overlap_with_selected(lowered_text, search_results, selected_results)); +} + +export function valid_search_query(search_query, taskConfig) { + const bWords = taskConfig.bannedWords; + const lowered_search_query = search_query.toLowerCase(); + return !has_banned_words(lowered_search_query, bWords); +} + +function is_too_short(text, isOnboarding) { + if (isOnboarding) { + return false; + } + + const tokenized_text = split_tokenize(text); + if (tokenized_text.length < MIN_NUM_WORDS_PER_UTTERANCE) { + alert("Your message was too short. Please try again and use longer and more engaging messages."); + return true; + } + return false; +} + +function looks_like_greetings(text, isOnboarding) { + if (isOnboarding) { + return false; + } + const first_word = split_tokenize(text)[0]; + if (GREETING_FAREWELL_WORDS.includes(first_word)) { + alert("Your message looks like a greeting or farewell. Please try again and use more engaging messages."); + return true; + } + return false; +} + +function has_did_you_know(text) { + if (text.includes("did you know") || text.includes("did u know")) { + alert("Please try to be more engaging, and not use the phrase \'did you know\' :)."); + return true; + } + return false; +} + +function has_turker_words(text) { + if (text.includes("turker") || text.includes("turk")) { + return !confirm("Please do not mention the mechanical turk task in the conversation." + + "Press \"Cancel\", to go back and edit, if your message does that, or \"OK\" to send the message."); + } + return false +} + +function has_banned_words(text, banned_words_list) { + const tokenized_text = split_tokenize(text); + + // Checking for banned words + const banned_words = tokenized_text.filter((w) => banned_words_list.indexOf(w) !== -1) + if (banned_words.length > 0) { + const detected_banned_words = banned_words.join(', '); + alert("We have detected the following offensive/banned language in your message: \"" + + detected_banned_words + + "\". Please edit and send again."); + return true + } + return false; +} + +function is_copy_pasted(text, docs) { + if (!docs || docs.length === 0 || text.length < MIN_TEXT_LENGTH_TO_CHECK_COPY) { + return false; + } + + function too_much_char_overlap(check_sentence, source_snetence) { + const n = check_sentence.length; + for (var s = 0; s < n; s += OVERLAP_LENGTH_CHECK) { + const e = Math.min(n, s + OVERLAP_LENGTH_CHECK); + if ((e - s) < MIN_TEXT_LENGTH_TO_CHECK_COPY) { + continue; + } + const small_substr = check_sentence.substring(s, e); + if (source_snetence.includes(small_substr)) { + return true; + } + } + return false; + } + + for (var doc_id in docs) { + const document = docs[doc_id]; + for (var sentence_id in document.content) { + const sentence = document.content[sentence_id].toLocaleLowerCase(); + if (too_much_char_overlap(text, sentence)) { + const q = "\""; + alert("Your message has too much overlap with one of the candidates sentences. " + + "Please retry and avoid copying and pasting."); + return true; + } + } + } + return false; +} + +function needs_more_overlap_with_selected(text, docs, selected) { + if (!docs || + docs.length === 0 || + !selected || + selected.length === 0 || + selected[0][0] === true) { + return false; + } + + var selected_sentences = []; + for (var doc_id = 1; doc_id < selected.length; doc_id++) { + const docSelection = selected[doc_id]; + for (var sent_id = 0; sent_id < docSelection.length; sent_id++) { + if (docSelection[sent_id] === false) { // This sentence was not selected + continue; + } + // "doc_id-1" because selection has an extra value at index 0 (nothing selected) + const lower_selected_sentence = docs[doc_id - 1].content[sent_id].toLowerCase(); + selected_sentences.push(lower_selected_sentence); + } + } + const num_overlaps = overlap_number(selected_sentences.join(" "), text); + if (MIN_OVERLAP_REQUIRED >= num_overlaps) { + return !confirm("Are you sure you are using the right checked sentence? " + + "We have detected a lack of similarity between the checked sentence and your message " + + "(please press \"OK\" if you intended to send this message, " + + "or \"Cancel\" to go back for edit)."); + } + return false; +} + +function overlap_number(selected_sentences, text) { + const PREFIX_LENGTH = 4; + function reduce_to_prefix(s) { + return split_tokenize(s) + .filter((w) => STOPWORDS.indexOf(w) === -1) + .map((w) => w.slice(0, PREFIX_LENGTH)); + } + + const text_tokens = reduce_to_prefix(text); + const sentence_token = reduce_to_prefix(selected_sentences); + return text_tokens.filter((word) => sentence_token.indexOf(word) !== -1).length; +} diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/OnBoardingSidePane.jsx b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/OnBoardingSidePane.jsx new file mode 100644 index 00000000000..5f897a75c49 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/OnBoardingSidePane.jsx @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +import React, { useState } from "react"; +import SearchPanel from "./SearchPanel.jsx"; + +// NOTE: these need to match ONBOARDING_STEPS dict in constants.py +export const OnboardingSteps = { + NOT_ONBOARDING: 0, + CHAT_INTERFACE: 1, + TRY_SEARCH: 2, + PERSONA_WIZARD: 3, + PERSONA_APPRENTICE: 4, + WAITING: 10 +}; + +export default function OnboardingSidePane({ onBoardingStep, mephistoContext, + searchResults, selected, handleSelect, setSearchQuery }) { + let tutorialComponent =
    Error: Unknown OnBoarding Step.
    ; + switch (onBoardingStep) { + case OnboardingSteps.CHAT_INTERFACE: + tutorialComponent = ; + break; + + case OnboardingSteps.TRY_SEARCH: + tutorialComponent = ; + break; + + case OnboardingSteps.PERSONA_WIZARD: + tutorialComponent = ; + break; + + case OnboardingSteps.PERSONA_APPRENTICE: + tutorialComponent = ; + break; + + case OnboardingSteps.WAITING: + tutorialComponent = ; + break; + + default: + console.error("Unrecognized onboarding step " + onBoardingStep); + } + + return ( +
    + {tutorialComponent} +
    + + ) +} + +function OnboardingSidePanel({ + mephistoContext, + searchResults, + selected, + handleSelect, + setSearchQuery, + hideSearchBar, + blinkSearchBar, + hidePersona, + blinkPersona }) { + const peronaDescription = mephistoContext.taskConfig["onboardingPersona"]; + const SearchBar = hideSearchBar ? null + : (
    +
    + +
    +
    ); + return ( +
    +
    + + +
    + {SearchBar} +
    + + ); +} + +function Waiting() { + return ( +
    +
    +

    Instruction.

    +

    + Please wait while we pair you with other participants. +

    +
    + +
    + ); +} + +function Persona({ isBlinking, personaDesc }) { + const blinkingStyle = (isBlinking === true) ? "blinking" : "non-blinking"; + const persona = (personaDesc === "") ? "" : + (
    +

    Character description

    +

    + {personaDesc} +

    +
    ); + return ( +
    + {persona} +
    ); +} + +function TaskDescription() { + return ( +
    +

    Instruction.

    +

    + Our OnboardingBot is sending you instructions (see the right pane). + Follow them to get familiar with this task and its environment, + before we start the main task. +

    +
    ); +} \ No newline at end of file diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/SearchPanel.jsx b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/SearchPanel.jsx new file mode 100644 index 00000000000..dea3da0f32a --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/SearchPanel.jsx @@ -0,0 +1,203 @@ +/* + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +import React, { useState } from "react"; +import { Button } from "react-bootstrap"; +import { valid_search_query } from "./Moderator.js"; + +export default function SearchPanel({ mephistoContext, + searchResults, selected, handleSelect, setSearchQuery, isWizard }) { + const { sendMessage } = mephistoContext; + const { agentState } = mephistoContext; + const { taskConfig } = mephistoContext; + + function handleQuery(props) { + const { query } = props; + setSearchQuery(query); + GetSearchResults(query, sendMessage); + } + + const wizard_sidepane = (agentState.wants_act) ?
    + + + +
    :

    Wait for response!

    ; + + const apprentice_sidepane =
    Enjoy the conversation!
    ; + + const sidepane = (isWizard === true) ? wizard_sidepane : apprentice_sidepane; + + return sidepane; +} + +function SearchResults({ search_results, selected, selectedChange }) { + function SearchDocsGenerator(doc, doc_index) { + const { title } = doc; + const shifted_row_index = doc_index + 1; + const sel_sen = selected[shifted_row_index].slice(); + return ( +
    + +
    + + ); + } + const elements = search_results.map(SearchDocsGenerator); + return (
    + {elements} +
    ); +} + +function GetSearchResults(query, sendMessage) { + const q = query; + const ts = Date.now(); + const message = { + timestamp: ts, + episode_done: false, + id: "Wizard", + text: query, + is_search_query: true + }; + sendMessage(message); +} + + +function SearchBar({ onSubmit, taskConfig }) { + const [text, setText] = useState(""); + + function handleSubmit() { + if (text === "") { + alert("Please insert a term for search in the searh bar"); + return; + } + + if (valid_search_query(text, taskConfig)) { + onSubmit({ query: text }); + } + } + + function handleKeyPresse(event) { + if (event.key === "Enter") { + handleSubmit(); + } + } + + return ( + + ); +} + +function SearcDocTitle({ title, opened, onOpenSelected }) { + // choose the icon unicde (down or right pointing icon) + const triangleCharCode = opened ? 9660 : 9658; + const text = String.fromCharCode(triangleCharCode) + title; + + return ( +
    + +
    + ); +} + +function SearchDocSentence({ sentence, selected, onChange, loc }) { + const sent_id = loc[1]; + const isChecked = selected[sent_id]; + return ( +
    + onChange(loc)} + /> + {sentence} +
    + ); +} + +function SearchDoc({ document, doc_id, selected_sentences, onChange }) { + const title = document.title; + const sentences = document.content; + const [opened, setOpened] = useState(false); + + function SentenceCheckBoxGenerator(sentence, sentenc_id) { + const location = [doc_id, sentenc_id]; + const key = doc_id.toString() + "_" + sentenc_id.toString(); + return ( + + ); + } + + const sents = sentences.slice(); + const sentence_selectors = sents.map(SentenceCheckBoxGenerator); + return ( +
    + (setOpened(!opened))} /> + {(opened) ? sentence_selectors : null} +
    + ); +} + +function NoDocumentSelected({ selected, selectedChange }) { + if ((!selected) || (selected.length === 0)) { + return null; + } + const loc = [0, 0]; + const isChecked = selected[0][0]; + return ( +
    + selectedChange(loc)} + /> + + Did not use search results for this message. + +
    + ) +} \ No newline at end of file diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/SidePane.jsx b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/SidePane.jsx new file mode 100644 index 00000000000..2bfb7eadb0c --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/SidePane.jsx @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +import React from "react"; +import InfoPanel from "./InfoPanel.jsx"; +import SearchPanel from "./SearchPanel.jsx"; +import OnboardingSidePane from "./OnBoardingSidePane.jsx"; + +export default function SidePane({ + mephistoContext, + appContext, + searchResults, + selected, + handleSelect, + setSearchQuery, + onBoardingStep, + isWizard, + apprenticePersona }) { + + const { currentAgentNames } = appContext.taskContext; + if (!currentAgentNames) { // The task is not started yet. + return ; + } + + if (onBoardingStep > 0) { + return ; + } + + // Hidding the search bar while the agents are choosing persona + const searchPanel = (!apprenticePersona || apprenticePersona === "") ? null : + (
    + +
    ) + + return ( +
    +
    + +
    + {searchPanel} +
    + ) +} + +function WaitForStart() { + return
    + Please wait! +

    System is adding the partner and setting up the service.

    + Matching may take up to 15 minutes (based on the number of online people). Please do not leave the chat. +
    +} \ No newline at end of file diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/styles.css b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/styles.css new file mode 100644 index 00000000000..42c5df6dc41 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/components/styles.css @@ -0,0 +1,115 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + #ui-container { + display: contents; +} + +.side-pane { + width: 55%; +} + +#info-bar { + display: flex; + flex-direction: row; +} + +.info-pane { + width: 50%; + padding: 2%; +} + +.instruction-pane { + background-color: rgb(250, 210, 230); +} + +.persona-pane { + background-color: lightblue; +} + +#search-area { + background-color: rgb(128, 202, 140); + padding: 1%; +} + +#no-answer { + background-color: rgb(245, 129, 129); + border-radius: 5px; + margin: 2px; + margin-bottom: 1%; + padding: 2%; +} + +#search-bar { + background-color: rgb(143, 245, 129); + border-radius: 5px; + margin: 2px; + margin-bottom: 1%; + padding: 2%; +} + +.search-results { + background-color: rgb(103, 211, 89); + padding: 2px; + margin: 2px; + margin-bottom: 2%; + border-radius: 5px; +} + +.doc-title { + background-color: rgb(166, 223, 74); + padding: 2px; + margin: 2px; + border-width: 0; + border-radius: 3px; +} + +.doc-sentence { + background-color: rgb(101, 179, 124); + padding: 2px; + margin: 2px; + border-radius: 3px; +} + +.doc-button { + border-width: 0; + background-color: rgba(0, 0, 0, 0); + text-align: left +} + +.expand-details-button { + border-width: 0; + background-color: rgba(0, 0, 0, 0); + text-align: left +} + +.blinking { + border: 5px; + border-style: solid; + border-radius: 20px; + border-color: rgba(0, 0, 0, 0); + animation-name: blink-animation; + animation-duration: 1s; + animation-delay: 1s; + animation-iteration-count: 10; + animation-direction: alternate; + animation-timing-function: ease-in-out; + } + + .non-blinking { + border: 5px; + border-style: solid; + border-color: rgba(0, 0, 0, 0); + } + + @keyframes blink-animation { + from { + border-color: rgba(200, 0, 0, 50); + } + to { + border-color: rgba(0, 0, 0, 0); + } + } \ No newline at end of file diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/main.js b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/main.js new file mode 100644 index 00000000000..94a608b1ad3 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/webapp/main.js @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +import React, { useState } from "react"; +import ReactDOM from "react-dom"; +import "bootstrap-chat/styles.css"; + +import { FormControl, Button } from "react-bootstrap"; +import { ChatApp, ChatMessage, INPUT_MODE } from "bootstrap-chat"; + +import SidePane from "./components/SidePane.jsx"; +import { OnboardingSteps } from "./components/OnBoardingSidePane.jsx"; +import valid_utterance from "./components/Moderator.js"; + +import "./components/styles.css"; + +function isOnboardingMessage(message) { + const sender_id = message["id"].toLocaleLowerCase(); + return sender_id.startsWith("onboarding"); +} + +function newMessageHandler( + messages, + setApprenticePersona, + setOnBoardingStep, + updateSearchResults, + updateSelectedSearchResults) { + + if ((!messages) || (messages.length < 1)) { + return; + } + + function resetSelected(search_results) { + var selected = [[false]]; + for (let doc_id = 0; doc_id < search_results.length; doc_id++) { + const doc = search_results[doc_id]; + var selected_sentences = []; + for (let sentence_id = 0; sentence_id < doc.content.length; sentence_id++) { + selected_sentences.push(false); + } + selected.push(selected_sentences); + } + updateSelectedSearchResults(selected); + } + + const msg = messages[messages.length - 1] + if (msg.id === "SearchAgent") { + const searh_res = msg.task_data["search_results"]; + resetSelected(searh_res); + updateSearchResults(searh_res); + } else if (msg.id === "PersonaAgent") { + const persona = msg['task_data']['apprentice_persona']; + setApprenticePersona(persona); + } else if (msg.id === "OnboardingBot") { + const taskDataKey = 'task_data'; + const onboardingStepKey = 'on_boarding_step'; + if ((taskDataKey in msg) && (onboardingStepKey in msg[taskDataKey])) { + const stepDuringOnboarding = msg['task_data'][onboardingStepKey]; + console.log("Setting onboarding step to ", stepDuringOnboarding); + setOnBoardingStep(stepDuringOnboarding); + } + } +} + +function isWizard(mephistoContext, appContext, onBoardingStep) { + if ((onBoardingStep === OnboardingSteps.TRY_SEARCH) || + (onBoardingStep === OnboardingSteps.PERSONA_WIZARD)) { + return true; + } + + const { agentId } = mephistoContext; + const { currentAgentNames } = appContext.taskContext; + if (!currentAgentNames) { + return false; + } + const agentName = currentAgentNames[agentId]; + const answer = (agentName === "Wizard") ? true : false; + return answer; +} + +function RenderChatMessage({ message, mephistoContext, appContext, setIsMinTurnsMet, isInMainTask }) { + if (message.text === "") { + return null; + } + + // TODO: replace this hacky solution for removing remaining message from onboarding + // with a better solution that purges them from the messages list. + if (isInMainTask && isOnboardingMessage(message)) { + return null; + } + + const { agentId } = mephistoContext; + const { currentAgentNames } = appContext.taskContext; + const taskDataKey = 'task_data'; + const searchResultsKey = 'search_results'; + if ((taskDataKey in message) && (searchResultsKey in message[taskDataKey])) { + // The received message comes from Search query: DO NOT ADD TO CHAT. + return null; + } + + const nTurnsKeys = "utterance_count"; // Because this is observed and comes from seen messages + if (("task_data" in message) && (nTurnsKeys in message.task_data)) { + const numTurns = message.task_data[nTurnsKeys]; + const minNumTurns = mephistoContext.taskConfig["minTurns"]; + if (numTurns > minNumTurns) { + setIsMinTurnsMet(true); + } + } + + const isSelf = ((message.id === agentId) || (message.id in currentAgentNames)); + let shownName; + if (isSelf) { + shownName = "You"; + } else { + if (["Wizard", "Apprentice"].includes(message.id)) { + shownName = "Your Partner"; + } else { + shownName = message.id; + } + } + + return ( +
    + +
    + ); +} + +function CustomTextResponse({ + taskConfig, + onMessageSend, + active, + searchQuery, + setSearchQuery, + searchResults, + setSearchResults, + selectedSearchResults, + setSelectedSearchResults, + isMinTurnMet, + isWizard, + isOnboarding, +}) { + const [textValue, setTextValue] = React.useState(""); + const [sending, setSending] = React.useState(false); + + const inputRef = React.useRef(); + + React.useEffect(() => { + if (active && inputRef.current && inputRef.current.focus) { + inputRef.current.focus(); + } + }, [active]); + + const trySignalFinishChat = () => { + if (active && !sending) { + setSending(true); + onMessageSend({ text: "", requested_finish: true }).then( + () => { setSending(false); }) + } + }; + + function needSelection(selMatrix) { + if (!isWizard) { + return false; + } + for (var i = 0; i < selMatrix.length; i++) { + for (var j = 0; j < selMatrix[i].length; j++) { + if (selMatrix[i][j]) { + return false + } + } + } + return true; + } + + const tryMessageSend = React.useCallback(() => { + if (needSelection(selectedSearchResults)) { + alert("Please select an option from the left panel.") + return; + } + if (textValue !== "" && + active && + !sending && + valid_utterance(textValue, searchResults, selectedSearchResults, isOnboarding, taskConfig)) { + setSending(true); + onMessageSend({ + timestamp: Date.now(), + text: textValue, + task_data: { + search_query: searchQuery, + text_candidates: searchResults, + selected_text_candaidtes: selectedSearchResults, + } + }).then(() => { + setTextValue(""); + setSearchQuery(""); + setSearchResults([]); + setSelectedSearchResults([[false]]); + setSending(false); + }); + } + }, [textValue, active, sending, onMessageSend, selectedSearchResults]); + + const handleKeyPress = React.useCallback( + (e) => { + if (e.key === "Enter") { + tryMessageSend(); + e.stopPropagation(); + e.nativeEvent.stopImmediatePropagation(); + } + }, + [tryMessageSend] + ); + + const finishButton = (isMinTurnMet) ? + : null; + + return ( +
    +
    + { + inputRef.current = ref; + }} + value={textValue} + placeholder="Enter your message here..." + onKeyPress={(e) => handleKeyPress(e)} + onChange={(e) => setTextValue(e.target.value)} + disabled={!active || sending} + /> + + {finishButton} +
    +
    + ); +} + +function MainApp() { + const [searchQuery, setSearchQuery] = useState(""); + const [searchResults, setSearchResults] = useState([]); + const [selectedSearchResults, setSelectedSearchResults] = useState([[false]]); + const [isMinTurnMet, setIsMinTurnsMet] = useState(false); + const [apprenticePersona, setApprenticePersona] = useState(""); + const [onBoardingStep, setOnBoardingStep] = useState(OnboardingSteps.NOT_ONBOARDING); + + function handleSelect(loc) { + const [doc_id, sentence_id] = loc; + const new_selected = selectedSearchResults.slice(); + if ((doc_id === 0) && (sentence_id === 0)) { // No sentce selected + new_selected[0][0] = !new_selected[0][0]; + for (var i = 1; i < new_selected.length; i++) { + for (var j = 0; j < new_selected[i].length; j++) { + new_selected[i][j] = false; + } + } + } else { // Any other selected + new_selected[0][0] = false; + const prev_val = selectedSearchResults[doc_id][sentence_id]; + new_selected[doc_id][sentence_id] = !prev_val; + } + setSelectedSearchResults(new_selected); + } + + return ( +
    + ( + + )} + renderTextResponse={({ onMessageSend, inputMode, mephistoContext, appContext }) => + () + } + onMessagesChange={(messages) => ( + newMessageHandler(messages, + setApprenticePersona, + setOnBoardingStep, + setSearchResults, + setSelectedSearchResults))} + + renderSidePane={({ mephistoContext, appContext }) => ( + + )} + /> +
    + ); +} + +ReactDOM.render(, document.getElementById("app")); diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/wizard_internet_blueprint.py b/parlai/crowdsourcing/tasks/wizard_of_internet/wizard_internet_blueprint.py new file mode 100644 index 00000000000..d2b1b206e1e --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/wizard_internet_blueprint.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from mephisto.operations.registry import register_mephisto_abstraction +from mephisto.abstractions.blueprints.parlai_chat.parlai_chat_blueprint import ( + ParlAIChatBlueprintArgs, + ParlAIChatBlueprint, +) +from mephisto.abstractions.blueprint import SharedTaskState +from omegaconf import DictConfig, MISSING + +WIZARD_INTERNET_PARLAICHAT_BLUEPRINT = 'wizard_internet_parlaichat_blueprint' + + +@dataclass +class WizardOfInternetBlueprintArgs(ParlAIChatBlueprintArgs): + _blueprint_type: str = WIZARD_INTERNET_PARLAICHAT_BLUEPRINT + _group: str = field( + default='WizardInternetParlAIChatBlueprint', + metadata={ + 'help': """ + ParlAI chat between two agents with one agent having access to a search API + that retrieves data from internet (common crawl snapshot). In order to run, + the search API needs to be up and running. + """ + }, + ) + + role_qualification: str = field( + default=MISSING, + metadata={ + 'help': """ + Specify the role (wizard or apprentice) that agents are trained on + during their onboarding. + """ + }, + ) + + min_turns: int = field( + default=MISSING, + metadata={ + 'help': """ + The minimum number of turns before showing the finish button on chat interface + and allowing the agents to end the conversations cleanly. + """ + }, + ) + + wizard_time_out: int = field( + default=180, + metadata={'help': 'Maximum allowed time (seconds) for Wizard, each round.'}, + ) + + apprentice_time_out: int = field( + default=60, + metadata={'help': 'Maximum allowed time (seconds) for Apprentice, each round.'}, + ) + + personas_file: str = field( + default=MISSING, + metadata={'help': 'Path to a text file that keeps a list of curated personas.'}, + ) + + persona_counts_file: str = field( + default=MISSING, + metadata={ + 'help': 'A semicolon seperated list of personas and their count (file)' + }, + ) + + shuffle_persona: bool = field( + default=True, metadata={'help': 'Whether to shuffle the persona list'} + ) + + use_personas_with_replacement: bool = field( + default=False, + metadata={'help': 'Using true does not discard personas after use.'}, + ) + + banned_words_file: str = field( + default=MISSING, + metadata={ + 'help': """ + Path to a text file with a list of banned words to block in the UI. + Each row in the file is one word/phrase. + User will receieve an alert and are asked to rephrase, if there is an exact match. + """ + }, + ) + + max_times_persona_use: int = field( + default=0, + metadata={ + 'help': """ + Maximum number of times to allow a particular persona to be used. + Default (0) mean no limit. + """ + }, + ) + + locations_file: str = field( + default=MISSING, + metadata={ + 'help': """ + Path to a text file that keeps a list of locations that will be added + to some of the curated personas (marked for needing persona). + """ + }, + ) + + search_server: str = field( + default=MISSING, metadata={'help': 'Address to the search API.'} + ) + + num_passages_retrieved: int = field( + default=5, + metadata={'help': 'The number of documents to request from search API.'}, + ) + + search_warning_turn: int = field( + default=2, + metadata={ + 'help': 'The round that wizard may receive warning for using more search.' + }, + ) + + search_warning_threshold: int = field( + default=2, + metadata={ + 'help': """ + The minimum number of times that wizard needs to use the search bar, + at the rounds that we check for sending them a warning + (warning is not send if wizard has used search more than this many times). + """ + }, + ) + + select_warning_turn: int = field( + default=3, + metadata={ + 'help': 'The round that Wizard may receive warning to select more search results' + }, + ) + + select_warning_threshold: int = field( + default=2, + metadata={ + 'help': """ + The minimum number of knowledge selections that wizard needs to have, + at the round that we check for sending them a warning + (warning is not send if Wizard has selected saerch results + at least this many times so far). + """ + }, + ) + + +@register_mephisto_abstraction() +class WizardOfInternetBlueprint(ParlAIChatBlueprint): + BLUEPRINT_TYPE = WIZARD_INTERNET_PARLAICHAT_BLUEPRINT + ArgsClass = WizardOfInternetBlueprintArgs + + @classmethod + def assert_task_args( + cls, args: 'DictConfig', shared_state: 'SharedTaskState' + ) -> None: + """ + Ensure that arguments are properly configured to launch this task. + """ + ParlAIChatBlueprint.assert_task_args(args=args, shared_state=shared_state) + blueprint = args.get('blueprint') + # Check search module is valid + assert hasattr(blueprint, 'search_server'), 'Provide search API address.' + + assert hasattr(blueprint, 'use_personas_with_replacement') + assert hasattr(shared_state, 'world_opt') + assert 'personas' in shared_state.world_opt + + # Number of personas is enough for running without replacement + if not blueprint.get('use_personas_with_replacement'): + n_personas = len(shared_state.world_opt['personas']) + n_conversations = blueprint.get('num_conversations') + assert ( + n_personas >= n_conversations + ), f'{n_personas} personas are not enought to use uniquely for {n_conversations} conversations.' + + # Make sure that we first show the warning for using search more often + # to the wizard, and then the warning for selecting more sentences. + assert blueprint.get('search_warning_turn') <= blueprint.get( + 'select_warning_turn' + ) diff --git a/parlai/crowdsourcing/tasks/wizard_of_internet/worlds.py b/parlai/crowdsourcing/tasks/wizard_of_internet/worlds.py new file mode 100644 index 00000000000..c921f5ef7b7 --- /dev/null +++ b/parlai/crowdsourcing/tasks/wizard_of_internet/worlds.py @@ -0,0 +1,1396 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from copy import deepcopy +import random +import time +from datetime import datetime +from typing import Any, Dict, List, Union +from joblib import Parallel, delayed + +from parlai.agents.rag.retrieve_api import SearchEngineRetriever +from parlai.crowdsourcing.utils.worlds import CrowdOnboardWorld, CrowdTaskWorld +from parlai.core.agents import Agent +from parlai.core.message import Message +from parlai.core.opt import Opt +from parlai.core.worlds import validate +import parlai.utils.logging as logging +from parlai.crowdsourcing.tasks.wizard_of_internet import constants +from parlai.crowdsourcing.tasks.wizard_of_internet.acceptability import ( + WizardOfInternetAcceptabilityChecker, +) + +from mephisto.abstractions.blueprint import AgentState +from mephisto.abstractions.databases.local_database import LocalMephistoDB + +mephisto_db = LocalMephistoDB() +ROLE_TALLY_CHACHE = {'data': None, 'last_update': None} + + +def sec_to_min_pretty(time_secs: int) -> str: + """ + Returns formatted string for converting secs to mins. + """ + if time_secs % 60 == 0: + return f'{time_secs // 60}' + m = time_secs / 60 + return f'{m:.2g}' + + +def get_worker_from_agent(agent: Agent): + """ + Returns Mephisto worker for a given ParlAI agent. + """ + return agent.mephisto_agent.get_worker() + + +def get_worker_by_name(worker_name: str): + """ + Returns the Mephisto worker from their worker name. + """ + workers = mephisto_db.find_workers(worker_name) + if len(workers) != 1: + logging.warning(f'Found {len(workers)} for worker {worker_name}') + if not workers: + return + return workers[0] + + +def _is_wiz(agent: Agent): + """ + Returns true if the `agent` is wizard. + """ + return agent.agent_id == 'Wizard' + + +def _is_query(act: Message): + """ + Checks if an agent action is a search query (only wizard can do this). + """ + k = 'is_search_query' + return k in act and act[k] + + +def _has_selected_sentence_from_search_results(action: Union[Dict, Message]): + """ + Whether there is any knowledges selected with this message. + """ + k_task = 'task_data' + k_selected = 'selected_text_candaidtes' + if (k_task in action) and (k_selected in action[k_task]): + # Boolean value that user has not selected any option + return not action[k_task][k_selected][0][0] + return False + + +def create_search_agent(opt): + """ + Creates and instance of SearchEngineRetriever object. + """ + logging.info('Initializing the search engine API.') + search_api_opt = deepcopy(opt) + search_api_opt['skip_retrieval_token'] = None + return SearchEngineRetriever(search_api_opt) + + +def run_search_query(query: str, search_client: SearchEngineRetriever): + """ + Conducts search through the SearchEngineRetriever client, and sorts the retrieved + docs. + + This function runs two searches for each query: + 1- + " news" + 2- + + The + """ + + def _search(q: str, n: int): + """ + Sends the search query to the search API. + """ + return search_client.retrieve([q], n)[0] + + def _dedupl_docs(docs_list): + uniq_docs = [] + seen_urls = set() + for d in docs_list: + url = d['url'] + if url in seen_urls: + continue + uniq_docs.append(d) + if len(uniq_docs) == constants.NUM_RETRIEVED_SEARCH_DOCS: + return uniq_docs + seen_urls.add(url) + logging.warning( + f'Only retrieved {len(uniq_docs)}, not {constants.NUM_RETRIEVED_SEARCH_DOCS}' + ) + return uniq_docs + + def _wiki_sort_key(doc): + """ + Helper function to put the Wikipedia pages last in ranking retrieved doc + results. + """ + url = doc['url'] + return 1 if url.startswith('https://en.wikipedia') else -1 + + if not search_client: + logging.error('No search client; can not run search request.') + return + logging.info(f'Running search for query "{query}"') + + # getting query with news + query_had_news = 'news' in query + if not query_had_news: + search_results = _search(f'{query} news', constants.NUM_RETRIEVED_SEARCH_NEWS) + else: + search_results = [] + + # getting web documents for the main search query + search_results.extend(_search(query, constants.NUM_RETRIEVED_SEARCH_DOCS)) + + # Remove a doc that was fetched by both news and regular search + # and reduce the number of dosc to NUM_RETRIEVED_SEARCH_DOCS + if not query_had_news: + # We did not have two separate queries if query_had_news was True. + search_results = _dedupl_docs(search_results) + + # Sorting retrieved docs based on their URL: Wikipedia pages go last. + search_results.sort(key=_wiki_sort_key) + + return Message( + { + 'id': constants.SEARCH_AGENT, + 'text': '*** SEARCH AGENT RESULTS (CHECK ACCOMPANIED DATA FOR RETRIEVED DOCS) ***', + 'task_data': {'search_results': search_results}, + } + ) + + +def _coordinator_send_message( + agent, message: str = '', task_data: Dict = None, episode_done: bool = False +): + """ + Sends a message to 'agent' from the coordinator. + + We use this to send a message to only one of the agents. It usually contains + specific instructions, alerts, or warnings for certain situations during the task. + """ + if not task_data: + task_data = dict() + agent.observe( + { + 'id': constants.COORDINATOR_AGENT, + 'text': message, + 'episode_done': episode_done, + 'task_data': task_data, + } + ) + + +def persona_from_template_values(topic: str, topic_item: str, extra_details: str = ''): + """ + Generates a sentence stating the persona of the apprentice, given their selection. + """ + pers = f'My favorite {topic} is {topic_item}.' + if extra_details: + pers += f'\n{extra_details}' + return pers + + +def _form_response_get_field(form_response: Dict[str, Any], filed_num: int): + """ + Extracts the value of a certain field from the Mephisto response. + """ + frd = form_response['task_data'] + k = 'form_responses' + if k in frd and len(frd[k]) and (filed_num < len(frd[k])): + return frd[k][filed_num]['response'] + + +def _form_response_main_persona(form_response: Dict[str, Any]): + """ + Extracts the main selected persona from persona selection form response. + """ + topic = _form_response_get_field(form_response, 0) + entity = _form_response_get_field(form_response, 1) + return persona_from_template_values(topic, entity) + + +def _form_response_persona_expantion(form_response: Dict[str, Any]): + """ + Extracts the expanded details of persona from persona selection form response. + """ + return _form_response_get_field(form_response, -1) + + +def _send_persona_too_short_warning(agent: Agent, persona_expantion: str): + """ + Sends a warning to agent if persona details it too short. + """ + _coordinator_send_message( + agent, + message=f'Your expansion on persona ("{persona_expantion}") was too short. ' + 'Please rewrite to make a more elaborate and refined persona.', + ) + + +def _send_persona_overuse_warning(agent: Agent, main_persona: str): + """ + Ask agent to choose another persona, if the selected one looks repeated. + + For example, we don't want 200 pesonas that reads "My favorite book is Harry + Potter". + """ + _coordinator_send_message( + agent, + message=f'The character you chose for the persona ("{main_persona}")' + ' has already been used by others. Please choose some other character.', + ) + + +class SharedOnboardWorld(CrowdOnboardWorld): + """ + The parent (base) onboarding class for both agents. + """ + + def __init__(self, opt: Opt, mturk_agent: Agent): + super().__init__(opt, mturk_agent) + self.agent.agent_id = 'Participant' + self.role_training_qname = opt[constants.ROLE_QUALIFICATION_NAME_KEY] + self._world_name = self._get_world_name() + self._num_rounds = 0 + self.messages = [] + + def _get_world_name(self): + """ + Assigns a name to this world. + """ + dt = datetime.now() + return f'onboarding_world_{dt.strftime("%H-%M-%S")}' + + def wait_for_response(self, message: str = None, delay_time: int = 0): + """ + Starts waiting for a response from the agent, after `delay_time` many seconds. + """ + self._num_rounds += 1 + logging.info( + f'{self._world_name} waiting for response at round {self._num_rounds}' + ) + if delay_time > 0: + time.sleep(delay_time) + self.agent.observe( + {'id': constants.ONBOARDING_AGENT, 'text': message, 'episode_done': False} + ) + self.messages.append(self.agent.act(timeout=self.turn_timeout)) + + def send_message( + self, + message: str, + onboarding_step: int = None, + done: bool = False, + delay_time: int = 0, + ): + """ + Sends the next onboarding instruction to the agent. + """ + task_data = dict() + if onboarding_step: + task_data['on_boarding_step'] = onboarding_step + act = { + 'id': constants.ONBOARDING_AGENT, + 'text': message, + 'episode_done': done, + 'task_data': task_data, + } + if delay_time > 0: + time.sleep(delay_time) + self.agent.observe(act) + + def introduce_chat_interface(self): + """ + Showing the welcome onboard message to the agent, the first step during the + onboarding. + """ + self.send_message( + message=constants.ONBOARDING_WELCOME, + onboarding_step=constants.ONBOARDING_STEPS["CHAT_INTERFACE"], + done=True, + ) + + def go_for_start(self): + """ + The onboarding graduation message. + """ + self.send_message(message=constants.FINISHED_ONBOARDING, done=False) + # waiting for agent to read the final message + # then ending the onboarding (purging the onboarding world). + time.sleep(5) + self.send_message( + message="", onboarding_step=constants.ONBOARDING_STEPS["WAITING"], done=True + ) + + def parley(self): + """ + Provides a step by step scripted interactive onboarding for the agents. + + In each step, we introduce the agent to one part of their experience and + expectations in this task (eg, persona, chat interface etc.). Then, after a + short delay (parametrized by TUTORIAL_WAIT_TIMES), we as ask them to send a + response to move to the next step. This method needs to be implemented for + Wizard and Apprentice separately, as they have different onboarding experiences. + """ + error_message = "Implement parley for each role individually." + raise NotImplementedError(error_message) + + def get_worker(self): + return get_worker_from_agent(self.agent) + + def get_worker_name(self): + return self.get_worker().worker_name + + def grant_agent_training_qualification(self, role_id: int): + """ + Granting the onboarding qualification to the agent, based on their assigned + role. + """ + role = constants.ROLE_QUALIFICATION_NAME_KEY[role_id] + logging.info(f'Granting worker qualification for {role} role.') + worker = self.get_worker() + worker.grant_qualification(self.role_training_qname, role_id) + + def reason_to_reject(self): + """ + Check for bad behavior for poor quality of work from agent. + """ + if not self.episodeDone: + return 'left/diconnected before the task was over.' + + # messages were too short + messages_len = [] + for msg in self.messages: + if self.agent.agent_id != msg['id']: + # Not from this agent + continue + messages_len.append(len(msg['text'])) + msg_char_length_avg = sum(messages_len) / len(messages_len) + if msg_char_length_avg < constants.MIN_AVG_CHAR_LENGTH_UTTERANCES: + return ( + 'messages were too short for meaningfull conversations ' + f'(average message length: {msg_char_length_avg:.2f} chars).' + ) + + # how many times talked abut persona + n_persona_keyword_mentions = 0 + for msg in self.messages: + if self.agent.agent_id != msg['id']: + continue + for keyword in constants.ONBOARDING_PERSONA_KEYWORDS: + if keyword in msg['text'].lower(): + n_persona_keyword_mentions += 1 + + if n_persona_keyword_mentions < 1: + return ( + 'Did not talk enough about the persona. ' + f'Number of keyword overlaps: {n_persona_keyword_mentions}.' + ) + + # returning None means no reason to reject + return None + + def shutdown(self): + logging.info(f'Shutting down {self._world_name}') + super().shutdown() + logging.info('Shutdown completed successfully.') + + +class WizardOnboardingWorld(SharedOnboardWorld): + """ + The onboarding world for the wizard agent. + """ + + def __init__(self, opt: Opt, mturk_agent: Agent): + self.turn_timeout = opt['wizard_time_out'] + self._search_client = create_search_agent(opt) + self.num_searches = 0 + super().__init__(opt, mturk_agent) + + def _get_world_name(self): + return f'wizard-{super()._get_world_name()}' + + def introduce_knowledgeable_entity(self): + self.send_message(constants.WIZARD_INTRODUCE_KNOWLEDGE) + + def introduce_search(self): + self.send_message(message=constants.WIZARD_INTRODUCE_SEARCH) + + def try_search(self): + self.send_message( + message=constants.WIZARD_TRY_SEARCH, + onboarding_step=constants.ONBOARDING_STEPS['TRY_SEARCH'], + ) + + def introduce_persona(self): + self.send_message( + message=constants.WIZARD_INTRODUCE_APPRENTICE_PERSONA, + onboarding_step=constants.ONBOARDING_STEPS['PERSONA_WIZARD'], + ) + + def wait_for_response_with_search(self, message: str = None, delay_time: int = 0): + """ + Send a message to Wizard and waits for a search or response action. + """ + if message: + self.send_message(message=message, delay_time=delay_time) + + time_out = self.turn_timeout + agent = self.agent + while time_out > 0: + start_time = time.time() + act = agent.act(timeout=time_out) + if _is_query(act): + self.num_searches += 1 + search_query = act['text'] + search_res = run_search_query(search_query, self._search_client) + n = len(search_res['task_data']['search_results']) + logging.info( + f'Retrieved {n} documents for search query "{search_query}".' + ) + agent.observe(search_res) + else: + self.messages.append(act) + return + # subtracting the wait time from what was spent during search + spent_time = time.time() - start_time + time_out -= spent_time + + def parley(self): + """ + The interactive onboarding for the Wizard. + """ + wait_times = constants.TUTORIAL_WAIT_TIMES + self.introduce_chat_interface() + self.wait_for_response( + message='Please type a greeting message to continue.', + delay_time=wait_times['chat-interface'], + ) + self.introduce_knowledgeable_entity() + self.wait_for_response( + message=constants.ONBOARDING_ACKNOWLEDGE_UNDERSTOOD, + delay_time=wait_times['chat-interface'], + ) + self.introduce_search() + self.wait_for_response( + message=constants.ONBOARDING_ACKNOWLEDGE_UNDERSTOOD, + delay_time=wait_times['knowledge'], + ) + self.try_search() + self.wait_for_response_with_search() + self.introduce_persona() + self.wait_for_response_with_search() + self.go_for_start() + self.episodeDone = True + + def reason_to_reject(self): + """ + Check for bad behavior for poor quality of work from wizard agent. + """ + # Has used search enough + if self.num_searches < constants.MIN_NUM_SEARCH_ONBOARDING: + return f'did not use search enough (number of use {self.num_searches}).' + + # Has selected enough sentenes + num_selections = 0 + for msg in self.messages: + task_data = msg.get('task_data') + if not (task_data and isinstance(task_data, dict)): + continue + sel_options = task_data.get('selected_text_candaidtes') + if not sel_options or len(sel_options) == 1: # No choices + continue + if not sel_options[0][0]: + # sel_options[0][0] is "Did no use ..." option + num_selections += 1 + + if num_selections < constants.MIN_NUM_SELECTED_SENTENCES_ONBOARDING: + return ( + 'did not use or select search results enough times ' + f'(number of times used: {num_selections})' + ) + return super().reason_to_reject() + + def prep_save_data(self, agent: Agent): + """ + Saving session data after the world is closed. + """ + rejection_reason = self.reason_to_reject() + qualified_role = constants.WIZARD if self.episodeDone else constants.NO_ROLE + return { + constants.SAVED_DATA_IS_WIZARD_KEY: True, + constants.SAVED_DATA_WORKER_KEY: self.get_worker_name(), + constants.SAVED_DATA_ROLE_QUALIFICATION_DATA_KEY: ( + self.role_training_qname, + qualified_role, + ), + constants.WORKER_REJECT_REASON: rejection_reason, + } + + +class ApprenticeOnboardingWorld(SharedOnboardWorld): + def __init__(self, opt, mturk_agent): + self.turn_timeout = opt['apprentice_time_out'] + super().__init__(opt, mturk_agent) + + def _get_world_name(self): + return f'apprentice-{super()._get_world_name()}' + + def introduce_persona(self): + self.send_message( + message=constants.APPRENTICE_INTRODUCE_PERSONA, + onboarding_step=constants.ONBOARDING_STEPS['PERSONA_APPRENTICE'], + ) + + def introduce_partner_entity(self): + self.send_message(message=constants.APPRENTICE_INTRODUCE_WIZARD) + + def introduce_partner_knowledge(self): + self.send_message(message=constants.APPRENTICE_INTRODUCE_WIZARD_KNOWLEDGE) + + def parley(self): + """ + The interactive onboarding for the Apprentice. + """ + wait_times = constants.TUTORIAL_WAIT_TIMES + self.introduce_chat_interface() + self.wait_for_response( + message='Please type a greeting message to continue.', + delay_time=wait_times['chat-interface'], + ) + self.introduce_persona() + self.wait_for_response( + message=constants.APPRENTICE_PERSONA_ROLE_INSTRUCTION, + delay_time=wait_times['persona'], + ) + self.introduce_partner_entity() + self.wait_for_response( + message=constants.APPRENTICE_CHITCHAT_INSTRUCTION, + delay_time=wait_times['persona'], + ) + self.introduce_partner_knowledge() + self.wait_for_response( + message=constants.APPRENTICE_PERSONA_MSG_INSTRUCTION, + delay_time=wait_times['knowledge'], + ) + self.go_for_start() + self.episodeDone = True + + def prep_save_data(self, agent: Agent): + """ + Saving session data after the world is closed. + """ + rejection_reason = self.reason_to_reject() + qualified_role = constants.APPRENTICE if self.episodeDone else constants.NO_ROLE + return { + constants.SAVED_DATA_IS_WIZARD_KEY: False, + constants.SAVED_DATA_WORKER_KEY: self.get_worker_name(), + constants.SAVED_DATA_ROLE_QUALIFICATION_DATA_KEY: ( + self.role_training_qname, + qualified_role, + ), + constants.WORKER_REJECT_REASON: rejection_reason, + } + + +class MTurkMultiAgentDialogWorld(CrowdTaskWorld): + """ + The ParlAI world to run conversation, search, and flow. + + Two agents (wizard, apprentice) chat. One agent (wizard) has access to a search bar + that they may use for seraching our knowledge source (common crawl here). + """ + + def __init__(self, opt: Opt, agents: List[Agent] = None): + # Init world state + self.agents = agents + self._change_agents_order = False + self.messages = [] + self.episodeDone = False + self.turn_idx = 0 + self.num_search_queries = 0 + self.num_times_search_resutls_selected = 0 + self.world_tag = self._get_world_name() + + # Get world parameters from opt + self.min_num_turns = opt['min_turns'] + self.wizard_time_out = opt['wizard_time_out'] + self.apprentice_time_out = opt['apprentice_time_out'] + self.search_warning_turn = opt['search_warning_turn'] + self.search_warning_threshold = opt['search_warning_threshold'] + self.select_warning_turn = opt['select_warning_turn'] + self.select_warning_threshold = opt['select_warning_threshold'] + self.soft_block_qname = opt['soft_block_qname'] + self.send_task_data = opt['send_task_data'] + self.role_training_qname = opt[constants.ROLE_QUALIFICATION_NAME_KEY] + + # The agent that checks the acceptability of the messages (quality and safety). + self.acceptability_checker = self._get_acceptability_checker() + + # Number of pages to request for each wizard search + self.num_passages_to_retrieve = opt['num_passages_retrieved'] + self._search_client = create_search_agent(opt) + + # Information about personas and their availability. + self.personas_list = opt['personas'] + self.prev_persona_count = opt['prev_persona_count'] + self.max_times_persona_use = opt['max_times_persona_use'] + self.locations_list = opt['locations'] + self.persona_replacement = opt['pick_persona_with_replacement'] + self.selected_persona = None + + # Get worker names + self.worker_names = dict() + for a in self.agents: + self.worker_names[a] = get_worker_from_agent(a).worker_name + + def _get_acceptability_checker(self): + """ + Instantiate an instance of WizardOfInternetAcceptabilityChecker to monitor the + world. + """ + acr = WizardOfInternetAcceptabilityChecker() + acr.min_words_violation_threshold = constants.MIN_AVG_WORD_LENGTH_UTTERANCES + return acr + + def _get_world_name(self): + dt = datetime.now() + return f'cc_world_{dt.strftime("%H-%M-%S")}' + + def get_agent_order_mask(self, agent_index: int): + """ + A mask for simulating rotation/reordering of agents. + + Use this method for accessing agents by a certaint order. Do not use + self.agents[i] directly! + """ + assert agent_index in (0, 1), 'Invalid index for accessing agents.' + if self._change_agents_order: + # 0->1 and 1->0 + agent_index = 1 - agent_index + return self.agents[agent_index] + + def get_wizard_action(self, agent: Agent): + """ + Handles wizard message or search action. + """ + time_out = self.wizard_time_out + while time_out > 0: + start_time = time.time() + act = agent.act(timeout=time_out) + if _is_query(act): + self.num_search_queries += 1 + search_res = run_search_query(act['text'], self._search_client) + n = len(search_res['task_data']['search_results']) + logging.info(f'{n} search results were retrieved.') + agent.observe(search_res) + else: + if _has_selected_sentence_from_search_results(act): + self.num_times_search_resutls_selected += 1 + break + + # subtracting the wait time from what was spent during search + spent_time = time.time() - start_time + time_out -= spent_time + + return act + + def _send_task_objective_reminders(self, agent: Agent): + """ + Monitors the stats for target activies. If needed, sends goal reminders to + agent. + + This is mostly for checking if wizard does enough search and knowledge + selection. + """ + agent_id = agent.agent_id + if agent_id == constants.ROLE_NAMES[constants.WIZARD]: + # Checks if wizard has used search enough so far + if (self.turn_idx >= self.search_warning_turn) and ( + self.num_search_queries < self.search_warning_threshold + ): + _coordinator_send_message( + agent, message=constants.USE_SEARCH_WARNING_MESSAGE + ) + # Checks if wizard has selected search results enough times so far + elif (self.turn_idx >= self.select_warning_turn) and ( + self.num_times_search_resutls_selected < self.select_warning_threshold + ): + _coordinator_send_message( + agent, message=constants.USE_SEARCH_RESULTS_WARNING_MESSAGE + ) + + def next_utterance(self, agent: Agent): + """ + Handles receiving the next message from agent. + """ + agent_id = agent.agent_id + if agent_id == constants.ROLE_NAMES[constants.APPRENTICE]: + return agent.act(timeout=self.apprentice_time_out) + else: # It is wizard + return self.get_wizard_action(agent) + + def end_onboarding_state(self): + """ + Sends a message to front-end app to announce transition from onboarding. + """ + onboard_state = constants.ONBOARDING_STEPS['NOT_ONBOARDING'] + for agent in self.agents: + agent.observe(onboarding_mode_toggle_message(onboard_state)) + + def broadcast_apprentice_persona(self, persona: str): + """ + Sends the selected apprentice persona to the front-end app for display. + """ + for agent in self.agents: + persona_msg = { + 'id': constants.PERSONA_AGENT, + 'text': '', + 'episode_done': False, + 'task_data': {'apprentice_persona': persona}, + } + agent.observe(persona_msg) + + def shuffle_agents(self): + """ + Changes the starting order: who goes first. + """ + reorder = random.random() > 0.5 + if reorder: + logging.info(f'Switching agents orders in {self.world_tag}') + self._change_agents_order = True + + def sample_personas(self): + """ + Generates a list of sampled personas, apprentice will choose from this list. + """ + persona = self.personas_list + n = constants.CURATED_PERSONA_CHOICES + logging.info( + f'Randomly choosing {n} personas from {len(persona)} available ones.' + ) + if self.persona_replacement: + return random.sample(persona, k=n) + else: + return [persona.pop() for _ in range(n)] + + def random_location(self): + """ + Chooses a random location (only for personas that need one) + """ + return random.choice(self.locations_list) + + def assign_roles(self): + """ + Determines the order and the role of the agents in the world. + + Determines which agent goes first by random assignment. The agent roles are + based on their onboarding qualification. + """ + # Roling the dice for the starting order + self.shuffle_agents() + + # The role and order assignment to the agents. + starting_role = None + for agent_index in range(len(self.agents)): + agent = self.get_agent_order_mask(agent_index) + worker = get_worker_from_agent(agent) + qual = worker.get_granted_qualification(self.role_training_qname) + assert qual + role_qual = qual.value + if role_qual == constants.WIZARD: + agent.agent_id = 'Wizard' + elif role_qual == constants.APPRENTICE: + agent.agent_id = 'Apprentice' + else: + raise ValueError(f'Unrecognized role qulification {role_qual}.') + if not starting_role: # sets it the first time that loop runs + starting_role = role_qual + + logging.info('Agent roles assigned.') + logging.info(f'Agent with {self.get_agent_order_mask(0).agent_id} role starts.') + return starting_role + + def _get_apprentice(self): + if _is_wiz(self.agents[0]): + return self.agents[1] + else: + return self.agents[0] + + def receive_form_response(self, agent: Agent, check_persona_overuse: bool = False): + """ + Extracts the selected persona from the response form and validates it. + """ + + def generate_persona_key(persona_desc): + ret = persona_desc.strip().lower() + for sym in ('.', ',', ';', '!', '?'): + ret = ret.replace(sym, ' ') + return ' '.join([s for s in ret.split(' ') if s]) + + # Repeat asking for persona until having a valid one. + acceptable_response = False + while not acceptable_response: + agent_resp = agent.act(timeout=self.wizard_time_out) + + pers_exp = _form_response_persona_expantion(agent_resp) + + # Too short + if not pers_exp or len(pers_exp) < constants.PERSONA_EXPANSION_MIN_LEN_CHAR: + _send_persona_too_short_warning(agent, pers_exp) + continue + + # Persona was selected before + if check_persona_overuse: + persona_key = generate_persona_key( + _form_response_main_persona(agent_resp) + ) + if self.prev_persona_count[persona_key] >= self.max_times_persona_use: + _send_persona_overuse_warning(agent, persona_key) + continue + self.prev_persona_count[persona_key] += 1 + + acceptable_response = True + return agent_resp + + def _update_curated_personas_use(self, persona: str): + """ + Updates the persona use count. + + Increases the count for the number of times that the selected `persona` was + used, and removes it from available list of personas if it was selected too many + times. + """ + lower_persona = persona.lower() + self.prev_persona_count[lower_persona] += 1 + if self.prev_persona_count[lower_persona] < self.max_times_persona_use: + return + + logging.info(f'Trying to remove "{persona}" from list of personas.') + if len(persona) < constants.CURATED_PERSONA_CHOICES: + logging.warning( + 'Not enough personas may remain after removing, canceling removal.' + ) + return + + self.personas_list.remove(persona) + logging.info( + f'New number of available personas is "{len(self.personas_list)}".' + ) + + def _choose_curated_persona(self): + """ + Asks apprentice to choose a persona from the curated list of personas. + """ + persona_opts = self.sample_personas() + apprentice_agent = self._get_apprentice() + + # Removing PERSONA_NEEDS_LOCATION_TOKEN from what agents will see + persona_opts_views = [ + p.replace(constants.PERSONA_NEEDS_LOCATION_TOKEN, '') for p in persona_opts + ] + persona_selection_form = [ + { + 'type': 'choices', + 'question': 'Choose one of these personas to start:', + 'choices': persona_opts_views, + }, + {'type': 'text', 'question': 'Add something imaginative to refine it:'}, + ] + _coordinator_send_message( + apprentice_agent, + message=constants.APPRENTICE_CHOOSE_CURATED_PERSONA_REQUEST, + task_data={'respond_with_form': persona_selection_form}, + ) + agent_response = self.receive_form_response(apprentice_agent) + + rs = [r['response'] for r in agent_response['task_data']['form_responses']] + assert len(rs) == 2, 'Persona response form length is not 2.' + selected_persona, added_persona = rs + apprentice_persona = f'{selected_persona}\n{added_persona}' + worker_name = self.worker_names[apprentice_agent] + logging.info(f'Agent ({worker_name}) selected a persona: {apprentice_persona}') + + selected_persona_ind = persona_opts_views.index(selected_persona) + # Checking if persona needs location + if constants.PERSONA_NEEDS_LOCATION_TOKEN in persona_opts[selected_persona_ind]: + apprentice_location = self.random_location() + logging.info(f'Persona needs a location. {apprentice_location} selected.') + apprentice_persona = ( + f'I live in {apprentice_location}.\n{apprentice_persona}' + ) + + # Checking if the persona was used too often and needs to be removed. + self._update_curated_personas_use(persona_opts[selected_persona_ind]) + + return apprentice_persona + + def _choose_templated_topics_persona(self): + """ + Asks apprentice to choose a persona using the provided template. + """ + + topic_bundles = random.sample( + constants.TEMPLATE_PERSONAS_TOPICS, k=constants.TEMPLATE_PERSONAS_CHOICES + ) # Each topic bundle is string of comma-seperated related topics, eg. "book,author" + topics = [] + for tb in topic_bundles: + topics.extend(tb.split(',')) + + apprentice_agent = self._get_apprentice() + persona_selection_form = [ + { + 'type': 'choices', + 'question': 'My character\'s favorite ', + 'choices': topics, + }, + {'type': 'text', 'question': 'is '}, + {'type': 'text', 'question': 'Add something imaginative to refine it:'}, + ] + _coordinator_send_message( + apprentice_agent, + message=constants.APPRENTICE_CHOOSE_PERSONA_TEMPLATE_REQUEST, + task_data={'respond_with_form': persona_selection_form}, + ) + agent_response = self.receive_form_response( + apprentice_agent, check_persona_overuse=True + ) + + rs = [r['response'] for r in agent_response['task_data']['form_responses']] + assert len(rs) == 3, 'Template persona response form length is not 3.' + topic, topic_item, extra_details = rs + apprentice_persona = persona_from_template_values( + topic, topic_item, extra_details + ) + worker_name = self.worker_names[apprentice_agent] + logging.info(f'Agent ({worker_name}) selected a persona: {apprentice_persona}') + + return apprentice_persona + + def _reset_to_text_response(self, agent): + """ + Returns Mephisto response from form to text. + """ + _coordinator_send_message(agent=agent, task_data={'respond_with_form': False}) + + def apprentice_choose_persona(self): + """ + Randomly selects a persona selection type (template, curated) and asks agent. + """ + logging.info('Randomly choosing persona selection type.') + choose_from_templates = ( + random.random() < constants.PROBABILITY_CHOOSING_TEMPLATE_PERSONA + ) + if choose_from_templates: + logging.info('Choosing persona persona from template.') + resp = self._choose_templated_topics_persona() + else: + logging.info('Choosing persona persona from curated cases.') + resp = self._choose_curated_persona() + self._reset_to_text_response(self._get_apprentice()) + return resp + + def send_time_length_info(self): + """ + Sends a message to agents informing them about the length of the task (turns, + and timeout). + """ + min_rounds = self.min_num_turns + wiz_time = sec_to_min_pretty(self.wizard_time_out) + app_time = sec_to_min_pretty(self.apprentice_time_out) + for agent in self.agents: + message = f'This conversation continues for at least {min_rounds} rounds.\n' + t = wiz_time if _is_wiz(agent) else app_time + message += ( + f'In your turn, please send your message within {t} minutes. ' + 'Otherwise you may be disqualified. ' + ) + if not _is_wiz(agent): + message += ( + f'Note that you might have to wait up to {wiz_time} ' + 'mintes to receive a response from the other person.' + ) + agent.observe( + { + 'id': constants.COORDINATOR_AGENT, + 'text': message, + 'episode_done': False, + } + ) + + def send_starter_instruction(self, role: int): + """ + Sends a reminder about the role and goals in the beginning of chat. + """ + message_text = None + if role == constants.WIZARD: + message_text = constants.WIZARD_STARTING_INSTRUCTION + else: + assert role == constants.APPRENTICE + message_text = constants.APPRENTICE_STARTING_INSTRUCTION + start_instruction_message = { + 'id': constants.COORDINATOR_AGENT, + 'text': message_text, + 'episode_done': False, + } + self.get_agent_order_mask(0).observe(start_instruction_message) + + def send_wizard_persona_emphasize_message(self): + """ + Sends a message to wizard emphasizing on main goal here (apprentice persona). + """ + for agent in self.agents: + if not _is_wiz(agent): + continue + agent.observe( + { + 'id': constants.COORDINATOR_AGENT, + 'text': constants.WIZARD_PERSONA_EMPHASIZE, + 'episode_done': False, + } + ) + + def setup_roles_and_persona(self): + """ + Prepares the chat environment and states before starting agent interactions. + """ + logging.info('Setting up roles, orders, persona.') + self.end_onboarding_state() + self.broadcast_apprentice_persona('') # clear onboarding persona + starting_role = self.assign_roles() + self.send_wizard_persona_emphasize_message() + self.selected_persona = self.apprentice_choose_persona() + self.broadcast_apprentice_persona(self.selected_persona) + self.send_time_length_info() + self.send_starter_instruction(starting_role) + + def parley(self): + """ + parley process for the agents: running the chat world. + """ + if self.turn_idx == 0: + self.setup_roles_and_persona() + + self.turn_idx += 1 + logging.info( + f'{self.world_tag} is at turn {self.turn_idx}...\n' + f'Wizard has searched {self.num_search_queries} times and ' + f'selected results {self.num_times_search_resutls_selected} times.' + ) + + for idx in range(len(self.agents)): + agent = self.get_agent_order_mask(idx) + act = self.next_utterance(agent) + self.messages.append(deepcopy(act)) + if self.send_task_data: + act.force_set( + 'task_data', + { + 'last_acting_agent': agent.agent_id, + 'current_dialogue_turn': self.turn_idx, + 'utterance_count': self.turn_idx + idx, + }, + ) + + if 'requested_finish' in act and act['requested_finish']: + # One of the agents has requested for end of the chat. + self.episodeDone = True + break + + for other_agent in self.agents: + if other_agent != agent: + other_agent.observe(validate(act)) + + # Reminds wizard about searching and selecting knowledge if needed. + self._send_task_objective_reminders(agent) + + def _reason_to_disqualify(self, agent: Agent): + """ + Determining if agents had low quality work or had unsafe behaviour. + """ + # Disconncet or timeout + mephisto_agent = agent.mephisto_agent + if mephisto_agent.get_status() in ( + AgentState.STATUS_EXPIRED, + AgentState.STATUS_TIMEOUT, + ): + return 'agent was disconnected.' + + # Wizard not using search enough + if agent.agent_id == 'Wizard' and ( + (self.num_search_queries < self.search_warning_threshold) + or (self.num_times_search_resutls_selected < self.select_warning_threshold) + ): + return ( + 'blocked for not enough search activity ' + f'({self.num_search_queries} searches; ' + f'{self.num_times_search_resutls_selected} selected sentecnes).' + ) + + acceptability_checker_results = self.acceptability_checker.check_messages( + agent.agent_id, + self.selected_persona, + messages=self.messages, + is_worker_0=False, + violation_types=constants.ACCEPTABILITY_VIOLATIONS, + ) + if acceptability_checker_results: + return f'ParlAI acceptability checker found violations: "{acceptability_checker_results}"' + + def _soft_block_agent(self, agent): + """ + Softblocking the agent: they can not participate in this task anymore. + """ + worker = get_worker_from_agent(agent) + logging.warning(f'Soft blocking {worker.worker_name}') + worker.grant_qualification(self.soft_block_qname) + + def prep_save_data(self, agent_as_list): + """ + Saving the chat data, after checking its quality and safety. + """ + agent = agent_as_list[0] + agent_id = agent.agent_id + + logging.info(f'Preparing saved data for {agent_id}') + ret = {'agent_id': agent_id, 'message_history_copy': self.messages} + disqualify_reason = self._reason_to_disqualify(agent) + if disqualify_reason: + logging.info(f'Disqualified submission detecetd: "{disqualify_reason}"') + ret['disqualify_reason'] = disqualify_reason + self._soft_block_agent(agent) + + return ret + + def episode_done(self): + return self.episodeDone + + def shutdown(self): + """ + Shutdown all mturk agents in parallel, otherwise if one mturk agent is + disconnected then it could prevent other mturk agents from completing. + """ + global shutdown_agent + + def shutdown_agent(agent): + try: + agent.shutdown(timeout=None) + except Exception: + agent.shutdown() # not MTurkAgent + + Parallel(n_jobs=len(self.agents), backend='threading')( + delayed(shutdown_agent)(agent) for agent in self.agents + ) + + +def onboarding_mode_toggle_message(onboarding_step): + """ + Formats a message to be sent to front-end to detemine the state of onboarding. + """ + return { + 'id': constants.ONBOARDING_AGENT, + 'text': '', + 'episode_done': False, + 'task_data': {'on_boarding_step': onboarding_step}, + } + + +def _get_cached_roll_tally(): + """ + Returns the role tally counts from cache, if the cache is not expired. + """ + utime = ROLE_TALLY_CHACHE['last_update'] + if not utime: + logging.info('Initiated rolls tally cache.') + return None + + dt = time.time() - utime + logging.info(f'The last rolls tally cached {dt:.2f} seconds ago.') + if dt > constants.TALLY_CACHE_TIMEOUT: + logging.info( + 'Rolls tally cache is outdated ' + f'(is greater than {constants.TALLY_CACHE_TIMEOUT} s).' + ) + return None + + logging.info( + 'Rolls tally is fresh enough to use ' + f'(is less than {constants.TALLY_CACHE_TIMEOUT} s).' + ) + return ROLE_TALLY_CHACHE['data'] + + +def _cache_roll_tally(rolls_tally: Dict[int, int]): + """ + Updates the content of the roles tally cache. + """ + logging.info('Setting rolls tally cache.') + ROLE_TALLY_CHACHE['last_update'] = time.time() + ROLE_TALLY_CHACHE['data'] = rolls_tally + + +def find_needed_role(agent, rqname: str): + """ + Determines the role that the agent starting the onboarding needs to go through. + + Checks the number of agents who passed the onboarding and are waiting to be matched, + and the agents who are currently in the onboarding. Based the number of roles in the + pool decides what role a newcoming agent needs to be trained on. + + It caches the recent values of tally to avoid heavy DB queries. + The cache value is handled by ROLE_TALLY_CHACHE global variable. + To control the cache freshness and time out set TALLY_CACHE_TIMEOUT (in seconds). + """ + role_tally = _get_cached_roll_tally() + if not role_tally: + role_tally = {constants.WIZARD: 0, constants.APPRENTICE: 0} + db = agent.mephisto_agent.db + task_run_id = agent.mephisto_agent.task_run_id + agents_need_paring = db.find_onboarding_agents( + status=AgentState.STATUS_ONBOARDING, task_run_id=task_run_id + ) + agents_need_paring.extend( + db.find_agents(status=AgentState.STATUS_WAITING, task_run_id=task_run_id) + ) + + no_qual = 0 + unk_qual = 0 + this_agent_id = agent.mephisto_agent.get_agent_id() + for ag in agents_need_paring: + if ag.get_agent_id() == this_agent_id: + continue + worker = ag.get_worker() + worker_qualification = worker.get_granted_qualification(rqname) + if not worker_qualification: + no_qual += 1 + continue + qstatus = worker_qualification.value + if qstatus in (constants.WIZARD, constants.WIZARD_IN_TRAINING): + role_tally[constants.WIZARD] += 1 + elif qstatus in (constants.APPRENTICE, constants.APPRENTICE_IN_TRAINING): + role_tally[constants.APPRENTICE] += 1 + else: + unk_qual += 1 + if no_qual or unk_qual: + logging.warning( + f'\tNo qualifications: {no_qual}\tUnknown qualifications: {unk_qual}' + ) + _cache_roll_tally(role_tally) + + logging.info( + f'Wizard: {role_tally[constants.WIZARD]}\tApprentices: {role_tally[constants.APPRENTICE]}' + ) + if role_tally[constants.WIZARD] > role_tally[constants.APPRENTICE]: + logging.info('Onboarding a new Apprentice.') + role_tally[constants.APPRENTICE] += 1 + return constants.APPRENTICE + else: + logging.info('Onboarding a new Wizard.') + role_tally[constants.WIZARD] += 1 + return constants.WIZARD + + +def make_onboarding_world(opt, agent: Agent): + """ + Assigns agents to apporopraite onboarding worlds to balance the roles. + """ + role_qual_name = opt[constants.ROLE_QUALIFICATION_NAME_KEY] + + def assign_role_based_on_ques(agent): + worker = get_worker_from_agent(agent) + needed_worker = find_needed_role(agent, role_qual_name) + if needed_worker == constants.WIZARD: + worker.grant_qualification(role_qual_name, constants.WIZARD_IN_TRAINING) + return WizardOnboardingWorld(opt, agent) + else: + worker.grant_qualification(role_qual_name, constants.APPRENTICE_IN_TRAINING) + return ApprenticeOnboardingWorld(opt, agent) + + # sends a message to UI to set the onboarding step. + agent.observe( + onboarding_mode_toggle_message(constants.ONBOARDING_STEPS['CHAT_INTERFACE']) + ) + worker_qualification = get_worker_from_agent(agent).get_granted_qualification( + role_qual_name + ) + + if not worker_qualification: # Has not started onboarding before + return assign_role_based_on_ques(agent) + else: # Had been in onboarding but didn't finish + qstatus = worker_qualification.value + if qstatus == constants.WIZARD_IN_TRAINING: + return WizardOnboardingWorld(opt, agent) + elif qstatus == constants.APPRENTICE_IN_TRAINING: + return ApprenticeOnboardingWorld(opt, agent) + else: + logging.warning( + f'Unknown qualification status "{qstatus}" during creating onboarding workds' + + 'Assigning the roles based on waiting and onboarding agents queue size.' + ) + return assign_role_based_on_ques(agent) + + +def assign_role_training_qualification( + worker, role_qulification_name: str, role_qulification_value: int +): + """ + Syncs the training qualification of the agent (worker) with the DB. + """ + if not role_qulification_value or role_qulification_value == constants.NO_ROLE: + logging.warning('Agent did not qualify for a role.') + return False + role_name = constants.ROLE_NAMES[role_qulification_value] + logging.info(f'Agent qulified for {role_name} role. Granting worker qualification.') + worker.grant_qualification(role_qulification_name, role_qulification_value) + return True + + +def validate_onboarding(data: Dict): + """ + Check the contents of the data to ensure they are valid and safe. + """ + try: + saved_data = data['outputs']['messages'][-1]['data']['WORLD_DATA'] + role = ( + 'Wizard' if saved_data[constants.SAVED_DATA_IS_WIZARD_KEY] else 'Apprentice' + ) + logging.info(f'Validating {role} onboarding.') + except (IndexError, KeyError) as e: + logging.warning( + 'Incomplete data to validate agent onboarding.' + f'Onboarding saved_data error: {e}' + ) + return False + + rejection_reason = saved_data[constants.WORKER_REJECT_REASON] + if rejection_reason: + logging.warning(f'Rejected: {rejection_reason}') + return False + + # Role qualification + worker = get_worker_by_name(saved_data[constants.SAVED_DATA_WORKER_KEY]) + qual_name, qual_val = saved_data[constants.SAVED_DATA_ROLE_QUALIFICATION_DATA_KEY] + if not assign_role_training_qualification(worker, qual_name, qual_val): + return False + + logging.info('Onboarding work accepted.') + return True + + +def make_world(opt, agents: Agent): + return MTurkMultiAgentDialogWorld(opt, agents) + + +def get_world_params(): + return {'agent_count': 2}