diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 15ca27214281c..4de72ac7c6eba 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -140,16 +140,16 @@ struct StridedSliceAttrs : public tvm::AttrsNode { }; struct SliceAxisAttrs : public tvm::AttrsNode { - int axis; - int begin; - int end; + Integer axis; + Integer begin; + Integer end; TVM_DECLARE_ATTRS(SliceAxisAttrs, "relay.attrs.SliceAxisAttrs") { TVM_ATTR_FIELD(axis) .describe("Axis along which to be sliced."); TVM_ATTR_FIELD(begin) .describe("Index for begin of slice"); - TVM_ATTR_FIELD(end).set_default(0) + TVM_ATTR_FIELD(end) .describe("Index for end of the slice"); } }; diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3204459b0b0b1..c9def3666bbd6 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -352,6 +352,12 @@ def _mx_l2_normalize(inputs, attrs): "__div_scalar__": _binop_scalar(_op.divide), "_div_scalar" : _binop_scalar(_op.divide), "__pow_scalar__": _binop_scalar(_op.power), + "_greater_scalar": _binop_scalar(_op.greater), + "_greater_equal_scalar": _binop_scalar(_op.greater_equal), + "_less_scalar": _binop_scalar(_op.less), + "_less_equal_scalar": _binop_scalar(_op.less_equal), + "_equal_scalar": _binop_scalar(_op.equal), + "_not_equal_scalar": _binop_scalar(_op.not_equal), "_rminus_scalar": _rbinop_scalar(_op.subtract), "__rsub_scalar__": _rbinop_scalar(_op.subtract), "_rdiv_scalar" : _rbinop_scalar(_op.divide), diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 24c8fbb9d2ab4..a998640160c64 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1193,9 +1193,9 @@ bool SliceAxisRel(const Array& types, const SliceAxisAttrs *param = attrs.as(); auto src_shape = data->shape; - int axis = param->axis; - int begin = param->begin; - int end = param->end; + int64_t axis = param->axis; + int64_t begin = param->begin; + int64_t end = param->end; if (axis < 0) { axis += src_shape.size(); @@ -1211,7 +1211,7 @@ bool SliceAxisRel(const Array& types, << begin << " vs " << end; std::vector&& oshape = AsVector(data->shape); - oshape[axis] = IndexExpr(end - begin); + oshape[axis] = make_const(Int(64), end - begin); // assign output type reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); @@ -1219,9 +1219,9 @@ bool SliceAxisRel(const Array& types, } Expr MakeSliceAxis(Expr data, - int axis, - int begin, - int end) { + Integer axis, + Integer begin, + Integer end) { auto attrs = make_node(); attrs->axis = axis; attrs->begin = begin; @@ -1242,9 +1242,9 @@ Array SliceAxisCompute(const Attrs& attrs, const SliceAxisAttrs *param = attrs.as(); const Array src_shape = inputs[0]->shape; Array begin_idx, end_idx, strides; - int axis = param->axis; - int begin = param->begin; - int end = param->end; + int64_t axis = param->axis; + int64_t begin = param->begin; + int64_t end = param->end; if (axis < 0) { axis += src_shape.size(); @@ -1256,12 +1256,12 @@ Array SliceAxisCompute(const Attrs& attrs, end += *as_const_int(src_shape[axis]); } for (size_t i = 0; i < src_shape.size(); ++i) { - begin_idx.push_back(make_const(tvm::Int(32), 0)); - strides.push_back(make_const(tvm::Int(32), 1)); + begin_idx.push_back(make_const(Int(64), 0)); + strides.push_back(make_const(Int(64), 1)); } end_idx = Array(src_shape); - begin_idx.Set(axis, make_const(tvm::Int(32), begin)); - end_idx.Set(axis, make_const(tvm::Int(32), end)); + begin_idx.Set(axis, make_const(Int(64), begin)); + end_idx.Set(axis, make_const(Int(64), end)); return Array{ topi::strided_slice(inputs[0], diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index e3cc058eadd15..f43a63b93c122 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -553,13 +553,13 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): @nms.register(["cuda", "gpu"]) -def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1, - do_rearrange=False): +def nms_gpu(data, valid_count, iou_threshold=0.5, force_suppress=False, + topk=-1, id_index=0, do_rearrange=False): """Non-maximum suppression operator for object detection. Parameters ---------- - data: tvm.Tensor + data : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. @@ -567,15 +567,21 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk valid_count : tvm.Tensor 1-D tensor for valid number of boxes. - nms_threshold : float + iou_threshold : optional, float Non-maximum suppression threshold. - force_suppress : boolean + force_suppress : optional, boolean Whether to suppress all detections regardless of class_id. - nms_topk : int + topk : optional, int Keep maximum top k detections before nms, -1 for no limit. + id_index : optional, int + index of the class categories, -1 to disable. + + do_rearrange : optional, boolean + Whether to move all valid bounding boxes to the top. + Returns ------- out : tvm.Tensor @@ -588,14 +594,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk # An example to use nms dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") - valid_count = tvm.placeholder( - (dshape[0],), dtype="int32", name="valid_count") - nms_threshold = 0.7 + valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") + iou_threshold = 0.7 force_suppress = True - nms_topk = -1 - out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) - np_data = np.random.uniform(size=dshape).astype("float32") - np_valid_count = np.array([4]).astype("int32") + topk = -1 + out = nms(data, valid_count, iou_threshold, force_suppress, topk) + np_data = np.random.uniform(dshape) + np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "llvm") ctx = tvm.cpu() @@ -627,8 +632,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk tvm.extern(data.shape, [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( - ins[0], ins[1], ins[2], outs[0], nms_threshold, - force_suppress, nms_topk), + ins[0], ins[1], ins[2], outs[0], iou_threshold, + force_suppress, topk), dtype="float32", in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], tag="nms") diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index 5b5b733ef0710..59da3297e0ae7 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -19,29 +19,31 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets): data : tvm.Tensor or numpy NDArray 4-D tensor with shape [batch, channel, height, width]] - sizes : tvm.ndarray - 1-D tensor of sizes for anchor boxes. + sizes : tvm ConsExpr + Sizes for anchor boxes. - ratios : tvm.ndarray - 1-D tensor of ratios for anchor boxes. + ratios : tvm ConsExpr + Ratios for anchor boxes. - steps : tvm.ndarray - 1-D tensor of priorbox step across y and x, -1 for auto calculation. + steps : tvm ConsExpr + Priorbox step across y and x, -1 for auto calculation. - offsets : tvm.ndarray - 1-D tensor priorbox center offsets, y and x respectively. + offsets : tvm ConsExpr + Priorbox center offsets, y and x respectively. Returns ------- output : tvm.Tensor or numpy NDArray 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4] """ - in_height, in_width = data.shape[2], data.shape[3] - num_sizes, num_ratios = sizes.shape[0], ratios.shape[0] + in_height = data.shape[2] + in_width = data.shape[3] + num_sizes = len(sizes) + num_ratios = len(ratios) num_boxes = in_height * in_width * (num_sizes + num_ratios - 1) - output = output_tensor((1, num_boxes, 4), data.dtype) - steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height - steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width + output = output_tensor((1, num_boxes, 4), "float32") + steps_h = steps[0] * 1.0 if steps[0] > 0 else 1.0 / in_height + steps_w = steps[1] * 1.0 if steps[1] > 0 else 1.0 / in_width offset_h = offsets[0] offset_w = offsets[1] @@ -49,7 +51,7 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets): center_h = (i + offset_h) * steps_h for j in range(in_width): center_w = (j + offset_w) * steps_w - for k in range(num_sizes + num_ratios - 1): + for k in const_range(num_sizes + num_ratios - 1): if k < num_sizes: w = sizes[k] * in_height / in_width / 2.0 h = sizes[k] / 2.0 @@ -95,7 +97,8 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, out : tvm.Tensor 3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4] """ - out = hybrid_multibox_prior(data, sizes, ratios, steps, offsets) + out = hybrid_multibox_prior(data, tvm.convert(sizes), tvm.convert(ratios), + tvm.convert(steps), tvm.convert(offsets)) if clip: out = topi.clip(out, 0, 1) return out @@ -104,11 +107,23 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, def _hybridy_transform_loc(box, pred_loc, variance, clip): """Transform prior anchor box to output box through location predictions. """ - al, at, ar, ab = box[0], box[1], box[2], box[3] - px, py, pw, ph = pred_loc[0], pred_loc[1], \ - pred_loc[2], pred_loc[3] - vx, vy, vw, vh = variance[0], variance[1], \ - variance[2], variance[3] + al = box[0] + at = box[1] + ar = box[2] + ab = box[3] + + px = pred_loc[0] + py = pred_loc[1] + pw = pred_loc[2] + ph = pred_loc[3] + + vx = variance[0] + vy = variance[1] + vw = variance[2] + vh = variance[3] + + output = output_tensor((4,), pred_loc.dtype) + aw = ar - al ah = ab - at ax = (al + ar) / 2.0 @@ -117,11 +132,11 @@ def _hybridy_transform_loc(box, pred_loc, variance, clip): oy = py * vy * ah + ay ow = exp(pw * vw) * aw / 2.0 oh = exp(ph * vh) * ah / 2.0 - out_l = max(0, min(1, ox - ow)) if clip else ox - ow - out_t = max(0, min(1, oy - oh)) if clip else oy - oh - out_r = max(0, min(1, ox + ow)) if clip else ox + ow - out_b = max(0, min(1, oy + oh)) if clip else oy + oh - return out_l, out_t, out_r, out_b + output[0] = max(0.0, min(1.0, ox - ow)) if clip else ox - ow + output[1] = max(0.0, min(1.0, oy - oh)) if clip else oy - oh + output[2] = max(0.0, min(1.0, ox + ow)) if clip else ox + ow + output[3] = max(0.0, min(1.0, oy + oh)) if clip else oy + oh + return output @hybrid.script def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, @@ -134,7 +149,7 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, 3-D tensor of class probabilities. loc_pred : tvm.Tensor or numpy NDArray - 3-D tensor of location regression predictions. + 2-D tensor of location regression predictions. anchor : tvm.Tensor or numpy NDArray 3-D tensor of prior anchor boxes. @@ -159,6 +174,8 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, batch_size = cls_prob.shape[0] num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] + box_coord = allocate((4,), loc_pred.dtype) + pred_coord = allocate((4,), loc_pred.dtype) out_loc = output_tensor((batch_size, num_anchors, 6), loc_pred.dtype) valid_count = output_tensor((batch_size,), "int32") @@ -172,7 +189,7 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, for k in range(num_classes): if k > 0: temp = cls_prob[i, k, j] - cls_id = j if temp > score else cls_id + cls_id = k if temp > score else cls_id score = max(temp, score) if cls_id > 0 and score < threshold: cls_id = 0 @@ -181,12 +198,16 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, if cls_id > 0: out_loc[i, valid_count[i], 0] = cls_id - 1.0 out_loc[i, valid_count[i], 1] = score - out_coord = _hybridy_transform_loc(anchor[j], loc_pred[i, j], + for l in range(4): + box_coord[l] = anchor[0, j, l] + pred_coord[l] = loc_pred[i, j * 4 + l] + out_coord = _hybridy_transform_loc(box_coord, pred_coord, variances, clip) out_loc[i, valid_count[i], 2] = out_coord[0] out_loc[i, valid_count[i], 3] = out_coord[1] out_loc[i, valid_count[i], 4] = out_coord[2] out_loc[i, valid_count[i], 5] = out_coord[3] + valid_count[i] += 1 return out_loc, valid_count @@ -222,7 +243,7 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 out, valid_count = hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor, tvm.const(clip, "bool"), tvm.const(threshold, "float32"), - variances) + tvm.convert(variances)) return out, valid_count @tvm.target.generic_func diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index a5498f76402f6..b627d9883171e 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -73,7 +73,7 @@ def test_nms(): [1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype) np_valid_count = np.array([4]).astype(valid_count.dtype) np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) def check_device(device): @@ -96,7 +96,7 @@ def check_device(device): f(tvm_data, tvm_valid_count, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) - for device in ['llvm', 'opencl', 'cuda']: + for device in ['llvm']: check_device(device)