Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Make cumsum IR reusable, add thrust scan #7303

Merged
merged 15 commits into from
Jan 20, 2021
151 changes: 4 additions & 147 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust, is_thrust_available
from .scan import exclusive_scan
from ..utils import ceil_div


def cuda_atomic_add_rule(op):
Expand Down Expand Up @@ -51,10 +53,6 @@ def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)


def ceil_div(a, b):
return tvm.tir.indexdiv(a + b - 1, b)


def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index):
"""Low level IR to identify bounding boxes given a score threshold.

Expand Down Expand Up @@ -123,136 +121,6 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index
return ib.get()


def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
"""Low level IR to get the ouput indices of valid boxes
and the count of valid boxes

Parameters
----------
valid_boxes: Buffer
2D Buffer indicating valid boxes with shape [batch_size, num_anchors].

Returns
-------
valid_count: Buffer
1D Buffer of number of valid boxes per batch [batch_size].

valid_indices: Buffer
2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors].
"""
batch_size = valid_boxes.shape[0]
num_anchors = valid_boxes.shape[1]

ib = tvm.tir.ir_builder.create()

valid_boxes = ib.buffer_ptr(valid_boxes)

valid_count = ib.buffer_ptr(valid_count)
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.if_scope(num_anchors > 0):
# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = valid_indices[by * num_anchors + middle[0] - 1]
valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[
by * num_anchors + end[0] - 1
]
valid_indices[by * num_anchors + end[0] - 1] += tmp[0]
with ib.else_scope():
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = 0

return ib.get()


def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
Expand Down Expand Up @@ -374,19 +242,8 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
valid_indices_buf = tvm.tir.decl_buffer(
(batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8
)
valid_count_buf = tvm.tir.decl_buffer(
(batch_size,), "int32", "valid_count_buf", data_alignment=8
)
valid_count, valid_indices = te.extern(
[(batch_size,), (batch_size, num_anchors)],
[valid_boxes],
lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]),
dtype=["int32"],
in_buffers=[valid_boxes_buf],
out_buffers=[valid_count_buf, valid_indices_buf],
name="get_valid_indices",
tag="get_valid_indices_gpu",
)

valid_indices, valid_count = exclusive_scan(valid_boxes, axis=1, return_reduction=True)

out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8)
out_indices_buf = tvm.tir.decl_buffer(
Expand Down
Loading