From eb142d3ee9bb16ddf8d37fdec10c1bcda209deaa Mon Sep 17 00:00:00 2001
From: Masahiro Masuda <masahi129@gmail.com>
Date: Fri, 25 Dec 2020 07:22:00 +0900
Subject: [PATCH] integrate new cumsum change

---
 python/tvm/topi/cuda/__init__.py |   1 +
 python/tvm/topi/cuda/scan.py     | 168 +++++++++++++++++++------------
 2 files changed, 103 insertions(+), 66 deletions(-)

diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 23c625ae7ff7..f407e885d3e8 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -55,3 +55,4 @@
 from .correlation import *
 from .sparse import *
 from .argwhere import *
+from .scan import *
diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index 8d058d783b9a..566aad4d9957 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -22,24 +22,28 @@
 from ..utils import ceil_div
 
 
-def exclusive_sum_scan2d_ir(data, output):
+def exclusive_sum_scan2d_ir(data, output, reduction=None):
     """
     TODO
     """
-    num_rows = data.shape[0]
-    scan_size = data.shape[1]
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
 
     ib = tvm.tir.ir_builder.create()
 
     data = ib.buffer_ptr(data)
     output = ib.buffer_ptr(output)
 
+    if reduction is not None:
+        reduction = ib.buffer_ptr(reduction)
+
     max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
 
+    # Copy boxes to output
     with ib.new_scope():
         nthread_tx = max_threads
-        nthread_bx = ceil_div(scan_size, max_threads)
-        nthread_by = num_rows
+        nthread_bx = ceil_div(num_anchors, max_threads)
+        nthread_by = batch_size
         tx = te.thread_axis("threadIdx.x")
         bx = te.thread_axis("blockIdx.x")
         by = te.thread_axis("blockIdx.y")
@@ -47,19 +51,18 @@ def exclusive_sum_scan2d_ir(data, output):
         ib.scope_attr(bx, "thread_extent", nthread_bx)
         ib.scope_attr(by, "thread_extent", nthread_by)
         tid = bx * nthread_tx + tx
-        with ib.if_scope(tid == 0):
-            output[by, 0] = 0
-        with ib.else_scope():
-            with ib.if_scope(tid < scan_size):
-                output[by, tid] = data[by, tid - 1]
+        with ib.if_scope(tid < num_anchors):
+            output[by, tid] = data[by, tid]
 
     nthread_tx = max_threads
-    nthread_bx = ceil_div(scan_size, max_threads)
-    nthread_by = num_rows
+    nthread_bx = ceil_div(num_anchors, max_threads)
+    nthread_by = batch_size
 
-    # Up Sweep of prefix sum
+    ## The following algorithm performs parallel exclusive scan to get
+    ## a tensor that can later be used to select valid indices
+    # Up Sweep of exclusive scan
     lim = tvm.tir.generic.cast(
-        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_size, "float64"))), "int64"
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
     )
     with ib.for_range(0, lim, dtype="int64") as l2_width:
         width = 2 << l2_width
@@ -71,7 +74,7 @@ def exclusive_sum_scan2d_ir(data, output):
             ib.scope_attr(
                 bx,
                 "thread_extent",
-                tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"),
+                tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
             )
             tid = bx * nthread_tx + tx
 
@@ -81,15 +84,25 @@ def exclusive_sum_scan2d_ir(data, output):
             middle = ib.allocate("int64", (1,), name="middle", scope="local")
             end = ib.allocate("int64", (1,), name="end", scope="local")
             start[0] = width * tid
-            with ib.if_scope(start[0] < scan_size):
+            with ib.if_scope(start[0] < num_anchors):
                 middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
-                end[0] = tvm.te.min(start[0] + width, scan_size)
-                with ib.if_scope(middle[0] < scan_size):
-                    output[by * scan_size + end[0] - 1] += output[by * scan_size + middle[0] - 1]
+                end[0] = tvm.te.min(start[0] + width, num_anchors)
+                with ib.if_scope(middle[0] < num_anchors):
+                    output[by * num_anchors + end[0] - 1] += output[
+                        by * num_anchors + middle[0] - 1
+                    ]
+
+    # Down Sweep of exclusive scan
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", batch_size)
+        with ib.if_scope(bx < batch_size):
+            if reduction is not None:
+                reduction[bx] = output[(bx + 1) * num_anchors - 1]
+            output[(bx + 1) * num_anchors - 1] = 0
 
-    # Down Sweep of prefix sum
-    with ib.for_range(0, lim - 1, dtype="int64") as l2_width:
-        width = 2 << (lim - l2_width - 2)
+    with ib.for_range(0, lim, dtype="int64") as l2_width:
+        width = 2 << (lim - l2_width - 1)
 
         with ib.new_scope():
             tx = te.thread_axis("threadIdx.x")
