From 480787bc072bfc59dcc279038c772f8ad2ec03e9 Mon Sep 17 00:00:00 2001
From: mbrookhart <mbrookhart@octoml.ai>
Date: Thu, 17 Dec 2020 10:01:11 -0700
Subject: [PATCH] Parallelize cumsum in get_valid_counts

---
 python/tvm/topi/cuda/nms.py          | 117 ++++++++++++++++++++++++---
 tests/python/relay/test_op_level5.py |   4 +-
 2 files changed, 105 insertions(+), 16 deletions(-)

diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 020cf9b5bc63..6ba1a5704ee2 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -151,27 +151,118 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
     valid_indices = ib.buffer_ptr(valid_indices)
 
     max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
+    def ceil_div(a, b):
+        return tvm.tir.indexdiv(a + b - 1, b)
+
+    # Copy boxes to valid_indices
     with ib.new_scope():
         nthread_tx = max_threads
-        nthread_bx = batch_size // max_threads + 1
+        nthread_bx = ceil_div(batch_size * num_anchors, max_threads)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size * num_anchors):
+            valid_indices[tid] = valid_boxes[tid]
+
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(num_anchors, max_threads)
+    nthread_by = batch_size
+
+    ## The following algorithm performs parallel prefix sum to get
+    ## a tensor that can later be used to select valid indices
+    # Up Sweep of prefix sum
+    lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
+    )
+    with ib.for_range(0, lim, dtype="int64") as l2_width:
+        width = 2 << l2_width
+
+        with ib.new_scope():
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            start[0] = width * tid
+            with ib.if_scope(start[0] < num_anchors):
+                middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
+                end[0] = tvm.te.min(start[0] + width, num_anchors)
+                with ib.if_scope(middle[0] < num_anchors):
+                    valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
+                        by * num_anchors + middle[0] - 1
+                    ]
+
+    # Down Sweep of prefix sum
+    with ib.for_range(0, lim - 1, dtype="int64") as l2_width:
+        width = 2 << (lim - l2_width - 2)
+
+        with ib.new_scope():
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            start[0] = width * tid
+            with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < num_anchors)):
+                middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
+                with ib.if_scope(middle[0] < num_anchors):
+                    valid_indices[by * num_anchors + middle[0] - 1] += valid_indices[
+                        by * num_anchors + start[0] - 1
+                    ]
+
+    ## Write Sum to valid_count
+    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(batch_size, max_threads)
         tx = te.thread_axis("threadIdx.x")
         bx = te.thread_axis("blockIdx.x")
         ib.scope_attr(tx, "thread_extent", nthread_tx)
         ib.scope_attr(bx, "thread_extent", nthread_bx)
         tid = bx * max_threads + tx
-        # TODO(mbrookhart): Parallelize the sum and cumsum here
-        current_index = ib.allocate("int32", (1,), name="current_index", scope="local")
         with ib.if_scope(tid < batch_size):
-            current_index[0] = 0
-            valid_count[tid] = 0
-            with ib.for_range(0, num_anchors) as j:
-                idx = tid * num_anchors + j
-                valid_count[tid] = valid_count[tid] + valid_boxes[idx]
-                with ib.if_scope(valid_boxes[idx] == 1):
-                    valid_indices[idx] = current_index[0]
-                    current_index[0] = current_index[0] + 1
-                with ib.else_scope():
-                    valid_indices[idx] = -1
+            valid_count[tid] = valid_indices[tid * num_anchors + num_anchors - 1]
+
+    ## Remove invalid indices
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(batch_size * num_anchors, max_threads)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size * num_anchors):
+            with ib.if_scope(valid_boxes[tid] < 1):
+                # if this is an invalid box, mark -1
+                valid_indices[tid] = -1
+            with ib.else_scope():
+                # if this is a valid box, subtract 1 to get 0-based indexing
+                valid_indices[tid] += -1
+
     return ib.get()
 
 
diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py
index 1ce8a182f034..cdf3b240507b 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -313,10 +313,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
         for target, ctx in tvm.testing.enabled_targets():
             intrp = relay.create_executor("debug", ctx=ctx, target=target)
             out = intrp.evaluate(func)(np_data)
+
             tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
-            # get_valid_count for opencl doesn't do data rearrangement
-            if target in ["opencl"]:
-                return
             tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
             tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04)