From e369e388bb55c5ddab7303dada64629097837f0c Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 27 Sep 2023 12:14:03 +0300 Subject: [PATCH] Fixed bug of no saving simplified ONNX file (#1489) (cherry picked from commit 7ab603b7da9f7a367a275481c9d3662927d678bb) --- .../module_interfaces/exportable_detector.py | 14 +++++++++++--- tests/unit_tests/export_detection_model_test.py | 1 - 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/module_interfaces/exportable_detector.py b/src/super_gradients/module_interfaces/exportable_detector.py index bddde3266f..2ca8990138 100644 --- a/src/super_gradients/module_interfaces/exportable_detector.py +++ b/src/super_gradients/module_interfaces/exportable_detector.py @@ -6,6 +6,7 @@ from typing import Union, Optional, List, Tuple import numpy as np +import onnx import onnxsim import torch from torch import nn, Tensor @@ -495,9 +496,12 @@ def export( if onnx_simplify: # If TRT engine is used, we need to run onnxsim.simplify BEFORE attaching NMS, # because EfficientNMS_TRT is not supported by onnxsim and would lead to a runtime error. - onnxsim.simplify(output) + model_opt, simplify_successful = onnxsim.simplify(output) + if not simplify_successful: + raise RuntimeError(f"Failed to simplify ONNX model {output} with onnxsim. Please check the logs for details.") + onnx.save(model_opt, output) logger.debug(f"Ran onnxsim.simplify on model {output}") - # Disable onnx_simplify to avoid running it twice. + # Disable onnx_simplify to avoid running it second time. onnx_simplify = False nms_attach_method = attach_tensorrt_nms @@ -528,7 +532,11 @@ def export( ) if onnx_simplify: - onnxsim.simplify(output) + model_opt, simplify_successful = onnxsim.simplify(output) + if not simplify_successful: + raise RuntimeError(f"Failed to simplify ONNX model {output} with onnxsim. Please check the logs for details.") + onnx.save(model_opt, output) + logger.debug(f"Ran onnxsim.simplify on {output}") finally: if quantization_mode == ExportQuantizationMode.INT8: diff --git a/tests/unit_tests/export_detection_model_test.py b/tests/unit_tests/export_detection_model_test.py index 9b2d924ac9..924b40c003 100644 --- a/tests/unit_tests/export_detection_model_test.py +++ b/tests/unit_tests/export_detection_model_test.py @@ -303,7 +303,6 @@ def test_export_with_fp16_quantization(self): max_predictions_per_image = 300 with tempfile.TemporaryDirectory() as tmpdirname: - tmpdirname = "." out_path = os.path.join(tmpdirname, "ppyoloe_s_with_fp16_quantization.onnx") ppyolo_e: ExportableObjectDetectionModel = models.get(Models.PP_YOLOE_S, pretrained_weights="coco")