@@ -98,7 +111,7 @@ def exclusive_sum_scan2d_ir(data, output):
             ib.scope_attr(
                 bx,
                 "thread_extent",
-                tvm.tir.generic.cast(ceil_div(scan_size, max_threads * width), "int32"),
+                tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
             )
             tid = bx * nthread_tx + tx
 
@@ -107,39 +120,19 @@ def exclusive_sum_scan2d_ir(data, output):
             start = ib.allocate("int64", (1,), name="start", scope="local")
             middle = ib.allocate("int64", (1,), name="middle", scope="local")
             end = ib.allocate("int64", (1,), name="end", scope="local")
+            tmp = ib.allocate("int32", (1,), name="end", scope="local")
             start[0] = width * tid
-            with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < scan_size)):
+            with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
                 middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
-                with ib.if_scope(middle[0] < scan_size):
-                    output[by * scan_size + middle[0] - 1] += output[by * scan_size + start[0] - 1]
+                end[0] = tvm.tir.min(start[0] + width, num_anchors)
+                with ib.if_scope(middle[0] < num_anchors):
+                    tmp[0] = output[by * num_anchors + middle[0] - 1]
+                    output[by * num_anchors + middle[0] - 1] = output[by * num_anchors + end[0] - 1]
+                    output[by * num_anchors + end[0] - 1] += tmp[0]
 
     return ib.get()
 
 
-def is_thrust_available():
-    """
-    Test if thrust based scan ops are available.
-    """
-    return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None
-
-
-def scan_thrust(data, exclusive=True):
-    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
-    output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
-    return te.extern(
-        [data.shape],
-        [data],
-        lambda ins, outs: tvm.tir.call_packed(
-            "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive
-        ),
-        dtype=[data.dtype],
-        in_buffers=[data_buf],
-        out_buffers=[output_buf],
-        name="exclusive_sum_scan2d",
-        tag="exclusive_sum_scan2d_gpu",
-    )
-
-
 def get_reduction_from_exclusive_scan_ir(data, data_ex_scan, reduction):
     """TODO"""
     batch_size = data.shape[0]
@@ -185,6 +178,42 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output):
     )
 
 
+def is_thrust_available():
+    """
+    Test if thrust based scan ops are available.
+    """
+    return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None
+
+
+def scan_thrust(data, exclusive=True, return_reduction=False):
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
+    output = te.extern(
+        [data.shape],
+        [data],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive
+        ),
+        dtype=[data.dtype],
+        in_buffers=[data_buf],
+        out_buffers=[output_buf],
+        name="exclusive_sum_scan2d",
+        tag="exclusive_sum_scan2d_gpu",
+    )
+
+    if return_reduction:
+        ndim = len(data.shape)
+        if ndim == 1:
+            output = expand_dims(output, axis=0)
+            reduction = get_reduction_from_exclusive_scan(data, output)
+            reduction = squeeze(reduction, 0)
+        else:
+            reduction = get_reduction_from_exclusive_scan(data, output)
+        return output, reduction
+
+    return output
+
+
 def exclusive_scan(data, axis=-1, return_reduction=False):
     # TODO(masahi): support other binary associative operators
     ndim = len(data.shape)
@@ -194,17 +223,27 @@ def exclusive_scan(data, axis=-1, return_reduction=False):
 
     target = tvm.target.Target.current()
     if target and target.kind.name == "cuda" and is_thrust_available():
-        output = scan_thrust(data, exclusive=True)
-        if ndim == 1 and return_reduction:
-            output = expand_dims(data, axis=0)
-    else:
-        if ndim == 1:
-            data = expand_dims(data, axis=0)
+        return scan_thrust(data, exclusive=True, return_reduction=return_reduction)
 
-        data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
-        output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
+    if ndim == 1:
+        data = expand_dims(data, axis=0)
+
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8)
 
-        if ndim == 2:
+    if ndim == 2:
+        if return_reduction:
+            output, reduction = te.extern(
+                [data.shape, (data.shape[0],)],
+                [data],
+                lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]),
+                dtype=[data.dtype],
+                in_buffers=[data_buf],
+                out_buffers=[output_buf],
+                name="exclusive_scan",
+                tag="exclusive_scan_gpu",
+            )
+        else:
             output = te.extern(
                 [data.shape],
                 [data],
@@ -215,19 +254,16 @@ def exclusive_scan(data, axis=-1, return_reduction=False):
                 name="exclusive_scan",
                 tag="exclusive_scan_gpu",
             )
-        else:
-            assert False, "Unsupported dimension {}".format(ndim)
-
-    if return_reduction:
-        reduction = get_reduction_from_exclusive_scan(data, output)
+            reduction = None
+    else:
+        assert False, "Unsupported dimension {}".format(ndim)
 
     if ndim == 1:
         output = squeeze(output, 0)
         if return_reduction:
             reduction = squeeze(reduction, 0)
-            return output, reduction
-        return reduction
 
     if return_reduction:
         return output, reduction
+
     return output