Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
data compiler (#4034)
Browse files Browse the repository at this point in the history
  • Loading branch information
mojtaba-komeili authored Sep 22, 2021
1 parent 6b88843 commit edb19d3
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 0 deletions.
10 changes: 10 additions & 0 deletions parlai/crowdsourcing/projects/wizard_of_internet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ You need to have a functional search server running, and sets its address in `se
This server responds to the search requests sent by the worker who takes *wizard* role during this task:
It receieves a json with two keys: `q` and `n`, which are a string that is the search query, and an integer that is the number of pages to return, respectively.
It sends its response also as a json under a key named `response` which has a list of documents retrieved for the received search query. Each document is a mapping (dictionary) of *string->string* with at least 3 fields: `url`, `title`, and `content` (see [SearchEngineRetriever](https://github.com/facebookresearch/ParlAI/blob/70ee4a2c63008774fc9e66a8392847554920a14d/parlai/agents/rag/retrieve_api.py#L73) for more info on how this task interacts with the search server).

## Creating the dataset

Having collected data from crowdsourcing task, you may use `compile_resullts.py` to create your dataset, as a json file.
For example, if you called your task `wizard-of-internet` (you set this name in the config file that you ran with your task from `hydra_config`),
the following code creates your dataset as a json file in the directory specified by `--output-folder` flag:

```.python
python compile_results.py --task-name wizard-of-internet --output-folder=/dataset/wizard-internet
```
186 changes: 186 additions & 0 deletions parlai/crowdsourcing/projects/wizard_of_internet/compile_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Compiles the final dataset (a json file) from the this Mephisto crowdsourcing task.
Example use:
python compile_results.py --task-name wizard-of-internet --output-folder=/dataset/wizard-internet
"""

from typing import Dict, Union
import parlai.utils.logging as logging
from parlai.crowdsourcing.projects.wizard_of_internet.wizard_internet_blueprint import ( # noqa: F401
WIZARD_INTERNET_PARLAICHAT_BLUEPRINT,
)
from parlai.crowdsourcing.utils.analysis import AbstractResultsCompiler
from mephisto.abstractions.blueprint import AgentState

# Roles of the task
WIZARD = 'Wizard'
APPRENTICE = 'Apprentice'
SEARCH_AGENT = 'SearchAgent'

# Constant names with many reuse
CONTENTS = 'contents'
SELECTED_CONTENTS = 'selected_contents'
DIALOG_HISTORY_KEY = 'dialog_history'


def is_persona_form_response(message):
k1 = 'task_data'
k2 = 'form_responses'
return message and k1 in message and k2 in message[k1]


def get_human_sender(message):
sender = message['id']
if sender in (WIZARD, APPRENTICE, SEARCH_AGENT):
return sender


def remove_keys_from_dict(dictionary_data, keys):
for rm_key in keys:
del dictionary_data[rm_key]
return dictionary_data


def chat_interruption(message):
for k in ('requested_finish', 'MEPHISTO_is_submit'):
if k in message and message[k]:
return True
return False


class ChatMessage:
"""
Container for keeping the content of one interaction between agents.
"""

def __init__(self, message_dict: dict) -> None:
self._message = message_dict
self._sender = get_human_sender(self._message)

def _format_action(self, receiver):
return f'{self._sender} => {receiver}'

def get_action(self):
if self._sender == APPRENTICE:
return self._format_action(WIZARD)

if self._sender == SEARCH_AGENT:
return self._format_action(WIZARD)

# Must be WIZARD
k = 'is_search_query'
if k in self._message and self._message[k]:
return self._format_action(SEARCH_AGENT)
else:
return self._format_action(APPRENTICE)

def get_text(self):
if self._sender == SEARCH_AGENT:
return ''
return self._message.get('text', '')

def get_context(self):
if self._sender == APPRENTICE:
return {}

context_data = self._message.get('task_data', None)

if self._sender == WIZARD:
if not context_data:
return {}
else:
if 'form_responses' in context_data:
return {'Persona': context_data['form_responses'][0]['response']}
return {
CONTENTS: context_data.get('text_candidates', ''),
SELECTED_CONTENTS: context_data.get('selected_text_candaidtes', ''),
}
# Must be SEARCH_AGENT
return {CONTENTS: context_data['search_results']}

def compile_message(self):
d = dict()
d['action'] = self.get_action()
d['text'] = self.get_text()
d['context'] = self.get_context()
return d


class WizardOfInternetResultsCompiler(AbstractResultsCompiler):
"""
Compiles the results of Wizard of Internet crowdsourcing task into a json dataset.
"""

def is_unit_acceptable(self, unit_data):
# Depending on the situation (in practice) we may be able to salvage incomplete data.
# Here, we only keep completed and approved ones (discarding the rest).
return unit_data['status'] in (
AgentState.STATUS_ACCEPTED,
AgentState.STATUS_APPROVED,
)

def format_chat_data(self, dialog_history):
data_dict = {'apprentice_persona': '', DIALOG_HISTORY_KEY: []}

for message in dialog_history:
if message['id'] == 'PersonaAgent':
data_dict['apprentice_persona'] = message['task_data'][
'apprentice_persona'
]
continue

if is_persona_form_response(message) or not get_human_sender(message):
continue

if chat_interruption(message):
# There was interruption (could be a clean Finish)
# Ignoring the rest of messages
break

else:
compiled_message = ChatMessage(message).compile_message()
data_dict[DIALOG_HISTORY_KEY].append(compiled_message)

return data_dict

def compile_results(self) -> Dict[str, Dict[str, Union[dict, str]]]:

logging.info('Retrieving task data from Mephisto.')
task_units_data = self.get_task_data()
logging.info(f'Data for {len(task_units_data)} units loaded successfully.')

results = dict()
for work_unit in task_units_data:
assignment_id = work_unit['assignment_id']

data = work_unit['data']
agent_name = data['agent_name']
if agent_name == 'Wizard':
# Only collecting Wizard side data. Because it contains
# data that is needed from the apprentice side too.
formatted_chat_data = self.format_chat_data(data['messages'])
results[assignment_id] = formatted_chat_data

logging.info(f'{len(results)} dialogues compiled.')
return results


if __name__ == '__main__':
parser_ = WizardOfInternetResultsCompiler.setup_args()
args = parser_.parse_args()
opt = {
'task_name': args.task_name,
'results_format': 'json',
'output_folder': args.output_folder,
}
wizard_data_compiler = WizardOfInternetResultsCompiler(opt)
wizard_data_compiler.compile_and_save_results()

0 comments on commit edb19d3

Please sign in to comment.