Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export for ViT #15658

Merged
merged 38 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8fce819
Add ONNX support for ViT
lewtun Feb 11, 2022
5e15830
Refactor to use generic preprocessor
lewtun Feb 15, 2022
be90f25
Refactor
lewtun Feb 15, 2022
81287ec
Fix ONNX conversion for models with fast tokenizers
lewtun Feb 15, 2022
8103587
Fix copies
lewtun Feb 15, 2022
bcbdbd9
Add vision to tests
lewtun Feb 15, 2022
12a5306
Remove fixed ViT outputs
lewtun Feb 15, 2022
ba0a7b0
Extend ONNX slow tests to ViT
lewtun Feb 15, 2022
0ebbcae
Add dummy image generator
lewtun Feb 15, 2022
d179861
Use model_type to determine modality
lewtun Feb 15, 2022
b1b4f61
Add deprecation warnings for tokenizer argument
lewtun Feb 15, 2022
b0491e8
Add warning when overwriting the preprocessor
lewtun Feb 15, 2022
7f03d43
Add optional args to docstrings
lewtun Feb 15, 2022
26dcdde
Add TODO
lewtun Feb 15, 2022
99dd9a7
Add minimum PyTorch version to OnnxConfig
lewtun Feb 23, 2022
41fa7e0
Merge branch 'master' into vision-onnx-export
lewtun Feb 23, 2022
5ce5801
Fix minimum torch version
lewtun Feb 23, 2022
84bcfaa
Refactor
lewtun Feb 23, 2022
fcca7dc
Add vision dependency to CI tests
lewtun Feb 23, 2022
687d436
Tweak docstring
lewtun Feb 23, 2022
e4e3343
Add check on torch minimum version
lewtun Feb 24, 2022
8207be1
Merge branch 'master' into vision-onnx-export
lewtun Feb 24, 2022
ade513f
Replace absolute imports with relative ones
lewtun Feb 24, 2022
a7baf9a
Apply Sylvain's suggestions from code review
lewtun Feb 24, 2022
d898037
Merge remote-tracking branch 'origin/vision-onnx-export' into vision-…
lewtun Feb 24, 2022
88d25cf
Fix imports
lewtun Feb 24, 2022
941689b
Refactor OnnxConfig class variables from CONSTANT_NAME to snake_case
lewtun Feb 24, 2022
31dd4f9
Fix ViT torch version
lewtun Feb 24, 2022
d1f9397
Fix docstring
lewtun Feb 24, 2022
49dca94
Fix imports and add logging
lewtun Feb 24, 2022
aec42f8
Use relative imports for real this time and use type checking
lewtun Feb 24, 2022
750db82
Add check for vision feature extractor
lewtun Feb 24, 2022
951df50
Refactor imports for type checking
lewtun Feb 24, 2022
48129a7
Skip ONNX test if torch version is incompatible
lewtun Feb 25, 2022
b2e618e
Revert ImportError vs AssertionError
lewtun Feb 25, 2022
1514807
Revert gitignore
lewtun Feb 28, 2022
7ac6312
Replace ImportError with warning
lewtun Mar 9, 2022
81eedea
Add reasonable value for default atol
lewtun Mar 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Ready-made configurations include the following architectures:
- OpenAI GPT-2
- RoBERTa
- T5
- ViT
- XLM-RoBERTa
- XLM-RoBERTa-XL

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/vit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig", "ViTOnnxConfig"],
}

if is_vision_available():
Expand Down Expand Up @@ -50,7 +50,7 @@
]

if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig, ViTOnnxConfig

if is_vision_available():
from .feature_extraction_vit import ViTFeatureExtractor
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/vit/configuration_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
# limitations under the License.
""" ViT model configuration"""

