From a2ad4dea87d9a637745fb0a40ff9bbdde286194a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 20 Dec 2020 20:36:34 +0900 Subject: [PATCH] add api for returning reduction from ex scan output --- python/tvm/topi/cuda/nms.py | 50 +--------------- python/tvm/topi/cuda/scan.py | 108 ++++++++++++++++++++++++++--------- 2 files changed, 83 insertions(+), 75 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 888448973f11..f08389faf77c 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -124,40 +124,6 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index return ib.get() -def get_num_valid_boxes_ir(valid_boxes, valid_boxes_ex_scan, valid_count): - """TODO""" - batch_size = valid_boxes.shape[0] - num_anchors = valid_boxes.shape[1] - - ib = tvm.tir.ir_builder.create() - - valid_boxes = ib.buffer_ptr(valid_boxes) - valid_boxes_ex_scan = ib.buffer_ptr(valid_boxes_ex_scan) - valid_count = ib.buffer_ptr(valid_count) - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - - def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) - - ## Write Sum to valid_count - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - with ib.if_scope(tid < batch_size): - valid_count[tid] = ( - valid_boxes_ex_scan[tid, num_anchors - 1] + valid_boxes[tid, num_anchors - 1] - ) - - return ib.get() - - def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the @@ -284,22 +250,8 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): valid_indices_buf = tvm.tir.decl_buffer( (batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8 ) - valid_count_buf = tvm.tir.decl_buffer( - (batch_size,), "int32", "valid_count_buf", data_alignment=8 - ) - valid_indices = exclusive_scan(valid_boxes, axis=1) - - valid_count = te.extern( - [(batch_size,)], - [valid_boxes, valid_indices], - lambda ins, outs: get_num_valid_boxes_ir(ins[0], ins[1], outs[0]), - dtype=["int32"], - in_buffers=[valid_boxes_buf, valid_indices_buf], - out_buffers=[valid_count_buf], - name="get_valid_indices_sum", - tag="get_valid_indices_sum_gpu", - ) + valid_indices, valid_count = exclusive_scan(valid_boxes, axis=1, return_reduction=True) out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out_indices_buf = tvm.tir.decl_buffer( diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 4c7c838f3a6b..8d058d783b9a 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -19,6 +19,7 @@ from tvm import te from tvm._ffi import get_global_func from ..transform import expand_dims, squeeze +from ..utils import ceil_div def exclusive_sum_scan2d_ir(data, output): @@ -35,9 +36,6 @@ def exclusive_sum_scan2d_ir(data, output): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) - with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(scan_size, max_threads) @@ -142,7 +140,52 @@ def scan_thrust(data, exclusive=True): ) -def exclusive_scan(data, axis=-1): +def get_reduction_from_exclusive_scan_ir(data, data_ex_scan, reduction): + """TODO""" + batch_size = data.shape[0] + num_anchors = data.shape[1] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + data_ex_scan = ib.buffer_ptr(data_ex_scan) + reduction = ib.buffer_ptr(reduction) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + + return ib.get() + + +def get_reduction_from_exclusive_scan(data, ex_scan_output): + """TODO""" + assert len(data.shape) == 2, "Only 2D input supported for now" + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8) + ex_scan_output_buf = tvm.tir.decl_buffer( + ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8 + ) + + return te.extern( + [(data.shape[0],)], + [data, ex_scan_output], + lambda ins, outs: get_reduction_from_exclusive_scan_ir(ins[0], ins[1], outs[0]), + dtype=[ex_scan_output.dtype], + in_buffers=[data_buf, ex_scan_output_buf], + name="ex_scan_reduction", + tag="ex_scan_reduction_gpu", + ) + + +def exclusive_scan(data, axis=-1, return_reduction=False): # TODO(masahi): support other binary associative operators ndim = len(data.shape) if axis < 0: @@ -151,27 +194,40 @@ def exclusive_scan(data, axis=-1): target = tvm.target.Target.current() if target and target.kind.name == "cuda" and is_thrust_available(): - return scan_thrust(data, exclusive=True) - - if ndim == 1: - data = expand_dims(data, axis=0) + output = scan_thrust(data, exclusive=True) + if ndim == 1 and return_reduction: + output = expand_dims(data, axis=0) + else: + if ndim == 1: + data = expand_dims(data, axis=0) + + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + + if ndim == 2: + output = te.extern( + [data.shape], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + else: + assert False, "Unsupported dimension {}".format(ndim) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + if return_reduction: + reduction = get_reduction_from_exclusive_scan(data, output) - if ndim == 2: - output = te.extern( - [data.shape], - [data], - lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), - dtype=[data.dtype], - in_buffers=[data_buf], - out_buffers=[output_buf], - name="exclusive_scan", - tag="exclusive_scan_gpu", - ) - if ndim == 1: - return squeeze(output, 0) - return output - else: - assert False, "Unsupported dimension {}".format(ndim) + if ndim == 1: + output = squeeze(output, 0) + if return_reduction: + reduction = squeeze(reduction, 0) + return output, reduction + return reduction + + if return_reduction: + return output, reduction + return output