From 60cf692a63a22cd2698273c4945f037b4b22474b Mon Sep 17 00:00:00 2001 From: czh978 <41666381+czh978@users.noreply.github.com> Date: Mon, 19 Sep 2022 13:49:04 +0800 Subject: [PATCH] [Frontend][TFLite] fix detection_postprocess's non_max_suppression_attrs["force_suppress"] (#12593) * [Frontend][TFLite]fix detection_postprocess's non_max_suppression_attrs["force_suppress"] Since tvm only supports operators detection_postprocess use_regular_nms is false, which will suppress boxes that exceed the threshold regardless of the class when implementing NMS in tflite, in order for the results of tvm and tflite to be consistent, we need to set force_suppress to True. * [Frontend][TFLite]fix detection_postprocess's non_max_suppression_attrs[force_suppress] Added a test case that reproduces inconsistent results between tvm and tflite When the force_suppress is false,it will get a good result if you set the force_suppress as true --- python/tvm/relay/frontend/tflite.py | 2 +- tests/python/frontend/tflite/test_forward.py | 37 ++++++++++++++------ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6c68230e0ecc..a7e10ad72e55 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3355,7 +3355,7 @@ def convert_detection_postprocess(self, op): non_max_suppression_attrs = {} non_max_suppression_attrs["return_indices"] = False non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"] - non_max_suppression_attrs["force_suppress"] = False + non_max_suppression_attrs["force_suppress"] = True non_max_suppression_attrs["top_k"] = anchor_boxes non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"] non_max_suppression_attrs["invalid_to_bottom"] = False diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index deaef72e1d7f..7b2bd60d8a20 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4311,13 +4311,8 @@ def test_forward_matrix_diag(): # ---------------- -def test_detection_postprocess(): - """Detection PostProcess""" - tf_model_file = tf_testing.get_workload_official( - "http://download.tensorflow.org/models/object_detection/" - "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz", - "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb", - ) +def _test_detection_postprocess(tf_model_file, box_encodings_size, class_predictions_size): + """One iteration of detection postProcess with given model and shapes""" converter = tf.lite.TFLiteConverter.from_frozen_graph( tf_model_file, input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"], @@ -4328,16 +4323,16 @@ def test_detection_postprocess(): "TFLite_Detection_PostProcess:3", ], input_shapes={ - "raw_outputs/box_encodings": (1, 1917, 4), - "raw_outputs/class_predictions": (1, 1917, 91), + "raw_outputs/box_encodings": box_encodings_size, + "raw_outputs/class_predictions": class_predictions_size, }, ) converter.allow_custom_ops = True converter.inference_type = tf.lite.constants.FLOAT tflite_model = converter.convert() np.random.seed(0) - box_encodings = np.random.uniform(size=(1, 1917, 4)).astype("float32") - class_predictions = np.random.uniform(size=(1, 1917, 91)).astype("float32") + box_encodings = np.random.uniform(size=box_encodings_size).astype("float32") + class_predictions = np.random.uniform(size=class_predictions_size).astype("float32") tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) tvm_output = run_tvm_graph( tflite_model, @@ -4382,6 +4377,26 @@ def test_detection_postprocess(): ) +def test_detection_postprocess(): + """Detection PostProcess""" + box_encodings_size = (1, 1917, 4) + class_predictions_size = (1, 1917, 91) + tf_model_file = tf_testing.get_workload_official( + "http://download.tensorflow.org/models/object_detection/" + "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz", + "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb", + ) + _test_detection_postprocess(tf_model_file, box_encodings_size, class_predictions_size) + + box_encodings_size = (1, 2034, 4) + class_predictions_size = (1, 2034, 91) + tf_model_file = download_testdata( + "https://github.com/czh978/models_for_tvm_test/raw/main/tflite_graph_with_postprocess.pb", + "tflite_graph_with_postprocess.pb", + ) + _test_detection_postprocess(tf_model_file, box_encodings_size, class_predictions_size) + + ####################################################################### # Custom Converter # ----------------