From d0045e09d9ae4e9cd6885619d435b8c1ebe992c5 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 15 Mar 2019 13:58:19 -0700 Subject: [PATCH] elem length made to a variable --- topi/python/topi/cuda/nms.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index b8cdcac3c8440..7617aff3ffadd 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -35,6 +35,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): """ batch_size = data.shape[0] num_anchors = data.shape[1] + box_data_length = data.shape[2] ib = tvm.ir_builder.create() @@ -55,8 +56,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): with ib.if_scope(tid < batch_size * num_anchors): i = tid / num_anchors # number of batches j = tid % num_anchors # number of anchors - base_idx = i * num_anchors * 6 - with ib.if_scope(data[base_idx + j * 6 + 1] > score_threshold): + base_idx = i * num_anchors * box_data_length + with ib.if_scope(data[base_idx + j * box_data_length + 1] > score_threshold): flag[tid] = 1 idx[tid] = 1 with ib.else_scope(): @@ -127,10 +128,10 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): with ib.if_scope(tid < batch_size * num_anchors): i = tid / num_anchors # number of batches j = tid % num_anchors # number of anchors - base_idx = i * num_anchors * 6 + base_idx = i * num_anchors * elem_length with ib.if_scope(flag[tid] > 0): with ib.for_range(0, elem_length) as k: - out[base_idx + (idx[tid] - 1) * 6 + k] = data[base_idx + j * 6 + k] + out[base_idx + (idx[tid] - 1) * elem_length + k] = data[base_idx + j * elem_length + k] valid_count[i] = idx[i * num_anchors + num_anchors - 1] return ib.get() @@ -416,6 +417,7 @@ def invalid_to_bottom_pre(data, flag, idx): """ batch_size = data.shape[0] num_anchors = data.shape[1] + elem_length = data.shape[2] ib = tvm.ir_builder.create() @@ -434,9 +436,9 @@ def invalid_to_bottom_pre(data, flag, idx): j = bx * max_threads + tx with ib.for_range(0, batch_size, for_type="unroll") as i: - base_idx = i * num_anchors * 6 + base_idx = i * num_anchors * elem_length with ib.if_scope(j < num_anchors): - with ib.if_scope(data[base_idx + j * 6] >= 0): + with ib.if_scope(data[base_idx + j * elem_length] >= 0): flag[i * num_anchors + j] = 1 idx[i * num_anchors + j] = 1 with ib.else_scope(): @@ -494,14 +496,14 @@ def invalid_to_bottom_ir(data, flag, idx, out): j = bx * max_threads + tx with ib.for_range(0, batch_size, for_type="unroll") as i: - base_idx = i * num_anchors * 6 + base_idx = i * num_anchors * elem_length with ib.if_scope(j < num_anchors): with ib.for_range(0, elem_length) as k: - out[base_idx + j * 6 + k] = -1.0 + out[base_idx + j * elem_length + k] = -1.0 with ib.if_scope(flag[i * num_anchors + j] > 0): with ib.for_range(0, elem_length) as k: - out[base_idx + (idx[i * num_anchors + j] - 1) * 6 + k] \ - = data[base_idx + j * 6 + k] + out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \ + = data[base_idx + j * elem_length + k] return ib.get()