diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 11b4ebfcfaad..7fa1ffb8a4fe 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -79,10 +79,16 @@ struct MultiBoxTransformLocAttrs /*! \brief Attributes used in get_valid_counts operator */ struct GetValidCountsAttrs : public tvm::AttrsNode { double score_threshold; + int id_index; + int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { TVM_ATTR_FIELD(score_threshold).set_default(0.0) .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(id_index).set_default(0) + .describe("Axis index of id."); + TVM_ATTR_FIELD(score_index).set_default(1) + .describe("Index of the scores/confidence of boxes."); } }; diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0975a33450c8..81ef51b91336 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -569,7 +569,8 @@ def _mx_box_nms(inputs, attrs): raise tvm.error.OpAttributeInvalid( 'Value of attribute "out_format" must equal "corner" for operator box_nms.') - ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) + ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh, + id_index=id_index, score_index=score_index) nms_out = _op.vision.non_max_suppression(ret[1], ret[0], iou_threshold=iou_thresh, diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 8c8c4cd9aaa3..7de118071aa4 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -82,7 +82,10 @@ def schedule_get_valid_counts(_, outs, target): def compute_get_valid_counts(attrs, inputs, _, target): """Compute definition of get_valid_counts""" score_threshold = get_const_float(attrs.score_threshold) - return topi.vision.get_valid_counts(inputs[0], score_threshold) + id_index = get_const_int(attrs.id_index) + score_index = get_const_int(attrs.score_index) + return topi.vision.get_valid_counts(inputs[0], score_threshold, + id_index, score_index) reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index ab34eb6e6cfb..d19dde306aca 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -20,7 +20,9 @@ from ...expr import TupleWrapper def get_valid_counts(data, - score_threshold): + score_threshold, + id_index=0, + score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -32,6 +34,12 @@ def get_valid_counts(data, score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- valid_count : relay.Expr @@ -40,7 +48,8 @@ def get_valid_counts(data, out_tensor : relay.Expr Rearranged data tensor. """ - return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2) + return TupleWrapper(_make.get_valid_counts(data, score_threshold, + id_index, score_index), 2) def non_max_suppression(data, diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 2e5661cdc4dc..c0160e7d7128 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -50,9 +50,13 @@ bool GetValidCountRel(const Array& types, } Expr MakeGetValidCounts(Expr data, - double score_threshold) { + double score_threshold, + int id_index, + int score_index) { auto attrs = make_node(); attrs->score_threshold = score_threshold; + attrs->id_index = id_index; + attrs->score_index = score_index; static const Op& op = Op::Get("vision.get_valid_counts"); return CallNode::make(op, {data}, Attrs(attrs), {}); } diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 21b227f6b3b5..3d9ec6dde4ad 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -152,28 +152,28 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), def test_get_valid_counts(): - def verify_get_valid_counts(dshape, score_threshold): + def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): dtype = "float32" batch_size, num_anchor, elem_length = dshape - np_data = np.random.uniform(size=dshape).astype(dtype) + np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 for j in range(num_anchor): - score = np_data[i, j, 1] - if score >= score_threshold: + score = np_data[i, j, score_index] + if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 inter_idx += 1 if j >= np_out1[i]: for k in range(elem_length): - np_out2[i, j, k] = -1 + np_out2[i, j, k] = -1.0 x = relay.var("x", relay.ty.TensorType(dshape, dtype)) - z = relay.vision.get_valid_counts(x, score_threshold) + z = relay.vision.get_valid_counts(x, score_threshold, id_index, score_index) assert "score_threshold" in z.astext() func = relay.Function([x], z.astuple()) func = relay.ir_pass.infer_type(func) @@ -185,10 +185,10 @@ def verify_get_valid_counts(dshape, score_threshold): tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) - verify_get_valid_counts((1, 2500, 6), 0) - verify_get_valid_counts((1, 2500, 6), -1) - verify_get_valid_counts((3, 1000, 6), 0.55) - verify_get_valid_counts((16, 500, 6), 0.95) + verify_get_valid_counts((1, 2500, 6), 0, 0, 1) + verify_get_valid_counts((1, 2500, 5), -1, -1, 0) + verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0) + verify_get_valid_counts((16, 500, 5), 0.95, -1, 0) def test_non_max_suppression(): diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 460584bc8b78..c0da4a45ec8d 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -313,7 +313,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): @get_valid_counts.register(["cuda", "gpu"]) -def get_valid_counts_gpu(data, score_threshold=0): +def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -325,6 +325,12 @@ def get_valid_counts_gpu(data, score_threshold=0): score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- valid_count : tvm.Tensor diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index 7c8d7db33059..a6ba56eeb943 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements +# pylint: disable=import-error, invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements, too-many-function-args """Non-maximum suppression operator""" import tvm @@ -60,7 +60,7 @@ def hybrid_rearrange_out(data): @hybrid.script -def hybrid_get_valid_counts(data, score_threshold): +def hybrid_get_valid_counts(data, score_threshold, id_index, score_index): """Hybrid routine to get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -68,11 +68,18 @@ def hybrid_get_valid_counts(data, score_threshold): Parameters ---------- data : tvm.Tensor or numpy NDArray - Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + Input data. 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. score_threshold : tvm.const Lower limit of score for valid bounding boxes. + id_index : tvm.const + index of the class categories, -1 to disable. + + score_index: tvm.const + Index of the scores/confidence of boxes. + Returns ------- out_tensor : tvm.Tensor or numpy NDArray @@ -92,8 +99,9 @@ def hybrid_get_valid_counts(data, score_threshold): for i in parallel(batch_size): valid_count[i] = 0 for j in range(num_anchors): - score = data[i, j, 1] - if score > score_threshold: + score = data[i, j, score_index] + if score > score_threshold and \ + (id_index < 0 or data[i, j, id_index] >= 0): for k in range(box_data_length): out_tensor[i, valid_count[i], k] = data[i, j, k] valid_count[i] += 1 @@ -103,18 +111,25 @@ def hybrid_get_valid_counts(data, score_threshold): return valid_count, out_tensor @tvm.target.generic_func -def get_valid_counts(data, score_threshold=0): +def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. Parameters ---------- data : tvm.Tensor - Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + Input data. 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- out_tensor : tvm.Tensor @@ -123,14 +138,17 @@ def get_valid_counts(data, score_threshold=0): valid_count : tvm.Tensor 1-D tensor for valid number of boxes. """ - score_threshold_const = tvm.const(score_threshold, "float") - return hybrid_get_valid_counts(data, score_threshold_const) + score_threshold_const = tvm.const(score_threshold, "float32") + id_index_const = tvm.const(id_index, "int32") + score_index_const = tvm.const(score_index, "int32") + return hybrid_get_valid_counts(data, score_threshold_const, + id_index_const, score_index_const) @hybrid.script def hybrid_nms(data, sorted_index, valid_count, max_output_size, iou_threshold, force_suppress, - top_k, coord_start, id_index): + top_k, coord_start, id_index, score_index): """Hybrid routing for non-maximum suppression. Parameters @@ -165,6 +183,9 @@ def hybrid_nms(data, sorted_index, valid_count, id_index : tvm.const index of the class categories, -1 to disable. + score_index: tvm.const + Index of the scores/confidence of boxes. + Returns ------- output : tvm.Tensor @@ -182,41 +203,42 @@ def hybrid_nms(data, sorted_index, valid_count, box_data_length,), data.dtype) - for i in parallel(batch_size): + for i in range(batch_size): if iou_threshold > 0: if valid_count[i] > 0: # Reorder output nkeep = valid_count[i] if 0 < top_k < nkeep: nkeep = top_k - for j in range(nkeep): + for j in parallel(nkeep): for k in range(box_data_length): output[i, j, k] = data[i, sorted_index[i, j], k] box_indices[i, j] = sorted_index[i, j] if 0 < top_k < valid_count[i]: - for j in range(valid_count[i] - nkeep): + for j in parallel(valid_count[i] - nkeep): for k in range(box_data_length): output[i, j + nkeep, k] = -1.0 box_indices[i, j + nkeep] = -1 # Apply nms + box_start_idx = coord_start + batch_idx = i for j in range(valid_count[i]): - if output[i, j, 0] >= 0: - for k in range(valid_count[i]): + if output[i, j, score_index] > 0 and (id_index < 0 or output[i, j, id_index] >= 0): + box_a_idx = j + for k in parallel(valid_count[i]): check_iou = 0 - if k > j and output[i, k, 0] >= 0: + if k > j and output[i, k, score_index] > 0 \ + and (id_index < 0 or output[i, k, id_index] >= 0): if force_suppress: check_iou = 1 - elif id_index < 0 or output[i, j, 0] == output[i, k, 0]: + elif id_index < 0 or output[i, j, id_index] == output[i, k, id_index]: check_iou = 1 if check_iou > 0: - batch_idx = i - box_a_idx = j - box_b_idx = k - box_start_idx = coord_start - a_t = output[batch_idx, box_a_idx, box_start_idx + 1] - a_b = output[batch_idx, box_a_idx, box_start_idx + 3] a_l = output[batch_idx, box_a_idx, box_start_idx] + a_t = output[batch_idx, box_a_idx, box_start_idx + 1] a_r = output[batch_idx, box_a_idx, box_start_idx + 2] + a_b = output[batch_idx, box_a_idx, box_start_idx + 3] + box_b_idx = k b_t = output[batch_idx, box_b_idx, box_start_idx + 1] b_b = output[batch_idx, box_b_idx, box_start_idx + 3] b_l = output[batch_idx, box_b_idx, box_start_idx] @@ -227,22 +249,24 @@ def hybrid_nms(data, sorted_index, valid_count, u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area iou = 0.0 if u <= 0.0 else area / u if iou >= iou_threshold: - output[i, k, 0] = -1.0 + output[i, k, score_index] = -1.0 + if id_index >= 0: + output[i, k, id_index] = -1.0 box_indices[i, k] = -1 else: - for j in range(valid_count[i]): + for j in parallel(valid_count[i]): for k in range(box_data_length): output[i, j, k] = data[i, j, k] box_indices[i, j] = j # Set invalid entry to be -1 - for j in range(num_anchors - valid_count[i]): + for j in parallel(num_anchors - valid_count[i]): for k in range(box_data_length): output[i, j + valid_count[i], k] = -1.0 box_indices[i, j + valid_count[i]] = -1 # Only return max_output_size valid boxes num_valid_boxes = 0 if max_output_size > 0: - for j in range(valid_count[i]): + for j in parallel(valid_count[i]): if output[i, j, 0] >= 0: if num_valid_boxes == max_output_size: for k in range(box_data_length): @@ -263,9 +287,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, Parameters ---------- data : tvm.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. - The last dimension should be in format of - [class_id, score, box_left, box_top, box_right, box_bottom]. + 3-D tensor with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. @@ -338,7 +360,8 @@ def non_max_suppression(data, valid_count, max_output_size=-1, tvm.const(force_suppress, dtype="bool"), tvm.const(top_k, dtype="int32"), tvm.const(coord_start, dtype="int32"), - tvm.const(id_index, dtype="int32")) + tvm.const(id_index, dtype="int32"), + tvm.const(score_index, dtype="int32")) if not return_indices and invalid_to_bottom: out = hybrid_rearrange_out(out) diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 54c80c6e8c30..3a0b13489037 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -27,18 +27,18 @@ from topi.vision import ssd, non_max_suppression, get_valid_counts -def verify_get_valid_counts(dshape, score_threshold): +def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): dtype = "float32" batch_size, num_anchor, elem_length = dshape - np_data = np.random.uniform(size=dshape).astype(dtype) + np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 for j in range(num_anchor): - score = np_data[i, j, 1] - if score > score_threshold: + score = np_data[i, j, score_index] + if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 @@ -55,8 +55,8 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): data = tvm.placeholder(dshape, name="data", dtype=dtype) - outs = get_valid_counts(data, score_threshold) - s = topi.generic.schedule_multibox_prior(outs) + outs = get_valid_counts(data, score_threshold, id_index, score_index) + s = topi.generic.schedule_get_valid_counts(outs) tvm_input_data = tvm.nd.array(np_data, ctx) tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) @@ -67,33 +67,26 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) for device in ['llvm', 'cuda', 'opencl']: + # Disable gpu test for now + if device != "llvm": + continue check_device(device) def test_get_valid_counts(): - verify_get_valid_counts((1, 2500, 6), 0) - verify_get_valid_counts((1, 2500, 6), -1) - verify_get_valid_counts((3, 1000, 6), 0.55) - verify_get_valid_counts((16, 500, 6), 0.95) + verify_get_valid_counts((1, 2500, 6), 0, 0, 1) + verify_get_valid_counts((1, 2500, 5), -1, -1, 0) + verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0) + verify_get_valid_counts((16, 500, 5), 0.95, -1, 1) -def test_non_max_suppression(): - dshape = (1, 5, 6) - indices_dshape = (1, 5) +def verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, iou_threshold, + force_suppress, top_k, coord_start, score_index, id_index): + dshape = np_data.shape + batch, num_anchors, _ = dshape + indices_dshape = (batch, num_anchors) data = tvm.placeholder(dshape, name="data") - valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") - nms_threshold = 0.7 - force_suppress = True - nms_topk = 2 - - np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], - [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], - [1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype) - np_valid_count = np.array([4]).astype(valid_count.dtype) - np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[3, 0, -1, -1, -1]]) + valid_count = tvm.placeholder((batch,), dtype="int32", name="valid_count") def check_device(device): ctx = tvm.context(device, 0) @@ -103,11 +96,17 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): if device == 'llvm': - out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) - indices_out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) + out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index, + return_indices=False) + indices_out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index) else: - out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) - indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) + out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index, + return_indices=False) + indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index) s = topi.generic.schedule_nms(out) indices_s = topi.generic.schedule_nms(indices_out) @@ -128,6 +127,30 @@ def check_device(device): check_device(device) +def test_non_max_suppression(): + np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], + [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], + [1, 0.5, 100, 60, 70, 110]]]).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], + [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1]]]) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) + + verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) + + np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80], + [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79], + [0.5, 100, 60, 70, 110]]]).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45], + [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]]]) + np_indices_result = np.array([[3, 0, -1, -1, -1]]) + verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) + + + def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False): data = tvm.placeholder(dshape, name="data")