Skip to content

Commit

Permalink
Support multibox op with hybrid script
Browse files Browse the repository at this point in the history
  • Loading branch information
Wang committed Jan 11, 2019
1 parent 73e0b0e commit 7ef52f3
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 64 deletions.
8 changes: 4 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,16 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
};

struct SliceAxisAttrs : public tvm::AttrsNode<SliceAxisAttrs> {
int axis;
int begin;
int end;
Integer axis;
Integer begin;
Integer end;

TVM_DECLARE_ATTRS(SliceAxisAttrs, "relay.attrs.SliceAxisAttrs") {
TVM_ATTR_FIELD(axis)
.describe("Axis along which to be sliced.");
TVM_ATTR_FIELD(begin)
.describe("Index for begin of slice");
TVM_ATTR_FIELD(end).set_default(0)
TVM_ATTR_FIELD(end)
.describe("Index for end of the slice");
}
};
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ def _mx_l2_normalize(inputs, attrs):
"__div_scalar__": _binop_scalar(_op.divide),
"_div_scalar" : _binop_scalar(_op.divide),
"__pow_scalar__": _binop_scalar(_op.power),
"_greater_scalar": _binop_scalar(_op.greater),
"_greater_equal_scalar": _binop_scalar(_op.greater_equal),
"_less_scalar": _binop_scalar(_op.less),
"_less_equal_scalar": _binop_scalar(_op.less_equal),
"_equal_scalar": _binop_scalar(_op.equal),
"_not_equal_scalar": _binop_scalar(_op.not_equal),
"_rminus_scalar": _rbinop_scalar(_op.subtract),
"__rsub_scalar__": _rbinop_scalar(_op.subtract),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
Expand Down
28 changes: 14 additions & 14 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1193,9 +1193,9 @@ bool SliceAxisRel(const Array<Type>& types,
const SliceAxisAttrs *param = attrs.as<SliceAxisAttrs>();

auto src_shape = data->shape;
int axis = param->axis;
int begin = param->begin;
int end = param->end;
int64_t axis = param->axis;
int64_t begin = param->begin;
int64_t end = param->end;

if (axis < 0) {
axis += src_shape.size();
Expand All @@ -1211,17 +1211,17 @@ bool SliceAxisRel(const Array<Type>& types,
<< begin << " vs " << end;

std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] = IndexExpr(end - begin);
oshape[axis] = make_const(Int(64), end - begin);

// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}

Expr MakeSliceAxis(Expr data,
int axis,
int begin,
int end) {
Integer axis,
Integer begin,
Integer end) {
auto attrs = make_node<SliceAxisAttrs>();
attrs->axis = axis;
attrs->begin = begin;
Expand All @@ -1242,9 +1242,9 @@ Array<Tensor> SliceAxisCompute(const Attrs& attrs,
const SliceAxisAttrs *param = attrs.as<SliceAxisAttrs>();
const Array<IndexExpr> src_shape = inputs[0]->shape;
Array<IndexExpr> begin_idx, end_idx, strides;
int axis = param->axis;
int begin = param->begin;
int end = param->end;
int64_t axis = param->axis;
int64_t begin = param->begin;
int64_t end = param->end;

if (axis < 0) {
axis += src_shape.size();
Expand All @@ -1256,12 +1256,12 @@ Array<Tensor> SliceAxisCompute(const Attrs& attrs,
end += *as_const_int(src_shape[axis]);
}
for (size_t i = 0; i < src_shape.size(); ++i) {
begin_idx.push_back(make_const(tvm::Int(32), 0));
strides.push_back(make_const(tvm::Int(32), 1));
begin_idx.push_back(make_const(Int(64), 0));
strides.push_back(make_const(Int(64), 1));
}
end_idx = Array<IndexExpr>(src_shape);
begin_idx.Set(axis, make_const(tvm::Int(32), begin));
end_idx.Set(axis, make_const(tvm::Int(32), end));
begin_idx.Set(axis, make_const(Int(64), begin));
end_idx.Set(axis, make_const(Int(64), end));

return Array<Tensor>{
topi::strided_slice(inputs[0],
Expand Down
35 changes: 20 additions & 15 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,29 +553,35 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):


@nms.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1,
do_rearrange=False):
def nms_gpu(data, valid_count, iou_threshold=0.5, force_suppress=False,
topk=-1, id_index=0, do_rearrange=False):
"""Non-maximum suppression operator for object detection.
Parameters
----------
data: tvm.Tensor
data : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
valid_count : tvm.Tensor
1-D tensor for valid number of boxes.
nms_threshold : float
iou_threshold : optional, float
Non-maximum suppression threshold.
force_suppress : boolean
force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id.
nms_topk : int
topk : optional, int
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
do_rearrange : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
-------
out : tvm.Tensor
Expand All @@ -588,14 +594,13 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
# An example to use nms
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder(
(dshape[0],), dtype="int32", name="valid_count")
nms_threshold = 0.7
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]).astype("int32")
topk = -1
out = nms(data, valid_count, iou_threshold, force_suppress, topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
Expand Down Expand Up @@ -627,8 +632,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
tvm.extern(data.shape,
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], nms_threshold,
force_suppress, nms_topk),
ins[0], ins[1], ins[2], outs[0], iou_threshold,
force_suppress, topk),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
Expand Down
79 changes: 50 additions & 29 deletions topi/python/topi/vision/ssd/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,39 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
data : tvm.Tensor or numpy NDArray
4-D tensor with shape [batch, channel, height, width]]
sizes : tvm.ndarray
1-D tensor of sizes for anchor boxes.
sizes : tvm ConsExpr
Sizes for anchor boxes.
ratios : tvm.ndarray
1-D tensor of ratios for anchor boxes.
ratios : tvm ConsExpr
Ratios for anchor boxes.
steps : tvm.ndarray
1-D tensor of priorbox step across y and x, -1 for auto calculation.
steps : tvm ConsExpr
Priorbox step across y and x, -1 for auto calculation.
offsets : tvm.ndarray
1-D tensor priorbox center offsets, y and x respectively.
offsets : tvm ConsExpr
Priorbox center offsets, y and x respectively.
Returns
-------
output : tvm.Tensor or numpy NDArray
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
in_height, in_width = data.shape[2], data.shape[3]
num_sizes, num_ratios = sizes.shape[0], ratios.shape[0]
in_height = data.shape[2]
in_width = data.shape[3]
num_sizes = len(sizes)
num_ratios = len(ratios)
num_boxes = in_height * in_width * (num_sizes + num_ratios - 1)
output = output_tensor((1, num_boxes, 4), data.dtype)
steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
output = output_tensor((1, num_boxes, 4), "float32")
steps_h = steps[0] * 1.0 if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] * 1.0 if steps[1] > 0 else 1.0 / in_width
offset_h = offsets[0]
offset_w = offsets[1]

for i in parallel(in_height):
center_h = (i + offset_h) * steps_h
for j in range(in_width):
center_w = (j + offset_w) * steps_w
for k in range(num_sizes + num_ratios - 1):
for k in const_range(num_sizes + num_ratios - 1):
if k < num_sizes:
w = sizes[k] * in_height / in_width / 2.0
h = sizes[k] / 2.0
Expand Down Expand Up @@ -95,7 +97,8 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
out : tvm.Tensor
3-D tensor with shape [1, h_in * w_in * (num_sizes + num_ratios - 1), 4]
"""
out = hybrid_multibox_prior(data, sizes, ratios, steps, offsets)
out = hybrid_multibox_prior(data, tvm.convert(sizes), tvm.convert(ratios),
tvm.convert(steps), tvm.convert(offsets))
if clip:
out = topi.clip(out, 0, 1)
return out
Expand All @@ -104,11 +107,23 @@ def multibox_prior(data, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
def _hybridy_transform_loc(box, pred_loc, variance, clip):
"""Transform prior anchor box to output box through location predictions.
"""
al, at, ar, ab = box[0], box[1], box[2], box[3]
px, py, pw, ph = pred_loc[0], pred_loc[1], \
pred_loc[2], pred_loc[3]
vx, vy, vw, vh = variance[0], variance[1], \
variance[2], variance[3]
al = box[0]
at = box[1]
ar = box[2]
ab = box[3]

px = pred_loc[0]
py = pred_loc[1]
pw = pred_loc[2]
ph = pred_loc[3]

vx = variance[0]
vy = variance[1]
vw = variance[2]
vh = variance[3]

output = output_tensor((4,), pred_loc.dtype)

aw = ar - al
ah = ab - at
ax = (al + ar) / 2.0
Expand All @@ -117,11 +132,11 @@ def _hybridy_transform_loc(box, pred_loc, variance, clip):
oy = py * vy * ah + ay
ow = exp(pw * vw) * aw / 2.0
oh = exp(ph * vh) * ah / 2.0
out_l = max(0, min(1, ox - ow)) if clip else ox - ow
out_t = max(0, min(1, oy - oh)) if clip else oy - oh
out_r = max(0, min(1, ox + ow)) if clip else ox + ow
out_b = max(0, min(1, oy + oh)) if clip else oy + oh
return out_l, out_t, out_r, out_b
output[0] = max(0.0, min(1.0, ox - ow)) if clip else ox - ow
output[1] = max(0.0, min(1.0, oy - oh)) if clip else oy - oh
output[2] = max(0.0, min(1.0, ox + ow)) if clip else ox + ow
output[3] = max(0.0, min(1.0, oy + oh)) if clip else oy + oh
return output

@hybrid.script
def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
Expand All @@ -134,7 +149,7 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
3-D tensor of class probabilities.
loc_pred : tvm.Tensor or numpy NDArray
3-D tensor of location regression predictions.
2-D tensor of location regression predictions.
anchor : tvm.Tensor or numpy NDArray
3-D tensor of prior anchor boxes.
Expand All @@ -159,6 +174,8 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
batch_size = cls_prob.shape[0]
num_classes = cls_prob.shape[1]
num_anchors = cls_prob.shape[2]
box_coord = allocate((4,), loc_pred.dtype)
pred_coord = allocate((4,), loc_pred.dtype)
out_loc = output_tensor((batch_size, num_anchors, 6),
loc_pred.dtype)
valid_count = output_tensor((batch_size,), "int32")
Expand All @@ -172,7 +189,7 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
for k in range(num_classes):
if k > 0:
temp = cls_prob[i, k, j]
cls_id = j if temp > score else cls_id
cls_id = k if temp > score else cls_id
score = max(temp, score)
if cls_id > 0 and score < threshold:
cls_id = 0
Expand All @@ -181,12 +198,16 @@ def hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
if cls_id > 0:
out_loc[i, valid_count[i], 0] = cls_id - 1.0
out_loc[i, valid_count[i], 1] = score
out_coord = _hybridy_transform_loc(anchor[j], loc_pred[i, j],
for l in range(4):
box_coord[l] = anchor[0, j, l]
pred_coord[l] = loc_pred[i, j * 4 + l]
out_coord = _hybridy_transform_loc(box_coord, pred_coord,
variances, clip)
out_loc[i, valid_count[i], 2] = out_coord[0]
out_loc[i, valid_count[i], 3] = out_coord[1]
out_loc[i, valid_count[i], 4] = out_coord[2]
out_loc[i, valid_count[i], 5] = out_coord[3]
valid_count[i] += 1

return out_loc, valid_count

Expand Down Expand Up @@ -222,7 +243,7 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
out, valid_count = hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
tvm.const(clip, "bool"),
tvm.const(threshold, "float32"),
variances)
tvm.convert(variances))
return out, valid_count

@tvm.target.generic_func
Expand Down
4 changes: 2 additions & 2 deletions topi/tests/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_nms():
[1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])

def check_device(device):
Expand All @@ -96,7 +96,7 @@ def check_device(device):
f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)

for device in ['llvm', 'opencl', 'cuda']:
for device in ['llvm']:
check_device(device)


Expand Down

0 comments on commit 7ef52f3

Please sign in to comment.