Skip to content

Commit

Permalink
Add ONNX export for BeiT (#16498)
Browse files Browse the repository at this point in the history
* Add beit onnx conversion support

* Updated docs

* Added cross reference to ViT ONNX config
  • Loading branch information
akuma12 authored Apr 1, 2022
1 parent bfeff6c commit 9de70f2
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Ready-made configurations include the following architectures:

- ALBERT
- BART
- BEiT
- BERT
- Blenderbot
- BlenderbotSmall
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/beit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


_import_structure = {
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"],
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig", "BeitOnnxConfig"],
}

if is_vision_available():
Expand All @@ -48,7 +48,7 @@
]

if TYPE_CHECKING:
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig, BeitOnnxConfig

if is_vision_available():
from .feature_extraction_beit import BeitFeatureExtractor
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/models/beit/configuration_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" BEiT model configuration"""
from collections import OrderedDict
from typing import Mapping

from packaging import version

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -176,3 +181,21 @@ def __init__(
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
self.semantic_loss_ignore_index = semantic_loss_ignore_index


# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
class BeitOnnxConfig(OnnxConfig):

torch_onnx_minimum_version = version.parse("1.11")

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)

@property
def atol_for_validation(self) -> float:
return 1e-4
2 changes: 2 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.beit import BeitOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
Expand Down Expand Up @@ -270,6 +271,7 @@ class FeaturesManager:
onnx_config_cls=ElectraOnnxConfig,
),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"blenderbot": supported_features_mapping(
"default",
"default-with-past",
Expand Down
6 changes: 3 additions & 3 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
export,
validate_model_outputs,
)
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow


if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager

from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow


@require_onnx
class OnnxUtilsTestCaseV2(TestCase):
Expand Down Expand Up @@ -181,6 +180,7 @@ def test_values_override(self):
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
}

PYTORCH_EXPORT_WITH_PAST_MODELS = {
Expand Down

0 comments on commit 9de70f2

Please sign in to comment.