Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export for ViT #15658

Merged
merged 38 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8fce819
Add ONNX support for ViT
lewtun Feb 11, 2022
5e15830
Refactor to use generic preprocessor
lewtun Feb 15, 2022
be90f25
Refactor
lewtun Feb 15, 2022
81287ec
Fix ONNX conversion for models with fast tokenizers
lewtun Feb 15, 2022
8103587
Fix copies
lewtun Feb 15, 2022
bcbdbd9
Add vision to tests
lewtun Feb 15, 2022
12a5306
Remove fixed ViT outputs
lewtun Feb 15, 2022
ba0a7b0
Extend ONNX slow tests to ViT
lewtun Feb 15, 2022
0ebbcae
Add dummy image generator
lewtun Feb 15, 2022
d179861
Use model_type to determine modality
lewtun Feb 15, 2022
b1b4f61
Add deprecation warnings for tokenizer argument
lewtun Feb 15, 2022
b0491e8
Add warning when overwriting the preprocessor
lewtun Feb 15, 2022
7f03d43
Add optional args to docstrings
lewtun Feb 15, 2022
26dcdde
Add TODO
lewtun Feb 15, 2022
99dd9a7
Add minimum PyTorch version to OnnxConfig
lewtun Feb 23, 2022
41fa7e0
Merge branch 'master' into vision-onnx-export
lewtun Feb 23, 2022
5ce5801
Fix minimum torch version
lewtun Feb 23, 2022
84bcfaa
Refactor
lewtun Feb 23, 2022
fcca7dc
Add vision dependency to CI tests
lewtun Feb 23, 2022
687d436
Tweak docstring
lewtun Feb 23, 2022
e4e3343
Add check on torch minimum version
lewtun Feb 24, 2022
8207be1
Merge branch 'master' into vision-onnx-export
lewtun Feb 24, 2022
ade513f
Replace absolute imports with relative ones
lewtun Feb 24, 2022
a7baf9a
Apply Sylvain's suggestions from code review
lewtun Feb 24, 2022
d898037
Merge remote-tracking branch 'origin/vision-onnx-export' into vision-…
lewtun Feb 24, 2022
88d25cf
Fix imports
lewtun Feb 24, 2022
941689b
Refactor OnnxConfig class variables from CONSTANT_NAME to snake_case
lewtun Feb 24, 2022
31dd4f9
Fix ViT torch version
lewtun Feb 24, 2022
d1f9397
Fix docstring
lewtun Feb 24, 2022
49dca94
Fix imports and add logging
lewtun Feb 24, 2022
aec42f8
Use relative imports for real this time and use type checking
lewtun Feb 24, 2022
750db82
Add check for vision feature extractor
lewtun Feb 24, 2022
951df50
Refactor imports for type checking
lewtun Feb 24, 2022
48129a7
Skip ONNX test if torch version is incompatible
lewtun Feb 25, 2022
b2e618e
Revert ImportError vs AssertionError
lewtun Feb 25, 2022
1514807
Revert gitignore
lewtun Feb 28, 2022
7ac6312
Replace ImportError with warning
lewtun Mar 9, 2022
81eedea
Add reasonable value for default atol
lewtun Mar 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ jobs:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime]
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the vision dependency here to the CI tests. I couldn't see any other files in .github/workflows where I needed to add it, but please let me know otherwise

- save_cache:
key: v0.4-onnx-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -816,7 +816,7 @@ jobs:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime]
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision]
- save_cache:
key: v0.4-onnx-{{ checksum "setup.py" }}
paths:
Expand Down
1 change: 1 addition & 0 deletions docs/source/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Ready-made configurations include the following architectures:
- PLBart
- RoBERTa
- T5
- ViT
- XLM-RoBERTa
- XLM-RoBERTa-XL

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)

# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)

# Generate dummy inputs according to compute batch and sequence
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/marian/configuration_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,13 @@ def _generate_dummy_inputs_for_encoder_and_decoder(
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)

# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)

# Generate dummy inputs according to compute batch and sequence
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mbart/configuration_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,13 @@ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
)

# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
seq_length = compute_effective_axis_dimension(
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)

# Generate dummy inputs according to compute batch and sequence
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/vit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig", "ViTOnnxConfig"],
}

if is_vision_available():
Expand Down Expand Up @@ -51,7 +51,7 @@
]

if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig, ViTOnnxConfig

if is_vision_available():
from .feature_extraction_vit import ViTFeatureExtractor
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/models/vit/configuration_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
# limitations under the License.
""" ViT model configuration"""

from collections import OrderedDict
from typing import Mapping

from packaging import version

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -119,3 +125,20 @@ def __init__(
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.encoder_stride = encoder_stride


class ViTOnnxConfig(OnnxConfig):

torch_onnx_minimum_version = version.parse("1.11")

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
lewtun marked this conversation as resolved.
Show resolved Hide resolved
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)

@property
def atol_for_validation(self) -> float:
return 1e-4
26 changes: 21 additions & 5 deletions src/transformers/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from argparse import ArgumentParser
from pathlib import Path

from transformers.models.auto import AutoTokenizer

from ..models.auto import AutoConfig, AutoFeatureExtractor, AutoTokenizer
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
from ..utils import logging
from .convert import export, validate_model_outputs
from .features import FeaturesManager
Expand Down Expand Up @@ -46,8 +47,17 @@ def main():
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)

# Check the modality of the inputs and instantiate the appropriate preprocessor
# TODO(lewtun): Refactor this as a function if we need to check modalities elsewhere as well
config = AutoConfig.from_pretrained(args.model)
if config.model_type in TOKENIZER_MAPPING_NAMES:
preprocessor = AutoTokenizer.from_pretrained(args.model)
elif config.model_type in FEATURE_EXTRACTOR_MAPPING_NAMES:
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
else:
raise ValueError(f"Unsupported model type: {config.model_type}")

# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)
Expand All @@ -62,12 +72,18 @@ def main():
f"At least {onnx_config.default_onnx_opset} is required."
)

onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
onnx_inputs, onnx_outputs = export(
preprocessor,
model,
onnx_config,
args.opset,
args.output,
)

if args.atol is None:
args.atol = onnx_config.atol_for_validation

validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
logger.info(f"All good, model saved at: {args.output.as_posix()}")


Expand Down
Loading