Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add siqa and csqa tasks #158

Merged
merged 1 commit into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions catwalk/dependencies/lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from . import eurlex
from . import unfair_tos
from . import casehold
from . import siqa
from . import csqa

########################################
# Translation tasks
Expand Down Expand Up @@ -295,6 +297,8 @@
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
"social_iqa": siqa.SocialIQA,
"csqa": csqa.CommonsenseQA,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
Expand Down
69 changes: 69 additions & 0 deletions catwalk/dependencies/lm_eval/tasks/csqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Commonsense QA
"""
from catwalk.dependencies.lm_eval.base import MultipleChoiceTask


_CITATION = """
@inproceedings{talmor-etal-2019-commonsenseqa,
title = "{C}ommonsense{QA}: A Question Answering Challenge Targeting Commonsense Knowledge",
author = "Talmor, Alon and
Herzig, Jonathan and
Lourie, Nicholas and
Berant, Jonathan",
editor = "Burstein, Jill and
Doran, Christy and
Solorio, Thamar",
booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
month = jun,
year = "2019",
address = "Minneapolis, Minnesota",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/N19-1421",
doi = "10.18653/v1/N19-1421",
pages = "4149--4158"
}
"""


class CommonsenseQA(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "commonsense_qa"

def has_training_docs(self):
return True

def has_validation_docs(self):
return True

def has_test_docs(self):
return False

def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs

def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])

def test_docs(self):
return NotImplemented

def _process_doc(self, doc):
out_doc = {
"id": doc["id"],
"query": "Question: " + doc["question"] + "\nAnswer:",
"choices": doc["choices"]["text"],
"gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
}
return out_doc

def doc_to_text(self, doc):
return doc["query"]

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["query"]
73 changes: 73 additions & 0 deletions catwalk/dependencies/lm_eval/tasks/siqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Social IQA
"""
from catwalk.dependencies.lm_eval.base import MultipleChoiceTask


_CITATION = """
@inproceedings{sap-etal-2019-social,
title = "Social {IQ}a: Commonsense Reasoning about Social Interactions",
author = "Sap, Maarten and
Rashkin, Hannah and
Chen, Derek and
Le Bras, Ronan and
Choi, Yejin",
editor = "Inui, Kentaro and
Jiang, Jing and
Ng, Vincent and
Wan, Xiaojun",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
month = nov,
year = "2019",
address = "Hong Kong, China",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/D19-1454",
doi = "10.18653/v1/D19-1454",
pages = "4463--4473"
}
"""


class SocialIQA(MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "social_i_qa"

def has_training_docs(self):
return True

def has_validation_docs(self):
return True

def has_test_docs(self):
return False

def training_docs(self):
if self._training_docs is None:
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
return self._training_docs

def validation_docs(self):
return map(self._process_doc, self.dataset["validation"])

def test_docs(self):
return NotImplemented

def _process_doc(self, doc):
out_doc = {
"query": "Question: " + doc["context"] + " " + doc["question"] + "\nAnswer:",
"choices": [doc["answerA"], doc["answerB"], doc["answerC"]],
"gold": int(doc["label"]) - 1,
}
return out_doc

def doc_to_text(self, doc):
return doc["query"]

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["query"]



2 changes: 2 additions & 0 deletions catwalk/tasks/tasks_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@
"mathqa": EleutherTask("mathqa", ranked_classification=True).add_metrics(rc_metrics(primary="acc_per_token")),
#"wsc273": EleutherTask("wsc273", ranked_classification=True).add_metrics(ENTAILMENT_METRICS),
"winogrande": EleutherTask("winogrande", ranked_classification=True).add_metrics(rc_metrics(primary="acc_per_token")),
"social_iqa": EleutherTask("social_iqa", ranked_classification=True).add_metrics(rc_metrics(primary="acc_uncond")),
"csqa": EleutherTask("csqa", ranked_classification=True).add_metrics(rc_metrics(primary="acc_uncond"))
}