diff --git a/CHANGELOG.md b/CHANGELOG.md index c927d92a..d5993ed9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- Fixed the way we compute SQuAD metrics. + + ## [v0.2.2](https://github.com/allenai/catwalk/releases/tag/v0.2.2) - 2023-01-27 ### Changed diff --git a/catwalk/__main__.py b/catwalk/__main__.py index 33a1c097..bb0105ee 100644 --- a/catwalk/__main__.py +++ b/catwalk/__main__.py @@ -7,25 +7,25 @@ from catwalk.tasks import TASK_SETS -def main(): - initialize_logging(log_level="WARNING") +_parser = argparse.ArgumentParser() +_parser.add_argument('--model', type=str, required=True) +_parser.add_argument('--task', type=str, nargs="+") +_parser.add_argument('--split', type=str) +_parser.add_argument('--batch_size', type=int, default=32) +_parser.add_argument('--num_shots', type=int) +_parser.add_argument('--fewshot_seed', type=int) +_parser.add_argument('--limit', type=int) +_parser.add_argument( + '-d', '-w', + type=str, + default=None, + metavar="workspace", + dest="workspace", + help="the Tango workspace with the cache") + - parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, required=True) - parser.add_argument('--task', type=str, nargs="+") - parser.add_argument('--split', type=str) - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--num_shots', type=int) - parser.add_argument('--fewshot_seed', type=int) - parser.add_argument('--limit', type=int) - parser.add_argument( - '-d', '-w', - type=str, - default=None, - metavar="workspace", - dest="workspace", - help="the Tango workspace with the cache") - args = parser.parse_args() +def main(args: argparse.Namespace): + initialize_logging(log_level="WARNING") if args.workspace is None: workspace = None @@ -71,4 +71,4 @@ def main(): if __name__ == "__main__": - main() + main(_parser.parse_args()) diff --git a/catwalk/model.py b/catwalk/model.py index 37d27466..d4fa171e 100644 --- a/catwalk/model.py +++ b/catwalk/model.py @@ -1,7 +1,7 @@ import inspect from abc import ABC from copy import deepcopy -from typing import Sequence, Dict, Any, Iterator, Tuple, List, Optional +from typing import Sequence, Dict, Any, Iterator, Tuple, List, Optional, Union import torch from tango.common import Registrable, Tqdm @@ -41,6 +41,17 @@ def unsqueeze_args(args: Tuple[Any]) -> Tuple[Any, ...]: return tuple(fixed_args) +_TorchmetricsResult = Union[torch.Tensor, Dict[str, '_TorchmetricsResult']] +_CatwalkResult = Union[float, Dict[str, '_CatwalkResult']] + + +def recursive_tolist(args: _TorchmetricsResult) -> _CatwalkResult: + if isinstance(args, dict): + return { key: recursive_tolist(value) for key, value in args.items() } + else: + return args.tolist() + + class Model(Registrable, DetHashWithVersion, ABC): VERSION = "002lst" @@ -60,7 +71,7 @@ def calculate_metrics(self, task: Task, predictions: Sequence[Dict[str, Any]]) - metric_args = unsqueeze_args(metric_args) metric.update(*metric_args) return { - metric_name: metric.compute().tolist() + metric_name: recursive_tolist(metric.compute()) for metric_name, metric in metrics.items() } diff --git a/catwalk/models/huggingface.py b/catwalk/models/huggingface.py index b53318c1..2139d871 100644 --- a/catwalk/models/huggingface.py +++ b/catwalk/models/huggingface.py @@ -79,7 +79,9 @@ def _predict_qa( tokenizer: PreTrainedTokenizer, batch_size: int = 32 ) -> Iterator[Dict[str, Any]]: - pipe = QuestionAnsweringPipeline(model=model, tokenizer=tokenizer, device=model.device.index) + # The type annotation for QuestionAnsweringPipeline says `device` has to be an `int`, but when you look + # at the code, that's not actually correct. + pipe = QuestionAnsweringPipeline(model=model, tokenizer=tokenizer, device=model.device) # type: ignore contexts = [instance.context for instance in instances] questions = [instance.question for instance in instances] diff --git a/tests/test_spotchecks.py b/tests/test_spotchecks.py new file mode 100644 index 00000000..ba245067 --- /dev/null +++ b/tests/test_spotchecks.py @@ -0,0 +1,13 @@ +import sys + +import catwalk.__main__ + + +def test_squad(): + args = catwalk.__main__._parser.parse_args([ + "--model", "bert-base-uncased", + "--task", "squad", + "--split", "validation", + "--limit", "100" + ]) + catwalk.__main__.main(args)