diff --git a/.circleci/config.yml b/.circleci/config.yml index 7cd25f75b3c3..47ff2c6f10c5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -783,7 +783,7 @@ jobs: - v0.4-torch-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[torch,testing,sentencepiece,onnxruntime] + - run: pip install .[torch,testing,sentencepiece,onnxruntime,vision] - save_cache: key: v0.4-onnx-{{ checksum "setup.py" }} paths: @@ -816,7 +816,7 @@ jobs: - v0.4-torch-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[torch,testing,sentencepiece,onnxruntime] + - run: pip install .[torch,testing,sentencepiece,onnxruntime,vision] - save_cache: key: v0.4-onnx-{{ checksum "setup.py" }} paths: diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index aee21535aca4..7a972abe715f 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -60,6 +60,7 @@ Ready-made configurations include the following architectures: - PLBart - RoBERTa - T5 +- ViT - XLM-RoBERTa - XLM-RoBERTa-XL diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 81d9c12d814d..4d0ce02ae067 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -358,13 +358,13 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension( - batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 ) # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX token_to_add = tokenizer.num_special_tokens_to_add(is_pair) seq_length = compute_effective_axis_dimension( - seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add ) # Generate dummy inputs according to compute batch and sequence diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index a37e2f207481..9eafbf9363af 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -346,13 +346,13 @@ def _generate_dummy_inputs_for_encoder_and_decoder( # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension( - batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 ) # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX token_to_add = tokenizer.num_special_tokens_to_add(is_pair) seq_length = compute_effective_axis_dimension( - seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add ) # Generate dummy inputs according to compute batch and sequence diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index fc0775511cea..cf1d87835ed5 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -343,13 +343,13 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering( # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX batch_size = compute_effective_axis_dimension( - batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 ) # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX token_to_add = tokenizer.num_special_tokens_to_add(is_pair) seq_length = compute_effective_axis_dimension( - seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add ) # Generate dummy inputs according to compute batch and sequence diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py index 92c3681a4cce..ec0990fccaff 100644 --- a/src/transformers/models/vit/__init__.py +++ b/src/transformers/models/vit/__init__.py @@ -21,7 +21,7 @@ _import_structure = { - "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], + "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig", "ViTOnnxConfig"], } if is_vision_available(): @@ -51,7 +51,7 @@ ] if TYPE_CHECKING: - from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig + from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig, ViTOnnxConfig if is_vision_available(): from .feature_extraction_vit import ViTFeatureExtractor diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py index c8902fa9c0c3..e603a6d4f8bc 100644 --- a/src/transformers/models/vit/configuration_vit.py +++ b/src/transformers/models/vit/configuration_vit.py @@ -14,7 +14,13 @@ # limitations under the License. """ ViT model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -119,3 +125,20 @@ def __init__( self.num_channels = num_channels self.qkv_bias = qkv_bias self.encoder_stride = encoder_stride + + +class ViTOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index bb547172894b..6686626ea4bd 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -15,8 +15,9 @@ from argparse import ArgumentParser from pathlib import Path -from transformers.models.auto import AutoTokenizer - +from ..models.auto import AutoConfig, AutoFeatureExtractor, AutoTokenizer +from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES +from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES from ..utils import logging from .convert import export, validate_model_outputs from .features import FeaturesManager @@ -46,8 +47,17 @@ def main(): if not args.output.parent.exists(): args.output.parent.mkdir(parents=True) + # Check the modality of the inputs and instantiate the appropriate preprocessor + # TODO(lewtun): Refactor this as a function if we need to check modalities elsewhere as well + config = AutoConfig.from_pretrained(args.model) + if config.model_type in TOKENIZER_MAPPING_NAMES: + preprocessor = AutoTokenizer.from_pretrained(args.model) + elif config.model_type in FEATURE_EXTRACTOR_MAPPING_NAMES: + preprocessor = AutoFeatureExtractor.from_pretrained(args.model) + else: + raise ValueError(f"Unsupported model type: {config.model_type}") + # Allocate the model - tokenizer = AutoTokenizer.from_pretrained(args.model) model = FeaturesManager.get_model_from_feature(args.feature, args.model) model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature) onnx_config = model_onnx_config(model.config) @@ -62,12 +72,18 @@ def main(): f"At least {onnx_config.default_onnx_opset} is required." ) - onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output) + onnx_inputs, onnx_outputs = export( + preprocessor, + model, + onnx_config, + args.opset, + args.output, + ) if args.atol is None: args.atol = onnx_config.atol_for_validation - validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol) + validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol) logger.info(f"All good, model saved at: {args.output.as_posix()}") diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 65cedbaa5917..91cfee0e0784 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -13,15 +13,31 @@ # limitations under the License. import copy import dataclasses +import warnings from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union -from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType, is_torch_available +import numpy as np +from packaging import version +from ..file_utils import TensorType, is_torch_available, is_vision_available +from ..utils import logging from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size +if TYPE_CHECKING: + from ..configuration_utils import PretrainedConfig + from ..feature_extraction_utils import FeatureExtractionMixin + from ..tokenization_utils_base import PreTrainedTokenizerBase + + +if is_vision_available(): + from PIL import Image + +logger = logging.get_logger(__name__) + + DEFAULT_ONNX_OPSET = 11 # 2 Gb @@ -54,10 +70,10 @@ class OnnxConfig(ABC): Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format. """ - DEFAULT_FIXED_BATCH = 2 - DEFAULT_FIXED_SEQUENCE = 8 - - _TASKS_TO_COMMON_OUTPUTS = { + default_fixed_batch = 2 + default_fixed_sequence = 8 + torch_onnx_minimum_version = version.parse("1.8") + _tasks_to_common_outputs = { "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), @@ -71,14 +87,15 @@ class OnnxConfig(ABC): "end_logits": {0: "batch", 1: "sequence"}, } ), + "image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), } - def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None): + def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None): self._config = config - if task not in self._TASKS_TO_COMMON_OUTPUTS: + if task not in self._tasks_to_common_outputs: raise ValueError( - f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}" + f"{task} is not a supported task, supported tasks: {self._tasks_to_common_outputs.keys()}" ) self.task = task @@ -90,7 +107,7 @@ def __init__(self, config: PretrainedConfig, task: str = "default", patching_spe self._patching_specs.append(final_spec) @classmethod - def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig": + def from_model_config(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfig": """ Instantiate a OnnxConfig for a specific model @@ -121,7 +138,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: Returns: For each output: its name associated to the axes symbolic name and the axis position within the tensor """ - common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task] + common_outputs = self._tasks_to_common_outputs[self.task] return copy.deepcopy(common_outputs) @property @@ -146,7 +163,7 @@ def default_batch_size(self) -> int: Integer > 0 """ # Using 2 avoid ONNX making assumption about single sample batch - return OnnxConfig.DEFAULT_FIXED_BATCH + return OnnxConfig.default_fixed_batch @property def default_sequence_length(self) -> int: @@ -156,7 +173,7 @@ def default_sequence_length(self) -> int: Returns: Integer > 0 """ - return OnnxConfig.DEFAULT_FIXED_SEQUENCE + return OnnxConfig.default_fixed_sequence @property def default_onnx_opset(self) -> int: @@ -178,6 +195,21 @@ def atol_for_validation(self) -> float: """ return 1e-5 + @property + def is_torch_support_available(self) -> bool: + """ + The minimum PyTorch version required to export the model. + + Returns: + `bool`: Whether the installed version of PyTorch is compatible with the model. + """ + if is_torch_available(): + from transformers.file_utils import torch_version + + return torch_version >= self.torch_onnx_minimum_version + else: + return False + @staticmethod def use_external_data_format(num_parameters: int) -> bool: """ @@ -195,42 +227,85 @@ def use_external_data_format(num_parameters: int) -> bool: >= EXTERNAL_DATA_FORMAT_SIZE_LIMIT ) + def _generate_dummy_images( + self, batch_size: int = 2, num_channels: int = 3, image_height: int = 40, image_width: int = 40 + ): + images = [] + for _ in range(batch_size): + data = np.random.rand(image_height, image_width, num_channels) * 255 + images.append(Image.fromarray(data.astype("uint8")).convert("RGB")) + return images + def generate_dummy_inputs( self, - tokenizer: PreTrainedTokenizer, + preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, framework: Optional[TensorType] = None, + num_channels: int = 3, + image_width: int = 40, + image_height: int = 40, + tokenizer: "PreTrainedTokenizerBase" = None, ) -> Mapping[str, Any]: """ Generate inputs to provide to the ONNX exporter for the specific framework Args: - tokenizer: The tokenizer associated with this model configuration - batch_size: The batch size (int) to export the model for (-1 means dynamic axis) - seq_length: The sequence length (int) to export the model for (-1 means dynamic axis) - is_pair: Indicate if the input is a pair (sentence 1, sentence 2) - framework: The framework (optional) the tokenizer will generate tensor for + preprocessor: ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]): + The preprocessor associated with this model configuration. + batch_size (`int`, *optional*, defaults to -1): + The batch size to export the model for (-1 means dynamic axis). + seq_length (`int`, *optional*, defaults to -1): + The sequence length to export the model for (-1 means dynamic axis). + is_pair (`bool`, *optional*, defaults to `False`): + Indicate if the input is a pair (sentence 1, sentence 2) + framework (`TensorType`, *optional*, defaults to `None`): + The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for. + num_channels (`int`, *optional*, defaults to 3): + The number of channels of the generated images. + image_width (`int`, *optional*, defaults to 40): + The width of the generated images. + image_height (`int`, *optional*, defaults to 40): + The height of the generated images. Returns: Mapping[str, Tensor] holding the kwargs to provide to the model's forward function """ - - # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX - batch_size = compute_effective_axis_dimension( - batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 - ) - - # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX - token_to_add = tokenizer.num_special_tokens_to_add(is_pair) - seq_length = compute_effective_axis_dimension( - seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add - ) - - # Generate dummy inputs according to compute batch and sequence - dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size - return dict(tokenizer(dummy_input, return_tensors=framework)) + from ..feature_extraction_utils import FeatureExtractionMixin + from ..tokenization_utils_base import PreTrainedTokenizerBase + + if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None: + raise ValueError("You cannot provide both a tokenizer and a preprocessor to generate dummy inputs.") + if tokenizer is not None: + warnings.warn( + "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.", + FutureWarning, + ) + logger.warning("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.") + preprocessor = tokenizer + if isinstance(preprocessor, PreTrainedTokenizerBase): + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 + ) + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = preprocessor.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add + ) + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size + return dict(preprocessor(dummy_input, return_tensors=framework)) + elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) + dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) + return dict(preprocessor(images=dummy_input, return_tensors=framework)) + else: + raise ValueError( + "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor." + ) def patch_ops(self): for spec in self._patching_specs: @@ -264,7 +339,7 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> class OnnxConfigWithPast(OnnxConfig, ABC): def __init__( self, - config: PretrainedConfig, + config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None, use_past: bool = False, @@ -273,7 +348,7 @@ def __init__( self.use_past = use_past @classmethod - def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast": + def with_past(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfigWithPast": """ Instantiate a OnnxConfig with `use_past` attribute set to True @@ -326,7 +401,7 @@ def num_attention_heads(self) -> int: def generate_dummy_inputs( self, - tokenizer: PreTrainedTokenizer, + tokenizer: "PreTrainedTokenizerBase", batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, @@ -445,7 +520,7 @@ def num_attention_heads(self) -> Tuple[int]: def generate_dummy_inputs( self, - tokenizer: PreTrainedTokenizer, + tokenizer: "PreTrainedTokenizerBase", batch_size: int = -1, seq_length: int = -1, is_pair: bool = False, diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index f66c0b61ddc0..9f76b2fc78f6 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -12,18 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from inspect import signature from itertools import chain from pathlib import Path -from typing import Iterable, List, Tuple, Union +from typing import TYPE_CHECKING, Iterable, List, Tuple, Union import numpy as np from packaging.version import Version, parse -from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available -from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available -from transformers.onnx.config import OnnxConfig -from transformers.utils import logging +from ..file_utils import TensorType, is_tf_available, is_torch_available, is_torch_onnx_dict_inputs_support_available +from ..utils import logging +from .config import OnnxConfig + + +if is_torch_available(): + from ..modeling_utils import PreTrainedModel + +if is_tf_available(): + from ..modeling_tf_utils import TFPreTrainedModel + +if TYPE_CHECKING: + from ..feature_extraction_utils import FeatureExtractionMixin + from ..tokenization_utils import PreTrainedTokenizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -63,18 +74,19 @@ def check_onnxruntime_requirements(minimum_version: Version): def export_pytorch( - tokenizer: PreTrainedTokenizer, - model: PreTrainedModel, + preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"], + model: "PreTrainedModel", config: OnnxConfig, opset: int, output: Path, + tokenizer: "PreTrainedTokenizer" = None, ) -> Tuple[List[str], List[str]]: """ Export a PyTorch model to an ONNX Intermediate Representation (IR) Args: - tokenizer ([`PreTrainedTokenizer`]): - The tokenizer used for encoding the data. + preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]): + The preprocessor used for encoding the data. model ([`PreTrainedModel`]): The model to export. config ([`~onnx.config.OnnxConfig`]): @@ -88,6 +100,11 @@ def export_pytorch( `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from the ONNX configuration. """ + if tokenizer is not None: + warnings.warn( + "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.", + FutureWarning, + ) if issubclass(type(model), PreTrainedModel): import torch from torch.onnx import export as onnx_export @@ -106,7 +123,9 @@ def export_pytorch( # Ensure inputs match # TODO: Check when exporting QA we provide "is_pair=True" - model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) + model_inputs = config.generate_dummy_inputs( + preprocessor, tokenizer=tokenizer, framework=TensorType.PYTORCH + ) inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) onnx_outputs = list(config.outputs.keys()) @@ -150,18 +169,19 @@ def export_pytorch( def export_tensorflow( - tokenizer: PreTrainedTokenizer, - model: TFPreTrainedModel, + preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"], + model: "TFPreTrainedModel", config: OnnxConfig, opset: int, output: Path, + tokenizer: "PreTrainedTokenizer" = None, ) -> Tuple[List[str], List[str]]: """ Export a TensorFlow model to an ONNX Intermediate Representation (IR) Args: - tokenizer ([`PreTrainedTokenizer`]): - The tokenizer used for encoding the data. + preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]): + The preprocessor used for encoding the data. model ([`TFPreTrainedModel`]): The model to export. config ([`~onnx.config.OnnxConfig`]): @@ -180,6 +200,12 @@ def export_tensorflow( import onnx import tf2onnx + if tokenizer is not None: + warnings.warn( + "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.", + FutureWarning, + ) + model.config.return_dict = True # Check if we need to override certain configuration item @@ -190,7 +216,7 @@ def export_tensorflow( setattr(model.config, override_config_key, override_config_value) # Ensure inputs match - model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW) + model_inputs = config.generate_dummy_inputs(preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW) inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) onnx_outputs = list(config.outputs.keys()) @@ -203,18 +229,19 @@ def export_tensorflow( def export( - tokenizer: PreTrainedTokenizer, - model: Union[PreTrainedModel, TFPreTrainedModel], + preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"], + model: Union["PreTrainedModel", "TFPreTrainedModel"], config: OnnxConfig, opset: int, output: Path, + tokenizer: "PreTrainedTokenizer" = None, ) -> Tuple[List[str], List[str]]: """ Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR) Args: - tokenizer ([`PreTrainedTokenizer`]): - The tokenizer used for encoding the data. + preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]): + The preprocessor used for encoding the data. model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): The model to export. config ([`~onnx.config.OnnxConfig`]): @@ -233,26 +260,37 @@ def export( "Cannot convert because neither PyTorch nor TensorFlow are not installed. " "Please install torch or tensorflow first." ) + if tokenizer is not None: + warnings.warn( + "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.", + FutureWarning, + ) if is_torch_available(): - from transformers.file_utils import torch_version + from ..file_utils import torch_version if not is_torch_onnx_dict_inputs_support_available(): raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") + if not config.is_torch_support_available: + logger.warning( + f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}, got: {torch_version}" + ) + if is_torch_available() and issubclass(type(model), PreTrainedModel): - return export_pytorch(tokenizer, model, config, opset, output) + return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): - return export_tensorflow(tokenizer, model, config, opset, output) + return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer) def validate_model_outputs( config: OnnxConfig, - tokenizer: PreTrainedTokenizer, - reference_model: Union[PreTrainedModel, TFPreTrainedModel], + preprocessor: Union["PreTrainedTokenizer", "FeatureExtractionMixin"], + reference_model: Union["PreTrainedModel", "TFPreTrainedModel"], onnx_model: Path, onnx_named_outputs: List[str], atol: float, + tokenizer: "PreTrainedTokenizer" = None, ): from onnxruntime import InferenceSession, SessionOptions @@ -261,9 +299,13 @@ def validate_model_outputs( # TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test # dynamic input shapes. if issubclass(type(reference_model), PreTrainedModel): - reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) + reference_model_inputs = config.generate_dummy_inputs( + preprocessor, tokenizer=tokenizer, framework=TensorType.PYTORCH + ) else: - reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW) + reference_model_inputs = config.generate_dummy_inputs( + preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW + ) # Create ONNX Runtime session options = SessionOptions() @@ -341,7 +383,7 @@ def validate_model_outputs( def ensure_model_and_config_inputs_match( - model: Union[PreTrainedModel, TFPreTrainedModel], model_inputs: Iterable[str] + model: Union["PreTrainedModel", "TFPreTrainedModel"], model_inputs: Iterable[str] ) -> Tuple[bool, List[str]]: """ diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 58db3ed3f4d7..86772725e61f 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -16,6 +16,7 @@ from ..models.mbart import MBartOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.t5 import T5OnnxConfig +from ..models.vit import ViTOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..utils import logging from .config import OnnxConfig @@ -27,6 +28,7 @@ from transformers.models.auto import ( AutoModel, AutoModelForCausalLM, + AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, @@ -89,6 +91,7 @@ class FeaturesManager: "token-classification": AutoModelForTokenClassification, "multiple-choice": AutoModelForMultipleChoice, "question-answering": AutoModelForQuestionAnswering, + "image-classification": AutoModelForImageClassification, } elif is_tf_available(): _TASKS_TO_AUTOMODELS = { @@ -240,6 +243,7 @@ class FeaturesManager: "question-answering", onnx_config_cls=ElectraOnnxConfig, ), + "vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig), } AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 0cd53f885abf..0c968f82cb9a 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -3,23 +3,25 @@ from unittest import TestCase from unittest.mock import patch +import pytest + from parameterized import parameterized -from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available +from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available from transformers.onnx import ( EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, + OnnxConfigWithPast, ParameterFormat, export, validate_model_outputs, ) -from transformers.onnx.config import OnnxConfigWithPast if is_torch_available() or is_tf_available(): from transformers.onnx.features import FeaturesManager from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size -from transformers.testing_utils import require_onnx, require_tf, require_torch, slow +from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow @require_onnx @@ -178,6 +180,7 @@ def test_values_override(self): ("roberta", "roberta-base"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), + ("vit", "google/vit-base-patch16-224"), } PYTORCH_EXPORT_WITH_PAST_MODELS = { @@ -240,25 +243,38 @@ class OnnxExportTestCaseV2(TestCase): def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): from transformers.onnx import export - tokenizer = AutoTokenizer.from_pretrained(model_name) - config = AutoConfig.from_pretrained(model_name) - - # Useful for causal lm models that do not use pad tokens. - if not getattr(config, "pad_token_id", None): - config.pad_token_id = tokenizer.eos_token_id - model_class = FeaturesManager.get_model_class_for_feature(feature) + config = AutoConfig.from_pretrained(model_name) model = model_class.from_config(config) onnx_config = onnx_config_class_constructor(model.config) + if is_torch_available(): + from transformers.file_utils import torch_version + + if torch_version < onnx_config.torch_onnx_minimum_version: + pytest.skip( + f"Skipping due to incompatible PyTorch version. Minimum required is {onnx_config.torch_onnx_minimum_version}, got: {torch_version}" + ) + + # Check the modality of the inputs and instantiate the appropriate preprocessor + if model.main_input_name == "input_ids": + preprocessor = AutoTokenizer.from_pretrained(model_name) + # Useful for causal lm models that do not use pad tokens. + if not getattr(config, "pad_token_id", None): + config.pad_token_id = preprocessor.eos_token_id + elif model.main_input_name == "pixel_values": + preprocessor = AutoFeatureExtractor.from_pretrained(model_name) + else: + raise ValueError(f"Unsupported model input name: {model.main_input_name}") + with NamedTemporaryFile("w") as output: try: onnx_inputs, onnx_outputs = export( - tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name) + preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name) ) validate_model_outputs( onnx_config, - tokenizer, + preprocessor, model, Path(output.name), onnx_outputs, @@ -270,6 +286,7 @@ def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_c @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) @slow @require_torch + @require_vision def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) @@ -290,6 +307,7 @@ def test_pytorch_export_seq2seq_with_past( @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS)) @slow @require_tf + @require_vision def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)