Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Jan 31, 2023
2 parents 84cbcbf + d8c28f2 commit 3d50b9a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 22 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Fixed the way we compute SQuAD metrics.


## [v0.2.2](https://github.com/allenai/catwalk/releases/tag/v0.2.2) - 2023-01-27

### Changed
Expand Down
38 changes: 19 additions & 19 deletions catwalk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@
from catwalk.tasks import TASK_SETS


def main():
initialize_logging(log_level="WARNING")
_parser = argparse.ArgumentParser()
_parser.add_argument('--model', type=str, required=True)
_parser.add_argument('--task', type=str, nargs="+")
_parser.add_argument('--split', type=str)
_parser.add_argument('--batch_size', type=int, default=32)
_parser.add_argument('--num_shots', type=int)
_parser.add_argument('--fewshot_seed', type=int)
_parser.add_argument('--limit', type=int)
_parser.add_argument(
'-d', '-w',
type=str,
default=None,
metavar="workspace",
dest="workspace",
help="the Tango workspace with the cache")


parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--task', type=str, nargs="+")
parser.add_argument('--split', type=str)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_shots', type=int)
parser.add_argument('--fewshot_seed', type=int)
parser.add_argument('--limit', type=int)
parser.add_argument(
'-d', '-w',
type=str,
default=None,
metavar="workspace",
dest="workspace",
help="the Tango workspace with the cache")
args = parser.parse_args()
def main(args: argparse.Namespace):
initialize_logging(log_level="WARNING")

if args.workspace is None:
workspace = None
Expand Down Expand Up @@ -71,4 +71,4 @@ def main():


if __name__ == "__main__":
main()
main(_parser.parse_args())
15 changes: 13 additions & 2 deletions catwalk/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from abc import ABC
from copy import deepcopy
from typing import Sequence, Dict, Any, Iterator, Tuple, List, Optional
from typing import Sequence, Dict, Any, Iterator, Tuple, List, Optional, Union

import torch
from tango.common import Registrable, Tqdm
Expand Down Expand Up @@ -41,6 +41,17 @@ def unsqueeze_args(args: Tuple[Any]) -> Tuple[Any, ...]:
return tuple(fixed_args)


_TorchmetricsResult = Union[torch.Tensor, Dict[str, '_TorchmetricsResult']]
_CatwalkResult = Union[float, Dict[str, '_CatwalkResult']]


def recursive_tolist(args: _TorchmetricsResult) -> _CatwalkResult:
if isinstance(args, dict):
return { key: recursive_tolist(value) for key, value in args.items() }
else:
return args.tolist()


class Model(Registrable, DetHashWithVersion, ABC):
VERSION = "002lst"

Expand All @@ -60,7 +71,7 @@ def calculate_metrics(self, task: Task, predictions: Sequence[Dict[str, Any]]) -
metric_args = unsqueeze_args(metric_args)
metric.update(*metric_args)
return {
metric_name: metric.compute().tolist()
metric_name: recursive_tolist(metric.compute())
for metric_name, metric in metrics.items()
}

Expand Down
4 changes: 3 additions & 1 deletion catwalk/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def _predict_qa(
tokenizer: PreTrainedTokenizer,
batch_size: int = 32
) -> Iterator[Dict[str, Any]]:
pipe = QuestionAnsweringPipeline(model=model, tokenizer=tokenizer, device=model.device.index)
# The type annotation for QuestionAnsweringPipeline says `device` has to be an `int`, but when you look
# at the code, that's not actually correct.
pipe = QuestionAnsweringPipeline(model=model, tokenizer=tokenizer, device=model.device) # type: ignore

contexts = [instance.context for instance in instances]
questions = [instance.question for instance in instances]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_spotchecks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import sys

import catwalk.__main__


def test_squad():
args = catwalk.__main__._parser.parse_args([
"--model", "bert-base-uncased",
"--task", "squad",
"--split", "validation",
"--limit", "100"
])
catwalk.__main__.main(args)

0 comments on commit 3d50b9a

Please sign in to comment.