Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] classification models export #830

Merged
merged 25 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
81c313e
backup
felixdittrich92 Jan 11, 2022
50574b5
Merge branch 'mindee:main' into main
felixdittrich92 Jan 11, 2022
5a6ed54
Merge branch 'mindee:main' into main
felixdittrich92 Jan 18, 2022
b9958a7
Merge branch 'mindee:main' into main
felixdittrich92 Jan 20, 2022
14c4651
Merge branch 'mindee:main' into main
felixdittrich92 Feb 16, 2022
779731f
Merge branch 'mindee:main' into main
felixdittrich92 Feb 18, 2022
ce2cdda
Merge branch 'mindee:main' into main
felixdittrich92 Feb 22, 2022
76b9652
onnx classification
felixdittrich92 Feb 22, 2022
72e5e0d
fix: Fixed some ResNet architecture imprecisions (#828)
fg-mindee Feb 23, 2022
1633679
update with new models
felixdittrich92 Feb 23, 2022
6511183
feat: replace bce by focal loss in linknet loss (#824)
charlesmindee Feb 23, 2022
5e77167
Revert "feat: replace bce by focal loss in linknet loss (#824)"
felixdittrich92 Feb 24, 2022
5c0a514
Revert "fix: Fixed some ResNet architecture imprecisions (#828)"
felixdittrich92 Feb 24, 2022
21218aa
happy codacy
felixdittrich92 Feb 24, 2022
8ddb081
sapply suggestions
felixdittrich92 Apr 4, 2022
191aa20
fix-setup
felixdittrich92 Apr 4, 2022
fe2bcbe
Merge branch 'main' into onnx-classification
felixdittrich92 Apr 4, 2022
d0b1efc
remove onnx from test req
felixdittrich92 Apr 4, 2022
be55545
move onnx deps ftm to torch
felixdittrich92 Apr 4, 2022
b97512b
up
felixdittrich92 Apr 4, 2022
6b3971d
up
felixdittrich92 Apr 4, 2022
d31f626
revert requirements
felixdittrich92 Apr 4, 2022
ee7a3c2
fix
felixdittrich92 Apr 4, 2022
f52fa83
update docstring
felixdittrich92 Apr 4, 2022
6976419
up
felixdittrich92 Apr 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.utils.data import download_from_url

__all__ = ['load_pretrained_params', 'conv_sequence_pt']
__all__ = ['load_pretrained_params', 'conv_sequence_pt', 'export_classification_model_to_onnx']


def load_pretrained_params(
Expand Down Expand Up @@ -80,3 +80,33 @@ def conv_sequence_pt(
conv_seq.append(nn.ReLU(inplace=True))

return conv_seq


def export_classification_model_to_onnx(model: nn.Module, exp_name: str, dummy_input: torch.Tensor) -> str:
"""Export classification model to ONNX format.

>>> import torch
>>> from doctr.models.classification import resnet18
>>> from doctr.models.utils import export_classification_model_to_onnx
>>> model = resnet18(pretrained=True)
>>> export_classification_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))

Args:
model: the PyTorch model to be exported
exp_name: the name for the exported model
dummy_input: the dummy input to the model

Returns:
the path to the exported model
"""
torch.onnx.export(
model,
dummy_input,
f"{exp_name}.onnx",
input_names=['input'],
output_names=['logits'],
dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}},
export_params=True, opset_version=13, verbose=False
)
logging.info(f"Model exported to {exp_name}.onnx")
return f"{exp_name}.onnx"
10 changes: 10 additions & 0 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from doctr import transforms as T
from doctr.datasets import VOCABS, CharacterGenerator
from doctr.models import classification
from doctr.models.utils import export_classification_model_to_onnx
from utils import plot_recorder, plot_samples


Expand Down Expand Up @@ -334,6 +335,13 @@ def main(args):
if args.wb:
run.finish()

if args.export_onnx:
print("Exporting model to ONNX...")
dummy_batch = next(iter(val_loader))
dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
model_path = export_classification_model_to_onnx(model, exp_name, dummy_input)
print(f"Exported model saved in {model_path}")


def parse_args():
import argparse
Expand Down Expand Up @@ -378,6 +386,8 @@ def parse_args():
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--export-onnx', dest='export_onnx', action='store_true',
help='Export the model to ONNX')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR')
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"pytest>=5.3.2",
"coverage>=4.5.4",
"hdf5storage>=0.1.18",
"onnxruntime>=1.11.0",
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
"requests>=2.20.0",
"requirements-parser==0.2.0",
# Quality
Expand Down Expand Up @@ -137,6 +138,7 @@ def deps_list(*pkgs):
"coverage",
"requests",
"hdf5storage",
"onnxruntime",
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
"requirements-parser",
)

Expand Down
39 changes: 39 additions & 0 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import tempfile

import cv2
import numpy as np
import onnxruntime
import pytest
import torch

from doctr.models import classification
from doctr.models.classification.predictor import CropOrientationPredictor
from doctr.models.utils import export_classification_model_to_onnx


@pytest.mark.parametrize(
Expand Down Expand Up @@ -92,3 +97,37 @@ def test_crop_orientation_model(mock_text_box):
text_box_270 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3]


@pytest.mark.parametrize(
"arch_name, input_shape, output_size",
[
["vgg16_bn_r", (3, 32, 32), (126,)],
["resnet18", (3, 32, 32), (126,)],
["resnet31", (3, 32, 32), (126,)],
["resnet34", (3, 32, 32), (126,)],
["resnet34_wide", (3, 32, 32), (126,)],
["resnet50", (3, 32, 32), (126,)],
["magc_resnet31", (3, 32, 32), (126,)],
["mobilenet_v3_small", (3, 32, 32), (126,)],
["mobilenet_v3_large", (3, 32, 32), (126,)],
["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
],
)
def test_models_onnx_export(arch_name, input_shape, output_size):
# Model
batch_size = 2
model = classification.__dict__[arch_name](pretrained=True).eval()
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
with tempfile.TemporaryDirectory() as tmpdir:
# Export
model_path = export_classification_model_to_onnx(model,
exp_name=os.path.join(tmpdir, "model"),
dummy_input=dummy_input)
assert os.path.exists(model_path)
# Inference
ort_session = onnxruntime.InferenceSession(os.path.join(tmpdir, "model.onnx"),
providers=["CPUExecutionProvider"])
ort_outs = ort_session.run(['logits'], {'input': dummy_input.numpy()})
assert isinstance(ort_outs, list) and len(ort_outs) == 1
assert ort_outs[0].shape == (batch_size, *output_size)
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pytest>=5.3.2
requests>=2.20.0
hdf5storage>=0.1.18
coverage>=4.5.4
onnxruntime>=1.11.0
requirements-parser==0.2.0