-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ONNX export for ViT #15658
Add ONNX export for ViT #15658
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
tests/test_onnx_v2.py
Outdated
@@ -270,6 +276,7 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c | |||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) | |||
@slow | |||
@require_torch | |||
@require_vision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the vision requirement here to test the ViT checkpoint. Please let me know if this isn't a "good practice" because it mixes multiple modalities together
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the vision modality is installed for ONNX tests, so you'd have to double check this actually ends up being tested.
tests/test_onnx_v2.py
Outdated
model = model_class.from_config(config) | ||
onnx_config = onnx_config_class_constructor(model.config) | ||
|
||
# Check the modality of the inputs and instantiate the appropriate preprocessor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this becomes a piece of code we use often, maybe we can refactor this into a function?
data = np.random.rand(image_height, image_width, num_channels) * 255 | ||
images.append(Image.fromarray(data.astype("uint8")).convert("RGB")) | ||
return images | ||
|
||
def generate_dummy_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This base method now has a mix of arguments for text and image modalities. I'm not 100% sure if we should split the modalities apart ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You split it now right? Just checking to make sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
Regarding the the tokenizer
optional kwarg, it's very good to keep it like this, but there should be a deprecation warning when it's actually used, and it shouldn't be documented.
tests/test_onnx_v2.py
Outdated
@@ -270,6 +276,7 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c | |||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) | |||
@slow | |||
@require_torch | |||
@require_vision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the vision modality is installed for ONNX tests, so you'd have to double check this actually ends up being tested.
While testing this branch on Colab, I discovered a weird bug when trying to run inference in ONNX Runtime with
Curiously, there is no problem running inference with |
Co-authored-by: Sylvain Gugger <[email protected]>
from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size | ||
|
||
|
||
if TYPE_CHECKING: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I was already sorting out the relative imports, I also went ahead and fixed the import that are just used for type checking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️❤️ ❤️ ❤️
DEFAULT_FIXED_SEQUENCE = 8 | ||
|
||
_TASKS_TO_COMMON_OUTPUTS = { | ||
default_fixed_batch = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These class variables are now snake_case to prevent confusion / disaster with global constants
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for making this file more resilient and less prone to cyclical import errors :-)
if is_torch_available(): | ||
from ..modeling_utils import PreTrainedModel | ||
|
||
if is_tf_available(): | ||
from ..modeling_tf_utils import TFPreTrainedModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank your for this 😍 !
src/transformers/onnx/config.py
Outdated
|
||
from ..feature_extraction_utils import FeatureExtractionMixin | ||
from ..file_utils import TensorType, is_torch_available, is_vision_available | ||
from ..tokenization_utils_base import PreTrainedTokenizerBase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last step since this file is imported at very low level, it would be great to import those (PreTrainedTokenizerBase and FeatureExtractionMixin) in TYPE_CHECKING (for type checks) and then only when we do the instance check dynamically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
.gitignore
Outdated
scratch/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we want to add this in the general gitignore of Transformers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oop! Will fix that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work @lewtun !
data = np.random.rand(image_height, image_width, num_channels) * 255 | ||
images.append(Image.fromarray(data.astype("uint8")).convert("RGB")) | ||
return images | ||
|
||
def generate_dummy_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You split it now right? Just checking to make sure.
@@ -326,7 +401,7 @@ def num_attention_heads(self) -> int: | |||
|
|||
def generate_dummy_inputs( | |||
self, | |||
tokenizer: PreTrainedTokenizer, | |||
tokenizer: "PreTrainedTokenizerBase", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you asking about the change to PreTrainedTokenizerBase
or use of strings for the typing? Here's the reasons in both cases:
- I chose
PreTrainedTokenizerBase
because it covers both slow and fast tokenizers. The alternative would have been something likeUnion[PreTrainedTokenizer, PreTrainedTokenizerFast]
, but that felt clunky - I used strings for the typing following @sgugger's suggestion to use the
TYPE_CHECKING
constant to fix the circular imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was asking about the change of class, and it makes sense to me now, thanks for the explanation!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -178,6 +196,21 @@ def atol_for_validation(self) -> float: | |||
""" | |||
return 1e-5 | |||
|
|||
@property | |||
def is_torch_support_available(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For torch.fx we have a requirement on a specific torch version. If you have validated that it doesn't work with a specific torch version, I would see no problem in printing a warning mentioning exactly that. If it's going to fail, then raising an error is also fine.
Super happy to see this merged! 🤗 |
What does this PR do?
This PR enables the export of Vision Transformers (ViT) to ONNX with the following features:
default
image-classification
To enable this new modality, I had to significantly refactor the internals of the ONNX exporter because we need a way to pass the feature extractor instead of the tokenizer.
Thanks to a tip from @LysandreJik I replaced the positional
tokenizer
argument in various functions with a newpreprocessor
argument that can be a tokenizer or feature extractor (and possibly a processor in future). This should guarantee backwards compatibility for users who chose to use the Python API instead of thetransformers.onnx
CLI.Usage
Here's two Colab notebooks comparing the inference gains with ORT vs vanilla PyTorch (~20-30% faster on CPU, ~5% faster on GPU):
Todo
tokenizer
as keyword argument