From 6348f4dc3f17bb8234b1d2851f598ad8a2d37ad1 Mon Sep 17 00:00:00 2001 From: hasesh Date: Wed, 14 Aug 2019 17:01:48 -0700 Subject: [PATCH 1/3] More changes --- .../non_max_suppression_test.cc | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc index 9675612b7e12e..0fa2b99d5504f 100644 --- a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc @@ -73,7 +73,7 @@ TEST(NonMaxSuppressionOpTest, TwoClasses) { test.Run(); } -TEST(NonMaxSuppressionOpTest, TwoBathes) { +TEST(NonMaxSuppressionOpTest, TwoBatches_SingleClass) { OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {2, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, @@ -94,6 +94,7 @@ TEST(NonMaxSuppressionOpTest, TwoBathes) { 0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); test.AddInput("max_output_boxes_per_class", {}, {2L}); test.AddInput("iou_threshold", {}, {0.5f}); + test.AddInput("score_threshold", {}, {0.0f}); test.AddOutput("selected_indices", {4, 3}, {0L, 0L, 3L, @@ -103,6 +104,41 @@ TEST(NonMaxSuppressionOpTest, TwoBathes) { test.Run(); } +TEST(NonMaxSuppressionOpTest, TwoBatches_TwoClasses) { + OpTester test("NonMaxSuppression", 10, kOnnxDomain); + test.AddInput("boxes", {2, 5, 4}, + {0.0f, 0.0f, 0.3f, 0.3f, + 0.0f, 0.0f, 0.4f, 0.4f, + 0.0f, 0.0f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.9f, 0.9f, + 0.5f, 0.5f, 1.0f, 1.0f, + + 0.0f, 0.0f, 0.3f, 0.3f, + 0.0f, 0.0f, 0.4f, 0.4f, + 0.0f, 0.0f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.9f, 0.9f, + 0.5f, 0.5f, 1.0f, 1.0f}); + test.AddInput("scores", {2, 2, 5}, + {0.1f, 0.2f, 0.6f, 0.3f, 0.9f, + 0.1f, 0.2f, 0.6f, 0.3f, 0.9f, + + 0.1f, 0.2f, 0.6f, 0.3f, 0.9f, + 0.1f, 0.2f, 0.6f, 0.3f, 0.9f}); + test.AddInput("max_output_boxes_per_class", {}, {2L}); + test.AddInput("iou_threshold", {}, {0.5f}); + test.AddOutput("selected_indices", {8, 3}, + {0L, 0L, 4L, + 0L, 0L, 2L, + 0L, 1L, 4L, + 0L, 1L, 2L, + + 1L, 0L, 4L, + 1L, 0L, 2L, + 1L, 1L, 4L, + 1L, 1L, 2L}); + test.Run(); +} + TEST(NonMaxSuppressionOpTest, WithScoreThreshold) { OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, From b60c8318c4a7cf9d1a13b9087483925442ebc17c Mon Sep 17 00:00:00 2001 From: hariharans29 Date: Wed, 14 Aug 2019 19:16:31 -0700 Subject: [PATCH 2/3] Fix NMS --- .../cpu/object_detection/non_max_suppression.cc | 2 +- .../cpu/object_detection/non_max_suppression_test.cc | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc index 66084547810ad..bdb57248c4e34 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc @@ -141,7 +141,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) { for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) { int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_; - int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4; + int64_t box_offset = batch_index * pc.num_boxes_ * 4; // Filter by score_threshold_ std::priority_queue> sorted_scores_with_index; const auto* class_scores = scores_data + box_score_offset; diff --git a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc index 0fa2b99d5504f..721d04a83d4e8 100644 --- a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc @@ -115,8 +115,8 @@ TEST(NonMaxSuppressionOpTest, TwoBatches_TwoClasses) { 0.0f, 0.0f, 0.3f, 0.3f, 0.0f, 0.0f, 0.4f, 0.4f, - 0.0f, 0.0f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.9f, 0.9f, + 0.5f, 0.5f, 0.95f, 0.95f, + 0.5f, 0.5f, 0.96f, 0.96f, 0.5f, 0.5f, 1.0f, 1.0f}); test.AddInput("scores", {2, 2, 5}, {0.1f, 0.2f, 0.6f, 0.3f, 0.9f, @@ -125,7 +125,7 @@ TEST(NonMaxSuppressionOpTest, TwoBatches_TwoClasses) { 0.1f, 0.2f, 0.6f, 0.3f, 0.9f, 0.1f, 0.2f, 0.6f, 0.3f, 0.9f}); test.AddInput("max_output_boxes_per_class", {}, {2L}); - test.AddInput("iou_threshold", {}, {0.5f}); + test.AddInput("iou_threshold", {}, {0.8f}); test.AddOutput("selected_indices", {8, 3}, {0L, 0L, 4L, 0L, 0L, 2L, @@ -133,9 +133,9 @@ TEST(NonMaxSuppressionOpTest, TwoBatches_TwoClasses) { 0L, 1L, 2L, 1L, 0L, 4L, - 1L, 0L, 2L, + 1L, 0L, 1L, 1L, 1L, 4L, - 1L, 1L, 2L}); + 1L, 1L, 1L}); test.Run(); } From fcdf5e1e9598a8d2ef3c9ee3956a364676ed2429 Mon Sep 17 00:00:00 2001 From: hariharans29 Date: Wed, 14 Aug 2019 19:24:38 -0700 Subject: [PATCH 3/3] nits --- .../providers/cpu/object_detection/non_max_suppression_test.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc index 721d04a83d4e8..45f537bc89046 100644 --- a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc @@ -73,7 +73,7 @@ TEST(NonMaxSuppressionOpTest, TwoClasses) { test.Run(); } -TEST(NonMaxSuppressionOpTest, TwoBatches_SingleClass) { +TEST(NonMaxSuppressionOpTest, TwoBatches_OneClass) { OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {2, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, @@ -94,7 +94,6 @@ TEST(NonMaxSuppressionOpTest, TwoBatches_SingleClass) { 0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); test.AddInput("max_output_boxes_per_class", {}, {2L}); test.AddInput("iou_threshold", {}, {0.5f}); - test.AddInput("score_threshold", {}, {0.0f}); test.AddOutput("selected_indices", {4, 3}, {0L, 0L, 3L,