Skip to content

Commit

Permalink
elem length made to a variable
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 15, 2019
1 parent abf3c34 commit d0045e0
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
box_data_length = data.shape[2]

ib = tvm.ir_builder.create()

Expand All @@ -55,8 +56,8 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors # number of batches
j = tid % num_anchors # number of anchors
base_idx = i * num_anchors * 6
with ib.if_scope(data[base_idx + j * 6 + 1] > score_threshold):
base_idx = i * num_anchors * box_data_length
with ib.if_scope(data[base_idx + j * box_data_length + 1] > score_threshold):
flag[tid] = 1
idx[tid] = 1
with ib.else_scope():
Expand Down Expand Up @@ -127,10 +128,10 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors # number of batches
j = tid % num_anchors # number of anchors
base_idx = i * num_anchors * 6
base_idx = i * num_anchors * elem_length
with ib.if_scope(flag[tid] > 0):
with ib.for_range(0, elem_length) as k:
out[base_idx + (idx[tid] - 1) * 6 + k] = data[base_idx + j * 6 + k]
out[base_idx + (idx[tid] - 1) * elem_length + k] = data[base_idx + j * elem_length + k]
valid_count[i] = idx[i * num_anchors + num_anchors - 1]

return ib.get()
Expand Down Expand Up @@ -416,6 +417,7 @@ def invalid_to_bottom_pre(data, flag, idx):
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]

ib = tvm.ir_builder.create()

Expand All @@ -434,9 +436,9 @@ def invalid_to_bottom_pre(data, flag, idx):
j = bx * max_threads + tx

with ib.for_range(0, batch_size, for_type="unroll") as i:
base_idx = i * num_anchors * 6
base_idx = i * num_anchors * elem_length
with ib.if_scope(j < num_anchors):
with ib.if_scope(data[base_idx + j * 6] >= 0):
with ib.if_scope(data[base_idx + j * elem_length] >= 0):
flag[i * num_anchors + j] = 1
idx[i * num_anchors + j] = 1
with ib.else_scope():
Expand Down Expand Up @@ -494,14 +496,14 @@ def invalid_to_bottom_ir(data, flag, idx, out):
j = bx * max_threads + tx

with ib.for_range(0, batch_size, for_type="unroll") as i:
base_idx = i * num_anchors * 6
base_idx = i * num_anchors * elem_length
with ib.if_scope(j < num_anchors):
with ib.for_range(0, elem_length) as k:
out[base_idx + j * 6 + k] = -1.0
out[base_idx + j * elem_length + k] = -1.0
with ib.if_scope(flag[i * num_anchors + j] > 0):
with ib.for_range(0, elem_length) as k:
out[base_idx + (idx[i * num_anchors + j] - 1) * 6 + k] \
= data[base_idx + j * 6 + k]
out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
= data[base_idx + j * elem_length + k]
return ib.get()


Expand Down

0 comments on commit d0045e0

Please sign in to comment.