Skip to content

Commit

Permalink
Fixing tests for Perceiver (#14739)
Browse files Browse the repository at this point in the history
* Adding some slow test to check for perceiver at least from a high level.

* Re-enabling fast tests for Perceiver ImageClassification.

* Perceiver might try to run without Tokenizer (Fast doesn't exist) and
with FeatureExtractor some text only pipelines.

* Oops.

* Adding a comment for `update_config_with_model_class`.

* Remove `model_architecture` to get `tiny_config`.

* Finalize rebase.

* Smarter way to handle undefined FastTokenizer.

* Remove old code.

* Addressing some nits.

* Don't instantiate `None`.
  • Loading branch information
Narsil authored Dec 14, 2021
1 parent 322d416 commit 546a91a
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 19 deletions.
10 changes: 10 additions & 0 deletions src/transformers/models/perceiver/modeling_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@ def forward(
output_hidden_states=None,
labels=None,
return_dict=None,
pixel_values=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Expand Down Expand Up @@ -1296,6 +1297,10 @@ def forward(
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
elif inputs is None and pixel_values is not None:
inputs = pixel_values
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.perceiver(
Expand Down Expand Up @@ -1399,6 +1404,7 @@ def forward(
output_hidden_states=None,
labels=None,
return_dict=None,
pixel_values=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Expand Down Expand Up @@ -1427,6 +1433,10 @@ def forward(
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
elif inputs is None and pixel_values is not None:
inputs = pixel_values
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.perceiver(
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ def pipeline(
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None

if task in {"audio-classification"}:
# Audio classification will never require a tokenizer.
if task in {"audio-classification", "image-classification"}:
# These will never require a tokenizer.
# the model on the other hand might have a tokenizer, but
# the files could be missing from the hub, instead of failing
# on such repos, we just force to not load it.
Expand Down
28 changes: 23 additions & 5 deletions tests/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ def get_tiny_config_from_class(configuration_class):
model_tester = model_tester_class(parent=None)

if hasattr(model_tester, "get_pipeline_config"):
return model_tester.get_pipeline_config()
config = model_tester.get_pipeline_config()
elif hasattr(model_tester, "get_config"):
return model_tester.get_config()
config = model_tester.get_config()
else:
config = None
logger.warning(f"Model tester {model_tester_class.__name__} has no `get_config()`.")

return config


@lru_cache(maxsize=100)
def get_tiny_tokenizer_from_checkpoint(checkpoint):
Expand All @@ -100,11 +103,17 @@ def get_tiny_tokenizer_from_checkpoint(checkpoint):
return tokenizer


def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config):
def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config, feature_extractor_class):
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint)
except Exception:
feature_extractor = None
try:
if feature_extractor_class is not None:
feature_extractor = feature_extractor_class()
else:
feature_extractor = None
except Exception:
feature_extractor = None
if hasattr(tiny_config, "image_size") and feature_extractor:
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)

Expand Down Expand Up @@ -168,7 +177,9 @@ def test(self):
self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer")
else:
tokenizer = None
feature_extractor = get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config)
feature_extractor = get_tiny_feature_extractor_from_checkpoint(
checkpoint, tiny_config, feature_extractor_class
)

if tokenizer is None and feature_extractor is None:
self.skipTest(
Expand Down Expand Up @@ -218,6 +229,13 @@ def data(n):
if not tokenizer_classes:
# We need to test even if there are no tokenizers.
tokenizer_classes = [None]
else:
# Remove the non defined tokenizers
# ByT5 and Perceiver are bytes-level and don't define
# FastTokenizer, we can just ignore those.
tokenizer_classes = [
tokenizer_class for tokenizer_class in tokenizer_classes if tokenizer_class is not None
]

for tokenizer_class in tokenizer_classes:
if tokenizer_class is not None:
Expand Down
60 changes: 48 additions & 12 deletions tests/test_pipelines_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@

import unittest

from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
PerceiverConfig,
PreTrainedTokenizer,
is_vision_available,
)
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available
from transformers.pipelines import ImageClassificationPipeline, pipeline
from transformers.testing_utils import (
is_pipeline_test,
Expand All @@ -28,6 +23,7 @@
require_tf,
require_torch,
require_vision,
slow,
)

from .test_pipelines_common import ANY, PipelineTestCaseMeta
Expand All @@ -50,12 +46,7 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING

def get_test_pipeline(self, model, tokenizer, feature_extractor):
if isinstance(model.config, PerceiverConfig):
self.skipTest(
"Perceiver model tester is defined with a language one, which has no feature_extractor, so the automated test cannot work here"
)

image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"http://images.cocodataset.org/val2017/000000039769.jpg",
Expand Down Expand Up @@ -167,3 +158,48 @@ def test_custom_tokenizer(self):
image_classifier = pipeline("image-classification", model="lysandre/tiny-vit-random", tokenizer=tokenizer)

self.assertIs(image_classifier.tokenizer, tokenizer)

@slow
@require_torch
def test_perceiver(self):
# Perceiver is not tested by `run_pipeline_test` properly.
# That is because the type of feature_extractor and model preprocessor need to be kept
# in sync, which is not the case in the current design
image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-conv")
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.4385, "label": "tabby, tabby cat"},
{"score": 0.321, "label": "tiger cat"},
{"score": 0.0502, "label": "Egyptian cat"},
{"score": 0.0137, "label": "crib, cot"},
{"score": 0.007, "label": "radiator"},
],
)

image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-fourier")
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.5658, "label": "tabby, tabby cat"},
{"score": 0.1309, "label": "tiger cat"},
{"score": 0.0722, "label": "Egyptian cat"},
{"score": 0.0707, "label": "remote control, remote"},
{"score": 0.0082, "label": "computer keyboard, keypad"},
],
)

image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-learned")
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.3022, "label": "tabby, tabby cat"},
{"score": 0.2362, "label": "Egyptian cat"},
{"score": 0.1856, "label": "tiger cat"},
{"score": 0.0324, "label": "remote control, remote"},
{"score": 0.0096, "label": "quilt, comforter, comfort, puff"},
],
)

0 comments on commit 546a91a

Please sign in to comment.