diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index e0895bcbba..bf1f798c0c 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -11,7 +11,7 @@ from doctr.utils.data import download_from_url -__all__ = ['load_pretrained_params', 'conv_sequence_pt'] +__all__ = ['load_pretrained_params', 'conv_sequence_pt', 'export_classification_model_to_onnx'] def load_pretrained_params( @@ -80,3 +80,33 @@ def conv_sequence_pt( conv_seq.append(nn.ReLU(inplace=True)) return conv_seq + + +def export_classification_model_to_onnx(model: nn.Module, exp_name: str, dummy_input: torch.Tensor) -> str: + """Export classification model to ONNX format. + + >>> import torch + >>> from doctr.models.classification import resnet18 + >>> from doctr.models.utils import export_classification_model_to_onnx + >>> model = resnet18(pretrained=True) + >>> export_classification_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32)) + + Args: + model: the PyTorch model to be exported + exp_name: the name for the exported model + dummy_input: the dummy input to the model + + Returns: + the path to the exported model + """ + torch.onnx.export( + model, + dummy_input, + f"{exp_name}.onnx", + input_names=['input'], + output_names=['logits'], + dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}}, + export_params=True, opset_version=13, verbose=False + ) + logging.info(f"Model exported to {exp_name}.onnx") + return f"{exp_name}.onnx" diff --git a/references/classification/train_pytorch.py b/references/classification/train_pytorch.py index 9336cdcd16..84d198fb77 100644 --- a/references/classification/train_pytorch.py +++ b/references/classification/train_pytorch.py @@ -25,6 +25,7 @@ from doctr import transforms as T from doctr.datasets import VOCABS, CharacterGenerator from doctr.models import classification +from doctr.models.utils import export_classification_model_to_onnx from utils import plot_recorder, plot_samples @@ -334,6 +335,13 @@ def main(args): if args.wb: run.finish() + if args.export_onnx: + print("Exporting model to ONNX...") + dummy_batch = next(iter(val_loader)) + dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0] + model_path = export_classification_model_to_onnx(model, exp_name, dummy_input) + print(f"Exported model saved in {model_path}") + def parse_args(): import argparse @@ -378,6 +386,8 @@ def parse_args(): help='Log to Weights & Biases') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='Load pretrained parameters before starting the training') + parser.add_argument('--export-onnx', dest='export_onnx', action='store_true', + help='Export the model to ONNX') parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use') parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') diff --git a/setup.py b/setup.py index fae8bc7228..4e64ad0d88 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ "pytest>=5.3.2", "coverage>=4.5.4", "hdf5storage>=0.1.18", + "onnxruntime>=1.11.0", "requests>=2.20.0", "requirements-parser==0.2.0", # Quality @@ -137,6 +138,7 @@ def deps_list(*pkgs): "coverage", "requests", "hdf5storage", + "onnxruntime", "requirements-parser", ) diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index 76d05d4bd5..7ff798d635 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -1,10 +1,15 @@ +import os +import tempfile + import cv2 import numpy as np +import onnxruntime import pytest import torch from doctr.models import classification from doctr.models.classification.predictor import CropOrientationPredictor +from doctr.models.utils import export_classification_model_to_onnx @pytest.mark.parametrize( @@ -92,3 +97,37 @@ def test_crop_orientation_model(mock_text_box): text_box_270 = np.rot90(text_box_0, 3) classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True) assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3] + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["vgg16_bn_r", (3, 32, 32), (126,)], + ["resnet18", (3, 32, 32), (126,)], + ["resnet31", (3, 32, 32), (126,)], + ["resnet34", (3, 32, 32), (126,)], + ["resnet34_wide", (3, 32, 32), (126,)], + ["resnet50", (3, 32, 32), (126,)], + ["magc_resnet31", (3, 32, 32), (126,)], + ["mobilenet_v3_small", (3, 32, 32), (126,)], + ["mobilenet_v3_large", (3, 32, 32), (126,)], + ["mobilenet_v3_small_orientation", (3, 128, 128), (4,)], + ], +) +def test_models_onnx_export(arch_name, input_shape, output_size): + # Model + batch_size = 2 + model = classification.__dict__[arch_name](pretrained=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path = export_classification_model_to_onnx(model, + exp_name=os.path.join(tmpdir, "model"), + dummy_input=dummy_input) + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession(os.path.join(tmpdir, "model.onnx"), + providers=["CPUExecutionProvider"]) + ort_outs = ort_session.run(['logits'], {'input': dummy_input.numpy()}) + assert isinstance(ort_outs, list) and len(ort_outs) == 1 + assert ort_outs[0].shape == (batch_size, *output_size) diff --git a/tests/requirements.txt b/tests/requirements.txt index 99025de2e0..0303f097fc 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,4 +2,5 @@ pytest>=5.3.2 requests>=2.20.0 hdf5storage>=0.1.18 coverage>=4.5.4 +onnxruntime>=1.11.0 requirements-parser==0.2.0