Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] Ensure all PyTorch models are ONNX exportable #789

Closed
10 tasks done
Tracked by #791
fg-mindee opened this issue Jan 10, 2022 · 16 comments
Closed
10 tasks done
Tracked by #791

[models] Ensure all PyTorch models are ONNX exportable #789

fg-mindee opened this issue Jan 10, 2022 · 16 comments
Assignees
Labels
critical High priority framework: pytorch Related to PyTorch backend module: models Related to doctr.models topic: onnx ONNX-related
Milestone

Comments

@fg-mindee
Copy link
Contributor

fg-mindee commented Jan 10, 2022

Most users of the library are more interested in existing pretrained models to use for inference rather than training. For this reason, it's important to ensure we can easily export those trained models.

@fg-mindee fg-mindee added critical High priority module: models Related to doctr.models framework: pytorch Related to PyTorch backend topic: onnx ONNX-related labels Jan 10, 2022
@fg-mindee fg-mindee added this to the 0.6.0 milestone Jan 10, 2022
@fg-mindee fg-mindee self-assigned this Jan 10, 2022
@fg-mindee fg-mindee mentioned this issue Jan 10, 2022
85 tasks
@felixdittrich92
Copy link
Contributor

@charlesmindee classification complete 👍

@frytoli
Copy link

frytoli commented May 2, 2022

I'm relatively green when it comes to ML engineering, but I've been attempting to look at exporting the DBNet models to ONNX. I think we're relatively close already and just need to convert inputs/outputs to a data structure compatible with ONNX (specifically not numpy.ndarray). Here's some sample code and the resulting error. I'm slowly working on solving this on my end, but I wanted to post it here in case this is a really easy problem to solve for someone else more familiar with the project.

import torch
from doctr.models.detection import db_resnet50_rotation

model = db_resnet50_rotation(pretrained=True)
torch.onnx.export(model, torch.randn(1, 3, 1024, 1024), "model.onnx", input_names=['input'], output_names=['output'], opset_version=13)
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: numpy.ndarray

@felixdittrich92
Copy link
Contributor

Hi @frytoli 👋,

thanks for working on this 🤗

To clearify where the problem is:
If you run the following with verbose=True you get a more detailed output:

import torch
from doctr.models.detection import db_resnet50_rotation

model = db_resnet50_rotation(pretrained=True)
dummy_input = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
torch.onnx.export(
        model,
        dummy_input,
        "test.onnx",
        input_names=['input'],
        output_names=['logits'],
        dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}},
        export_params=True, opset_version=13, verbose=True
    )

problem: preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())

Currently the models forward pass includes the postprocessing call.
What in my opinion is not optimal to avoid something like this, when modeling, it should be noted that the model only returns specific data, i.e. logits vector and, if ground truth is present, loss.
So what should be done in the first place is to make these changes for all detection and recognition models.
However, doing so would break the pure pipeline, so we would have to see how it affects runtime/latency.
Furthermore, we would have to extend these changes to the Tensorflow part as well, in order to maintain equality on the one hand and on the other hand I think that we can then export the models more easily in Tensorflow's SavedModel format.
If you are interested in it, please open a pull request and we can work on it together if you like.
@charlesmindee @frgfm what is your opinion about this ?

@charlesmindee
Copy link
Collaborator

charlesmindee commented May 3, 2022

Hi @felixdittrich92 and @frytoli,
Thanks for bringing this on the table, to avoid breaking the whole pipeline as Felix mentioned it, we may want to create a flag ("exportable" or "raw" or something like that) that could be passed in the model initialization: db_resnet50(pretrained=True, exportable=True), and this would return in the call of the model the raw output without the postprocessing, or with an adapted post processing (ONNX compatible), what do you think ?
@MaximeChurin

@felixdittrich92
Copy link
Contributor

@charlesmindee @frytoli Sounds good to me, but we have to take care of increasing complexity and the differences between TF and PT implementations. So i would suggest maybe to try it with on model but on both implementations so that we can provide an equivalent implementation which solves the ONNX export problem on PT and the SavedModel export on TF side. Wdyt ?

@charlesmindee
Copy link
Collaborator

Yes it would be nice to do a POC on 1 model first, but we may want to make it exportable in ONNX for both pytorch and tensorflow

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented May 4, 2022

So maybe if we can handle both with Onnx do we really need TF's SavedModel also ? Otherwise i would say let us modify this issue and close the counterpart issue for TF SavedModel 👍 (I will test tf2onnx this week with the classification models)

@felixdittrich92
Copy link
Contributor

@charlesmindee wdyt ?

@frgfm
Copy link
Collaborator

frgfm commented May 7, 2022

Hi all 👋

So, yes, we're very aware of the JIT incompatibility with some data structures & numpy operations. Having conditional executions in call methods could become tricky for maintenance later on

This is a big topic, hence this issue. Here is a suggestion:

  • all models's "core computation" should be compatible with JIT export
  • we separate the post-processing
  • using ONNX export, the post-processing will still need to be put afterwards.

What do you think?

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented May 7, 2022

@frgfm @charlesmindee
Hi again 👋 ,
I see two options now:

  1. We make each model exportable by just going to the logits vector and do post-processing afterwards. @frgfm
  2. We try to make the post-processing step JIT compatible. @charlesmindee