from collections import OrderedDict
from typing import Mapping

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -117,3 +121,13 @@ def __init__(
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias


class ViTOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
lewtun marked this conversation as resolved.
Show resolved Hide resolved
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
20 changes: 16 additions & 4 deletions src/transformers/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from argparse import ArgumentParser
from pathlib import Path

from transformers.models.auto import AutoTokenizer
from transformers.models.auto import AutoFeatureExtractor, AutoTokenizer

from ..utils import logging
from .convert import export, validate_model_outputs
Expand Down Expand Up @@ -47,10 +47,16 @@ def main():
args.output.parent.mkdir(parents=True)

# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)
# Check the modality of the inputs and instantiate the appropriate preprocessor
if model.main_input_name == "input_ids":
preprocessor = AutoTokenizer.from_pretrained(args.model)
elif model.main_input_name == "pixel_values":
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
lewtun marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Unsupported model input name: {model.main_input_name}")

# Ensure the requested opset is sufficient
if args.opset is None:
Expand All @@ -62,12 +68,18 @@ def main():
f"At least {onnx_config.default_onnx_opset} is required."
)

onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
onnx_inputs, onnx_outputs = export(
preprocessor,
model,
onnx_config,
args.opset,
args.output,
)

if args.atol is None:
args.atol = onnx_config.atol_for_validation

validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
logger.info(f"All good, model saved at: {args.output.as_posix()}")


Expand Down
78 changes: 56 additions & 22 deletions src/transformers/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,24 @@
import dataclasses
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union

import requests
from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TensorType,
is_torch_available,
is_vision_available,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have absolute imports mixed with relative imports, so everything should be relative imports.
Also PreTrainedTokenizerFast should not be imported outside of a check for is_tokenizers_available but I don't think you really need it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the advice. Here I was just following what already existed, but I'll take care of the relative imports in this PR too :)


from ..feature_extraction_utils import FeatureExtractionMixin
from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size

from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType, is_torch_available

from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size
if is_vision_available():
from PIL import Image


DEFAULT_ONNX_OPSET = 11
Expand Down Expand Up @@ -71,6 +84,7 @@ class OnnxConfig(ABC):
"end_logits": {0: "batch", 1: "sequence"},
}
),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}

def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None):
Expand Down Expand Up @@ -197,40 +211,60 @@ def use_external_data_format(num_parameters: int) -> bool:

def generate_dummy_inputs(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This base method now has a mix of arguments for text and image modalities. I'm not 100% sure if we should split the modalities apart ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You split it now right? Just checking to make sure.

self,
tokenizer: PreTrainedTokenizer,
preprocessor: Union[PreTrainedTokenizer, FeatureExtractionMixin],
lewtun marked this conversation as resolved.
Show resolved Hide resolved
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
tokenizer: PreTrainedTokenizer = None,
lewtun marked this conversation as resolved.
Show resolved Hide resolved
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework

Args:
tokenizer: The tokenizer associated with this model configuration
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
framework: The framework (optional) the tokenizer will generate tensor for
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
The preprocessor associated with this model configuration.
batch_size (`int`):
The batch size (int) to export the model for (-1 means dynamic axis)
seq_length (`int`):
The sequence length (int) to export the model for (-1 means dynamic axis)
is_pair (`bool`):
Indicate if the input is a pair (sentence 1, sentence 2)
framework (`TensorType`):
The framework (optional) the tokenizer will generate tensor for
lewtun marked this conversation as resolved.
Show resolved Hide resolved
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer associated with this model configuration
lewtun marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
if isinstance(preprocessor, PreTrainedTokenizer) and tokenizer:
lewtun marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("You cannot provide both a tokenizer and a preprocessor to generate dummy inputs.")
if isinstance(preprocessor, PreTrainedTokenizer) or isinstance(preprocessor, PreTrainedTokenizerFast):
lewtun marked this conversation as resolved.
Show resolved Hide resolved
if tokenizer:
lewtun marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just tested if the preprocessor is a tokenizer and if tokenizer is not None above and returned an error. This branch will never be reached.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I found it was convenient to re-order the test so that we overwrite the preprocessor arg only if preprocessor=None and tokenizer is not None

preprocessor = tokenizer
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
)

# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
)

# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
)
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = preprocessor.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
)

# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
return dict(tokenizer(dummy_input, return_tensors=framework))
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
return dict(preprocessor(dummy_input, return_tensors=framework))
elif isinstance(preprocessor, FeatureExtractionMixin) and is_vision_available():
lewtun marked this conversation as resolved.
Show resolved Hide resolved
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lewtun marked this conversation as resolved.
Show resolved Hide resolved
image = Image.open(requests.get(url, stream=True).raw)
return dict(preprocessor(images=image, return_tensors=framework))
else:
raise ValueError(
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
)

def patch_ops(self):
for spec in self._patching_specs:
Expand Down
56 changes: 38 additions & 18 deletions src/transformers/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from transformers.onnx.config import OnnxConfig
from transformers.utils import logging

from ..feature_extraction_utils import FeatureExtractionMixin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. All the imports are a mess of relative and absolute. This should all be relative.



logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -63,26 +65,29 @@ def check_onnxruntime_requirements(minimum_version: Version):


def export_pytorch(
tokenizer: PreTrainedTokenizer,
preprocessor: Union[PreTrainedTokenizer, FeatureExtractionMixin],
lewtun marked this conversation as resolved.
Show resolved Hide resolved
model: PreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
tokenizer: PreTrainedTokenizer = None,
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch model to an ONNX Intermediate Representation (IR)

Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`]):
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
The preprocessor used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
lewtun marked this conversation as resolved.
Show resolved Hide resolved
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -106,7 +111,9 @@ def export_pytorch(

# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
model_inputs = config.generate_dummy_inputs(
preprocessor, tokenizer=tokenizer, framework=TensorType.PYTORCH
)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

Expand Down Expand Up @@ -150,26 +157,29 @@ def export_pytorch(


def export_tensorflow(
tokenizer: PreTrainedTokenizer,
preprocessor: Union[PreTrainedTokenizer, FeatureExtractionMixin],
model: TFPreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
tokenizer: PreTrainedTokenizer = None,
) -> Tuple[List[str], List[str]]:
"""
Export a TensorFlow model to an ONNX Intermediate Representation (IR)

Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`TFPreTrainedModel`]):
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
The preprocessor used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
lewtun marked this conversation as resolved.
Show resolved Hide resolved
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -190,7 +200,7 @@ def export_tensorflow(
setattr(model.config, override_config_key, override_config_value)

# Ensure inputs match
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
model_inputs = config.generate_dummy_inputs(preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

Expand All @@ -203,18 +213,19 @@ def export_tensorflow(


def export(
tokenizer: PreTrainedTokenizer,
preprocessor: Union[PreTrainedTokenizer, FeatureExtractionMixin],
model: Union[PreTrainedModel, TFPreTrainedModel],
config: OnnxConfig,
opset: int,
output: Path,
tokenizer: PreTrainedTokenizer = None,
) -> Tuple[List[str], List[str]]:
"""
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)

Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
preprocessor: ([`PreTrainedTokenizer`] or [`FeatureExtractionMixin`]):
The preprocessor used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
Expand All @@ -223,6 +234,8 @@ def export(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -241,18 +254,19 @@ def export(
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")

if is_torch_available() and issubclass(type(model), PreTrainedModel):
return export_pytorch(tokenizer, model, config, opset, output)
return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
return export_tensorflow(tokenizer, model, config, opset, output)
return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer)


def validate_model_outputs(
config: OnnxConfig,
tokenizer: PreTrainedTokenizer,
preprocessor: Union[PreTrainedTokenizer, FeatureExtractionMixin],
reference_model: Union[PreTrainedModel, TFPreTrainedModel],
onnx_model: Path,
onnx_named_outputs: List[str],
atol: float,
tokenizer: PreTrainedTokenizer = None,
):
from onnxruntime import InferenceSession, SessionOptions

Expand All @@ -261,9 +275,15 @@ def validate_model_outputs(
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
if issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
reference_model_inputs = config.generate_dummy_inputs(
preprocessor,
tokenizer=tokenizer,
framework=TensorType.PYTORCH,
lewtun marked this conversation as resolved.
Show resolved Hide resolved
)
else:
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
reference_model_inputs = config.generate_dummy_inputs(
preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW
)

# Create ONNX Runtime session
options = SessionOptions()
Expand Down
Loading