Skip to content

Commit

Permalink
[TOPI][CUDA] Fix nms block extent and type mismatch in multibox (#2320)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and tqchen committed Jan 7, 2019
1 parent 57506af commit 9b4b360
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
10 changes: 4 additions & 6 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def sort_pre_ir_data(data, index, sizes_in, data_out, index_out, \
tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
dshape = tvm.max(sizes_in.shape[0], p_index[0])
dshape = axis_mul_before * axis_mul_after
nthread_tx = max_threads
nthread_bx = dshape // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
Expand Down Expand Up @@ -331,9 +331,7 @@ def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend):
Parameters
----------
data: tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
2-D tensor of input boxes' score with shape [batch_size, num_anchors].
data_buf: Buffer
2D Buffer of input boxes' score with shape [batch_size, num_anchors].
Expand Down Expand Up @@ -595,8 +593,8 @@ def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]).astype("int32")
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
Expand Down
8 changes: 4 additions & 4 deletions topi/python/topi/cuda/ssd/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw,
oy = py * vy * ah + ay
ow = tvm.exp(pw * vw) * aw / 2.0
oh = tvm.exp(ph * vh) * ah / 2.0
return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh)
return tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox - ow)), ox - ow), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy - oh)), oy - oh), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \
tvm.select(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh)

max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def nms(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1)
force_suppress = True
nms_topk = -1
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
np_data = np.random.uniform(size=dshape).astype("float32")
np_valid_count = np.array([4]).astype("int32")
s = topi.generic.schedule_nms(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
Expand Down
4 changes: 2 additions & 2 deletions topi/tests/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def check_device(device):
f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)

for device in ['llvm', 'opencl']:
for device in ['llvm', 'opencl', 'cuda']:
check_device(device)


Expand Down Expand Up @@ -105,7 +105,7 @@ def check_device(device):
f(tvm_input_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-3)

for device in ['llvm', 'opencl']:
for device in ['llvm', 'opencl', 'cuda']:
check_device(device)


Expand Down

0 comments on commit 9b4b360

Please sign in to comment.