Skip to content

Commit

Permalink
add api for returning reduction from ex scan output
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 24, 2020
1 parent b7f4ef7 commit a2ad4de
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 75 deletions.
50 changes: 1 addition & 49 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,40 +124,6 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index
return ib.get()


def get_num_valid_boxes_ir(valid_boxes, valid_boxes_ex_scan, valid_count):
"""TODO"""
batch_size = valid_boxes.shape[0]
num_anchors = valid_boxes.shape[1]

ib = tvm.tir.ir_builder.create()

valid_boxes = ib.buffer_ptr(valid_boxes)
valid_boxes_ex_scan = ib.buffer_ptr(valid_boxes_ex_scan)
valid_count = ib.buffer_ptr(valid_count)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)

## Write Sum to valid_count
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
valid_count[tid] = (
valid_boxes_ex_scan[tid, num_anchors - 1] + valid_boxes[tid, num_anchors - 1]
)

return ib.get()


def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
Expand Down Expand Up @@ -284,22 +250,8 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
valid_indices_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8
)
valid_count_buf = tvm.tir.decl_buffer(
(batch_size,), "int32", "valid_count_buf", data_alignment=8
)

valid_indices = exclusive_scan(valid_boxes, axis=1)

valid_count = te.extern(
[(batch_size,)],
[valid_boxes, valid_indices],
lambda ins, outs: get_num_valid_boxes_ir(ins[0], ins[1], outs[0]),
dtype=["int32"],
in_buffers=[valid_boxes_buf, valid_indices_buf],
out_buffers=[valid_count_buf],
name="get_valid_indices_sum",
tag="get_valid_indices_sum_gpu",
)
valid_indices, valid_count = exclusive_scan(valid_boxes, axis=1, return_reduction=True)

out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
out_indices_buf = tvm.tir.decl_buffer(
Expand Down
108 changes: 82 additions & 26 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tvm import te
from tvm._ffi import get_global_func
from ..transform import expand_dims, squeeze
from ..utils import ceil_div


def exclusive_sum_scan2d_ir(data, output):
Expand All @@ -35,9 +36,6 @@ def exclusive_sum_scan2d_ir(data, output):

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)

with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(scan_size, max_threads)
Expand Down Expand Up @@ -142,7 +140,52 @@ def scan_thrust(data, exclusive=True):
)


def exclusive_scan(data, axis=-1):
def get_reduction_from_exclusive_scan_ir(data, data_ex_scan, reduction):
"""TODO"""
batch_size = data.shape[0]
num_anchors = data.shape[1]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
data_ex_scan = ib.buffer_ptr(data_ex_scan)
reduction = ib.buffer_ptr(reduction)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(batch_size, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size):
reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1]

return ib.get()


def get_reduction_from_exclusive_scan(data, ex_scan_output):
"""TODO"""
assert len(data.shape) == 2, "Only 2D input supported for now"
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8)
ex_scan_output_buf = tvm.tir.decl_buffer(
ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8
)

return te.extern(
[(data.shape[0],)],
[data, ex_scan_output],
lambda ins, outs: get_reduction_from_exclusive_scan_ir(ins[0], ins[1], outs[0]),
dtype=[ex_scan_output.dtype],
in_buffers=[data_buf, ex_scan_output_buf],
name="ex_scan_reduction",
tag="ex_scan_reduction_gpu",
)


def exclusive_scan(data, axis=-1, return_reduction=False):
# TODO(masahi): support other binary associative operators
ndim = len(data.shape)
if axis < 0:
Expand All @@ -151,27 +194,40 @@ def exclusive_scan(data, axis=-1):

target = tvm.target.Target.current()
if target and target.kind.name == "cuda" and is_thrust_available():
return scan_thrust(data, exclusive=True)

if ndim == 1:
data = expand_dims(data, axis=0)
output = scan_thrust(data, exclusive=True)
if ndim == 1 and return_reduction:
output = expand_dims(data, axis=0)
else:
if ndim == 1:
data = expand_dims(data, axis=0)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)

if ndim == 2:
output = te.extern(
[data.shape],
[data],
lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_scan",
tag="exclusive_scan_gpu",
)
else:
assert False, "Unsupported dimension {}".format(ndim)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
if return_reduction:
reduction = get_reduction_from_exclusive_scan(data, output)

if ndim == 2:
output = te.extern(
[data.shape],
[data],
lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[output_buf],
name="exclusive_scan",
tag="exclusive_scan_gpu",
)
if ndim == 1:
return squeeze(output, 0)
return output
else:
assert False, "Unsupported dimension {}".format(ndim)
if ndim == 1:
output = squeeze(output, 0)
if return_reduction:
reduction = squeeze(reduction, 0)
return output, reduction
return reduction

if return_reduction:
return output, reduction
return output

0 comments on commit a2ad4de

Please sign in to comment.