Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export for ViT #15658

Merged
merged 38 commits into from
Mar 9, 2022
Merged

Add ONNX export for ViT #15658

merged 38 commits into from
Mar 9, 2022

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Feb 15, 2022

What does this PR do?

This PR enables the export of Vision Transformers (ViT) to ONNX with the following features:

  • default
  • image-classification

To enable this new modality, I had to significantly refactor the internals of the ONNX exporter because we need a way to pass the feature extractor instead of the tokenizer.

Thanks to a tip from @LysandreJik I replaced the positional tokenizer argument in various functions with a new preprocessor argument that can be a tokenizer or feature extractor (and possibly a processor in future). This should guarantee backwards compatibility for users who chose to use the Python API instead of the transformers.onnx CLI.

Usage

import requests
import numpy as np
from PIL import Image
from onnxruntime import InferenceSession
from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForImageClassification

# Export ViT checkpoint with image classification head
model_ckpt = "google/vit-base-patch16-224"
!python -m transformers.onnx --model={model_ckpt} --feature=image-classification onnx/

# Download an image of two cute cats - naturally ;-)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# Instantiate config and feature extractor
config = AutoConfig.from_pretrained(model_ckpt)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
inputs = feature_extractor(image, return_tensors="np")

# Create ONNX Runtime session
session = InferenceSession("onnx/model.onnx", providers=["CPUExecutionProvider"])
outputs = session.run(["logits"], dict(inputs))
predicted_class_idx = np.argmax(outputs[0])
# Returns Predicted class: Egyptian cat
print("Predicted class:", config.id2label[predicted_class_idx])

Here's two Colab notebooks comparing the inference gains with ORT vs vanilla PyTorch (~20-30% faster on CPU, ~5% faster on GPU):

Todo

  • Add deprecation warning if user passes tokenizer as keyword argument
  • Run an inference test to see if we get any speed-up over vanilla PyTorch (maybe)

@HuggingFaceDocBuilder
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@lewtun lewtun changed the title Add ONNX export for vision models Add ONNX export for ViT Feb 15, 2022
@@ -270,6 +276,7 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the vision requirement here to test the ViT checkpoint. Please let me know if this isn't a "good practice" because it mixes multiple modalities together

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the vision modality is installed for ONNX tests, so you'd have to double check this actually ends up being tested.

src/transformers/onnx/config.py Outdated Show resolved Hide resolved
src/transformers/onnx/config.py Outdated Show resolved Hide resolved
src/transformers/onnx/config.py Outdated Show resolved Hide resolved
src/transformers/onnx/config.py Outdated Show resolved Hide resolved
src/transformers/onnx/convert.py Outdated Show resolved Hide resolved
src/transformers/onnx/convert.py Outdated Show resolved Hide resolved
model = model_class.from_config(config)
onnx_config = onnx_config_class_constructor(model.config)

# Check the modality of the inputs and instantiate the appropriate preprocessor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this becomes a piece of code we use often, maybe we can refactor this into a function?

data = np.random.rand(image_height, image_width, num_channels) * 255
images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images

def generate_dummy_inputs(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This base method now has a mix of arguments for text and image modalities. I'm not 100% sure if we should split the modalities apart ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You split it now right? Just checking to make sure.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Regarding the the tokenizer optional kwarg, it's very good to keep it like this, but there should be a deprecation warning when it's actually used, and it shouldn't be documented.

src/transformers/onnx/__main__.py Outdated Show resolved Hide resolved
src/transformers/onnx/config.py Outdated Show resolved Hide resolved
src/transformers/onnx/config.py Outdated Show resolved Hide resolved
src/transformers/onnx/config.py Outdated Show resolved Hide resolved
src/transformers/onnx/convert.py Outdated Show resolved Hide resolved
src/transformers/onnx/convert.py Outdated Show resolved Hide resolved
@@ -270,6 +276,7 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the vision modality is installed for ONNX tests, so you'd have to double check this actually ends up being tested.

@lewtun
Copy link
Member Author

lewtun commented Feb 17, 2022

While testing this branch on Colab, I discovered a weird bug when trying to run inference in ONNX Runtime with torch v1.10.2:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_42' Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:42 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape &, std::vector<int64_t> &, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,197,768}, requested shape:{2,197,12,64}

Curiously, there is no problem running inference with torch v1.9, so something seems to have changed in the torch ONNX exporter in the latest version. I'm currently investigating what the source of the problem is ...

from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size


if TYPE_CHECKING:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I was already sorting out the relative imports, I also went ahead and fixed the import that are just used for type checking

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️❤️ ❤️ ❤️

DEFAULT_FIXED_SEQUENCE = 8

_TASKS_TO_COMMON_OUTPUTS = {
default_fixed_batch = 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These class variables are now snake_case to prevent confusion / disaster with global constants

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for making this file more resilient and less prone to cyclical import errors :-)

Comment on lines +29 to +33
if is_torch_available():
from ..modeling_utils import PreTrainedModel

if is_tf_available():
from ..modeling_tf_utils import TFPreTrainedModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank your for this 😍 !


from ..feature_extraction_utils import FeatureExtractionMixin
from ..file_utils import TensorType, is_torch_available, is_vision_available
from ..tokenization_utils_base import PreTrainedTokenizerBase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last step since this file is imported at very low level, it would be great to import those (PreTrainedTokenizerBase and FeatureExtractionMixin) in TYPE_CHECKING (for type checks) and then only when we do the instance check dynamically

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

.gitignore Outdated
Comment on lines 168 to 169
scratch/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we want to add this in the general gitignore of Transformers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oop! Will fix that!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :)

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work @lewtun !

data = np.random.rand(image_height, image_width, num_channels) * 255
images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images

def generate_dummy_inputs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You split it now right? Just checking to make sure.

@@ -326,7 +401,7 @@ def num_attention_heads(self) -> int:

def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
tokenizer: "PreTrainedTokenizerBase",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you asking about the change to PreTrainedTokenizerBase or use of strings for the typing? Here's the reasons in both cases:

  • I chose PreTrainedTokenizerBase because it covers both slow and fast tokenizers. The alternative would have been something like Union[PreTrainedTokenizer, PreTrainedTokenizerFast], but that felt clunky
  • I used strings for the typing following @sgugger's suggestion to use the TYPE_CHECKING constant to fix the circular imports

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was asking about the change of class, and it makes sense to me now, thanks for the explanation!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks @lewtun for iterating and @sgugger for the great reviews!

@@ -178,6 +196,21 @@ def atol_for_validation(self) -> float:
"""
return 1e-5

@property
def is_torch_support_available(self) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For torch.fx we have a requirement on a specific torch version. If you have validated that it doesn't work with a specific torch version, I would see no problem in printing a warning mentioning exactly that. If it's going to fail, then raising an error is also fine.

@lewtun lewtun merged commit 50dd314 into master Mar 9, 2022
@lewtun lewtun deleted the vision-onnx-export branch March 9, 2022 16:37
@davanstrien
Copy link
Member

Super happy to see this merged! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants