Skip to content

Commit

Permalink
[onnx] classification models export (mindee#830)
Browse files Browse the repository at this point in the history
* backup

* onnx classification

* fix: Fixed some ResNet architecture imprecisions (mindee#828)

* feat: Added new resnets

* feat: Added ResNet101

* fix: Fixed ResNet31 & ResNet34 wide

* feat: Added new pretrained resnets

* style: Fixed isort

* fix: Fixed ResNet architectures

* refactor: Refactored LinkNet

* feat: Added more LinkNets

* fix: Fixed MAGResNet

* docs: Updated documentation

* refactor: Removed ResNet101

* fix: Fixed warning

* fix: Fixed a few bugs

* test: Updated unittests

* docs: Fixed docstrings

* update with new models

* feat: replace bce by focal loss in linknet loss (mindee#824)

* feat: replace bce by focal loss in linknet loss

* fix: requested changes

* fix: mask reduction

* fix: mask reduction

* fix: loss reduction

* fix: final adjustements

* fix: final changes

* Revert "feat: replace bce by focal loss in linknet loss (mindee#824)"

This reverts commit 6511183.

* Revert "fix: Fixed some ResNet architecture imprecisions (mindee#828)"

This reverts commit 72e5e0d.

* happy codacy

* sapply suggestions

* fix-setup

* remove onnx from test req

* move onnx deps ftm to torch

* up

* up

* revert requirements

* fix

* update docstring

* up

Co-authored-by: F-G Fernandez <[email protected]>
Co-authored-by: Charles Gaillard <[email protected]>
  • Loading branch information
3 people committed Apr 5, 2022
1 parent 32a7498 commit 6b79301
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 1 deletion.
32 changes: 31 additions & 1 deletion doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.utils.data import download_from_url

__all__ = ['load_pretrained_params', 'conv_sequence_pt']
__all__ = ['load_pretrained_params', 'conv_sequence_pt', 'export_classification_model_to_onnx']


def load_pretrained_params(
Expand Down Expand Up @@ -87,3 +87,33 @@ def conv_sequence_pt(
conv_seq.append(nn.ReLU(inplace=True))

return conv_seq


def export_classification_model_to_onnx(model: nn.Module, exp_name: str, dummy_input: torch.Tensor) -> str:
"""Export classification model to ONNX format.
>>> import torch
>>> from doctr.models.classification import resnet18
>>> from doctr.models.utils import export_classification_model_to_onnx
>>> model = resnet18(pretrained=True)
>>> export_classification_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
Args:
model: the PyTorch model to be exported
exp_name: the name for the exported model
dummy_input: the dummy input to the model
Returns:
the path to the exported model
"""
torch.onnx.export(
model,
dummy_input,
f"{exp_name}.onnx",
input_names=['input'],
output_names=['logits'],
dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}},
export_params=True, opset_version=13, verbose=False
)
logging.info(f"Model exported to {exp_name}.onnx")
return f"{exp_name}.onnx"
10 changes: 10 additions & 0 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from doctr import transforms as T
from doctr.datasets import VOCABS, CharacterGenerator
from doctr.models import classification
from doctr.models.utils import export_classification_model_to_onnx
from utils import plot_recorder, plot_samples


Expand Down Expand Up @@ -334,6 +335,13 @@ def main(args):
if args.wb:
run.finish()

if args.export_onnx:
print("Exporting model to ONNX...")
dummy_batch = next(iter(val_loader))
dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
model_path = export_classification_model_to_onnx(model, exp_name, dummy_input)
print(f"Exported model saved in {model_path}")


def parse_args():
import argparse
Expand Down Expand Up @@ -378,6 +386,8 @@ def parse_args():
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--export-onnx', dest='export_onnx', action='store_true',
help='Export the model to ONNX')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR')
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"pytest>=5.3.2",
"coverage>=4.5.4",
"hdf5storage>=0.1.18",
"onnxruntime>=1.11.0",
"requests>=2.20.0",
"requirements-parser==0.2.0",
# Quality
Expand Down Expand Up @@ -137,6 +138,7 @@ def deps_list(*pkgs):
"coverage",
"requests",
"hdf5storage",
"onnxruntime",
"requirements-parser",
)

Expand Down
39 changes: 39 additions & 0 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import tempfile

import cv2
import numpy as np
import onnxruntime
import pytest
import torch

from doctr.models import classification
from doctr.models.classification.predictor import CropOrientationPredictor
from doctr.models.utils import export_classification_model_to_onnx


def _test_classification(model, input_shape, output_size, batch_size=2):
Expand Down Expand Up @@ -98,3 +103,37 @@ def test_crop_orientation_model(mock_text_box):
text_box_270 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3]


@pytest.mark.parametrize(
"arch_name, input_shape, output_size",
[
["vgg16_bn_r", (3, 32, 32), (126,)],
["resnet18", (3, 32, 32), (126,)],
["resnet31", (3, 32, 32), (126,)],
["resnet34", (3, 32, 32), (126,)],
["resnet34_wide", (3, 32, 32), (126,)],
["resnet50", (3, 32, 32), (126,)],
["magc_resnet31", (3, 32, 32), (126,)],
["mobilenet_v3_small", (3, 32, 32), (126,)],
["mobilenet_v3_large", (3, 32, 32), (126,)],
["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
],
)
def test_models_onnx_export(arch_name, input_shape, output_size):
# Model
batch_size = 2
model = classification.__dict__[arch_name](pretrained=True).eval()
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
with tempfile.TemporaryDirectory() as tmpdir:
# Export
model_path = export_classification_model_to_onnx(model,
exp_name=os.path.join(tmpdir, "model"),
dummy_input=dummy_input)
assert os.path.exists(model_path)
# Inference
ort_session = onnxruntime.InferenceSession(os.path.join(tmpdir, "model.onnx"),
providers=["CPUExecutionProvider"])
ort_outs = ort_session.run(['logits'], {'input': dummy_input.numpy()})
assert isinstance(ort_outs, list) and len(ort_outs) == 1
assert ort_outs[0].shape == (batch_size, *output_size)
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pytest>=5.3.2
requests>=2.20.0
hdf5storage>=0.1.18
coverage>=4.5.4
onnxruntime>=1.11.0
requirements-parser==0.2.0

0 comments on commit 6b79301

Please sign in to comment.