Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Crowd analysis (#3935)
Browse files Browse the repository at this point in the history
* minor cleanup

* minor cleanup

* accpt filter
  • Loading branch information
mojtaba-komeili authored Aug 13, 2021
1 parent cdc8c05 commit 6a56b5c
Showing 1 changed file with 34 additions and 12 deletions.
46 changes: 34 additions & 12 deletions parlai/crowdsourcing/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd

from parlai.core.opt import Opt
import parlai.utils.logging as logging

# Defining the class only if Mephisto is installed, since it relies on Mephisto
Expand Down Expand Up @@ -47,7 +48,7 @@ def setup_args(cls):
)
return parser

def __init__(self, opt: Dict[str, Any]):
def __init__(self, opt: Opt):
self.output_folder = opt.get('output_folder')
self.results_format = opt['results_format']

Expand All @@ -61,6 +62,19 @@ def get_results_path_base(self) -> str:
f'{self.__class__.__name__}__{now.strftime("%Y%m%d_%H%M%S")}',
)

def unit_acceptable(self, unit_data: Dict[str, Any]) -> bool:
"""
Helps filtering units that are compiled. Override for use.
Returning False means that the unit data will be discarded.
"""
if not unit_data:
# Add your task-specific qualificaiton logic that justifies
# discarding this unit, based on it data content.
return False

return True

@abstractmethod
def compile_results(self) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -114,7 +128,7 @@ def setup_args(cls):
)
return parser

def __init__(self, opt: Dict[str, Any]):
def __init__(self, opt: Opt):

super().__init__(opt)

Expand Down Expand Up @@ -154,9 +168,9 @@ def setup_args(cls):
)
return parser

def __init__(self, opt: Dict[str, Any]):
def __init__(self, opt: Opt):
super().__init__(opt)
self.task_name = opt["task_name"]
self.task_name = opt['task_name']
self._mephisto_db = None
self._mephisto_data_browser = None

Expand Down Expand Up @@ -185,18 +199,26 @@ def get_task_units(self, task_name: str) -> List[Unit]:
data_browser = self.get_mephisto_data_browser()
return data_browser.get_units_for_task_name(task_name)

def get_units_data(self, task_units: List[Unit]) -> List[dict]:
def get_data_from_unit(self, unit: Unit) -> Dict[str, Any]:
"""
Retrieves task data for a single unit.
"""
try:
data_browser = self.get_mephisto_data_browser()
return data_browser.get_data_from_unit(unit)
except (IndexError, AssertionError):
logging.warning(
f'Skipping unit {unit.db_id}. No message found for this unit.'
)

def get_units_data(self, task_units: List[Unit]) -> List[Dict[str, Any]]:
"""
Retrieves task data for a list of Mephisto task units.
"""
data_browser = self.get_mephisto_data_browser()
task_data = []
for unit in task_units:
try:
unit_data = data_browser.get_data_from_unit(unit)
unit_data = self.get_data_from_unit(unit)
if unit_data and self.unit_acceptable(unit_data):
task_data.append(unit_data)
except (IndexError, AssertionError):
logging.warning(
f"Skipping unit {unit.db_id}. No message found for this unit."
)

return task_data

0 comments on commit 6a56b5c

Please sign in to comment.