Keep in mind as a user i think the most than want also to use this exported models inside docTR.
So maybe a bit offtopic and not directly matching with this issue but why not providing the training part of this library in PT and TF and the inference pipeline with onnxruntime some points i see:

  1. It would be much easier to maintain
  2. If we really have problems with a model in a specific framework than if you want you could use the working one with a pure onnx inference pipeline no problem
  3. Now with the next release we provide a model sharing option so we have to provide that it is backward compatible (i am really a bit scared about this part 😅 ) would be no problem with onnx
  4. we could more focus on optimization for the inference and have a clear split between training and inference

I know this would be a lot of refactoring but maybe it could be a good match #530

What do you think ? 😄
I think at this point we have to find a good way which will lead to breaking changes before we can go ahead !

@frgfm
Copy link
Collaborator

frgfm commented May 16, 2022

My bad, I should clarify my previous thought:

  • if we can make all the steps JIT compatible, yes sure, we should do it. In my previous attempts, I remember that this would require diving into obscure cv2 functions (which aren't compatible as is)
  • if it's not possible or quite hard, I suggest preserving a post-processing after the ONNX part

About the training & alignment with TF, this is a serious problem. Moving between TF & ONNX is much more troublesome than between PyTorch & ONNX. I fully agree that we should do our best to go into the right direction though 😅

@felixdittrich92
Copy link
Contributor

short update:
both SAR implementations are exportable with ONNX 💯 (only a small flaw for TF we have to provide a constant batch size. Why ? If we pass the dummys batch size as None than it is not possible to compute the initial LSTMCell states)

left:
classification: MAGC (only TF)
recognition: MASTER (TF and PT)

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented Jun 9, 2022

update: with #941 PyTorch is done ❤️
All models can be exported only MAGC and MASTER on TF left

One ref: microsoft/onnxruntime#10994 MASTER (PT) model is exportable but doesn't work with onnxruntime currently until this ticket is closed

arch_name = 'master', input_shape = (3, 32, 128)
    @pytest.mark.parametrize(
        "arch_name, input_shape",
        [
            ["crnn_vgg16_bn", (3, 32, 128)],
            ["crnn_mobilenet_v3_small", (3, 32, 128)],
            ["crnn_mobilenet_v3_large", (3, 32, 128)],
            ["sar_resnet[31](https://github.com/mindee/doctr/runs/6819799280?check_suite_focus=true#step:6:32)", (3, [32](https://github.com/mindee/doctr/runs/6819799280?check_suite_focus=true#step:6:33), 128)],
            ["master", (3, 32, 128)],
        ],
    )
    def test_models_onnx_export(arch_name, input_shape):
        # Model
        batch_size = 2
        model = recognition.__dict__[arch_name](pretrained=True, exportable=True).eval()
        dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
        with tempfile.TemporaryDirectory() as tmpdir:
            # Export
            model_path = export_model_to_onnx(model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input)
            assert os.path.exists(model_path)
            # Inference
            ort_session = onnxruntime.InferenceSession(os.path.join(tmpdir, "model.onnx"),
>                                                      providers=["CPUExecutionProvider"])
tests/pytorch/test_models_recognition_pt.py:114: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:[33](https://github.com/mindee/doctr/runs/6819799280?check_suite_focus=true#step:6:34)5: in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession object at 0x7fcae7bdca90>
providers = ['CPUExecutionProvider'], provider_options = [{}]
disabled_optimizers = set()
    def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
        available_providers = C.get_available_providers()
        # Tensorrt can fall back to CUDA. All others fall back to CPU.
        if 'TensorrtExecutionProvider' in available_providers:
            self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        elif 'MIGraphXExecutionProvider' in available_providers:
            self._fallback_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider']
        else:
            self._fallback_providers = ['CPUExecutionProvider']
        # validate providers and provider_options before other initialization
        providers, provider_options = check_and_normalize_provider_args(providers,
                                                                        provider_options,
                                                                        available_providers)
        if providers == [] and len(available_providers) > 1:
            self.disable_fallback()
            raise ValueError("This ORT build has {} enabled. ".format(available_providers) +
                             "Since ORT 1.9, you are required to explicitly set " +
                             "the providers parameter when instantiating InferenceSession. For example, "
                             "onnxruntime.InferenceSession(..., providers={}, ...)".format(available_providers))
        session_options = self._sess_options if self._sess_options else C.get_default_session_options()
        if self._model_path:
            sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
        else:
            sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
        if disabled_optimizers is None:
            disabled_optimizers = set()
        elif not isinstance(disabled_optimizers, set):
            # convert to set. assumes iterable
            disabled_optimizers = set(disabled_optimizers)
        # initialize the C++ InferenceSession
>       sess.initialize_session(providers, provider_options, disabled_optimizers)
E       onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name 'Trilu_422'

@felixdittrich92
Copy link
Contributor

@frgfm @charlesmindee I think this issue is only about the export i would open another issue for the loading part wdyt ? :)

@frgfm
Copy link
Collaborator

frgfm commented Jul 20, 2022

@felixdittrich92 yes precisely 👍

@felixdittrich92
Copy link
Contributor

@frgfm @charlesmindee I think it is ok if we close this all models (without postprocessing) can be exported

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
critical High priority framework: pytorch Related to PyTorch backend module: models Related to doctr.models topic: onnx ONNX-related
Projects
None yet
Development

No branches or pull requests

5 participants