From d269839233620c0e852fce2de359f468210dc0af Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Dec 2020 15:17:34 +0900 Subject: [PATCH 01/15] import changes from scan branch commit cf0d4fdf3bf8fa6e1d6abf631042de28176923c3 Author: Masahiro Masuda Date: Fri Dec 25 10:12:01 2020 +0900 get valid count test working commit eb142d3ee9bb16ddf8d37fdec10c1bcda209deaa Author: Masahiro Masuda Date: Fri Dec 25 07:22:00 2020 +0900 integrate new cumsum change commit f89684d73dad1f863b4fd291e8804b5c24eae94f Author: Masahiro Masuda Date: Fri Dec 25 06:56:46 2020 +0900 remove ceil_div from nms commit a2ad4dea87d9a637745fb0a40ff9bbdde286194a Author: Masahiro Masuda Date: Sun Dec 20 20:36:34 2020 +0900 add api for returning reduction from ex scan output commit b7f4ef7006b722e365533bec53b1f104aa056da2 Author: Masahiro Masuda Date: Sun Dec 20 19:49:07 2020 +0900 move ceil_div to utils commit a9a57e34317b1f254165c3a88e465e33c7fda01b Author: Masahiro Masuda Date: Sun Dec 20 19:38:15 2020 +0900 rename prefix_scan.py to scan.py commit 03ed43ff550a435a28740ce1fa62cea71b90cf2c Author: Masahiro Masuda Date: Sat Dec 19 06:12:55 2020 +0900 surpress cpplint commit abceac980d8dfd94072acc228108d1fcd94a214c Author: masa Date: Fri Dec 18 20:36:24 2020 +0900 support more data type commit 3e7d1f81821a1e221cbb1322ef5b23f273f51c42 Author: masa Date: Fri Dec 18 20:09:51 2020 +0900 1d thrust scan working commit ac13b407e21a83ca57240cad205c32a5d000f999 Author: masa Date: Fri Dec 18 19:49:25 2020 +0900 adding thrust scan support commit 65634e86c33786541485dc6461a96da833332297 Author: masa Date: Fri Dec 18 19:01:11 2020 +0900 add thrust scan python stub commit 9876c901ee8b406bc9d75ba91c4734d55f85811b Author: masa Date: Fri Dec 18 20:55:14 2020 +0900 introduce prefix_scan.py and move scan ir in nms.py commit 667bdd3b135a03b53937fdb664915e07f1365ee1 Author: masa Date: Fri Dec 18 15:06:18 2020 +0900 make the scan loop exclusive commit 480787bc072bfc59dcc279038c772f8ad2ec03e9 Author: mbrookhart Date: Thu Dec 17 10:01:11 2020 -0700 Parallelize cumsum in get_valid_counts --- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/nms.py | 151 +-------------- python/tvm/topi/cuda/scan.py | 277 +++++++++++++++++++++++++++ python/tvm/topi/cuda/scatter.py | 6 +- python/tvm/topi/cuda/sparse.py | 6 +- python/tvm/topi/utils.py | 4 + src/runtime/contrib/thrust/thrust.cu | 74 +++++++ 7 files changed, 362 insertions(+), 157 deletions(-) create mode 100644 python/tvm/topi/cuda/scan.py diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 42bf980bec4c..e0ff5a12a9b2 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -56,3 +56,4 @@ from .correlation import * from .sparse import * from .argwhere import * +from .scan import * diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 0c01cc9fbbdf..32691da90ecc 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -22,6 +22,8 @@ from tvm.tir import if_then_else from .sort import argsort, argsort_thrust, is_thrust_available +from .scan import exclusive_scan +from ..utils import ceil_div def cuda_atomic_add_rule(op): @@ -51,10 +53,6 @@ def atomic_add(x, y): return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y) -def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) - - def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index): """Low level IR to identify bounding boxes given a score threshold. @@ -123,136 +121,6 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index return ib.get() -def get_valid_indices_ir(valid_boxes, valid_count, valid_indices): - """Low level IR to get the ouput indices of valid boxes - and the count of valid boxes - - Parameters - ---------- - valid_boxes: Buffer - 2D Buffer indicating valid boxes with shape [batch_size, num_anchors]. - - Returns - ------- - valid_count: Buffer - 1D Buffer of number of valid boxes per batch [batch_size]. - - valid_indices: Buffer - 2D Buffer indicating output sorted indcies of valid boxes [batch_size, num_anchors]. - """ - batch_size = valid_boxes.shape[0] - num_anchors = valid_boxes.shape[1] - - ib = tvm.tir.ir_builder.create() - - valid_boxes = ib.buffer_ptr(valid_boxes) - - valid_count = ib.buffer_ptr(valid_count) - valid_indices = ib.buffer_ptr(valid_indices) - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.if_scope(num_anchors > 0): - # Copy boxes to valid_indices - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) - tid = bx * nthread_tx + tx - with ib.if_scope(tid < num_anchors): - valid_indices[by, tid] = valid_boxes[by, tid] - - nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) - nthread_by = batch_size - - ## The following algorithm performs parallel exclusive scan to get - ## a tensor that can later be used to select valid indices - # Up Sweep of exclusive scan - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" - ) - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << l2_width - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(start[0] < num_anchors): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.te.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - valid_indices[by * num_anchors + end[0] - 1] += valid_indices[ - by * num_anchors + middle[0] - 1 - ] - - # Down Sweep of exclusive scan - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - with ib.if_scope(bx < batch_size): - valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1] - valid_indices[(bx + 1) * num_anchors - 1] = 0 - - with ib.for_range(0, lim, dtype="int64") as l2_width: - width = 2 << (lim - l2_width - 1) - - with ib.new_scope(): - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr( - bx, - "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), - ) - tid = bx * nthread_tx + tx - - by = te.thread_axis("blockIdx.y") - ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") - tmp = ib.allocate("int32", (1,), name="end", scope="local") - start[0] = width * tid - with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.tir.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - tmp[0] = valid_indices[by * num_anchors + middle[0] - 1] - valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[ - by * num_anchors + end[0] - 1 - ] - valid_indices[by * num_anchors + end[0] - 1] += tmp[0] - with ib.else_scope(): - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - with ib.if_scope(bx < batch_size): - valid_count[bx] = 0 - - return ib.get() - - def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the @@ -374,19 +242,8 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): valid_indices_buf = tvm.tir.decl_buffer( (batch_size, num_anchors), "int32", "valid_indices_buf", data_alignment=8 ) - valid_count_buf = tvm.tir.decl_buffer( - (batch_size,), "int32", "valid_count_buf", data_alignment=8 - ) - valid_count, valid_indices = te.extern( - [(batch_size,), (batch_size, num_anchors)], - [valid_boxes], - lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]), - dtype=["int32"], - in_buffers=[valid_boxes_buf], - out_buffers=[valid_count_buf, valid_indices_buf], - name="get_valid_indices", - tag="get_valid_indices_gpu", - ) + + valid_indices, valid_count = exclusive_scan(valid_boxes, axis=1, return_reduction=True) out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out_indices_buf = tvm.tir.decl_buffer( diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py new file mode 100644 index 000000000000..0d1459f043cb --- /dev/null +++ b/python/tvm/topi/cuda/scan.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Scan related operators" +import tvm +from tvm import te +from tvm._ffi import get_global_func +from ..transform import expand_dims, squeeze +from ..utils import ceil_div + + +def exclusive_sum_scan2d_ir(data, output, reduction=None): + """ + TODO + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + output = ib.buffer_ptr(output) + + if reduction is not None: + reduction = ib.buffer_ptr(reduction) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + # Copy boxes to output + with ib.if_scope(num_anchors > 0): + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + tid = bx * nthread_tx + tx + with ib.if_scope(tid < num_anchors): + output[by, tid] = data[by, tid] + + nthread_tx = max_threads + nthread_bx = ceil_div(num_anchors, max_threads) + nthread_by = batch_size + + ## The following algorithm performs parallel exclusive scan to get + ## a tensor that can later be used to select valid indices + # Up Sweep of exclusive scan + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" + ) + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << l2_width + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(start[0] < num_anchors): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.te.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + output[by * num_anchors + end[0] - 1] += output[ + by * num_anchors + middle[0] - 1 + ] + + # Down Sweep of exclusive scan + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + if reduction is not None: + reduction[bx] = output[(bx + 1) * num_anchors - 1] + output[(bx + 1) * num_anchors - 1] = 0 + + with ib.for_range(0, lim, dtype="int64") as l2_width: + width = 2 << (lim - l2_width - 1) + + with ib.new_scope(): + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr( + bx, + "thread_extent", + tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + ) + tid = bx * nthread_tx + tx + + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", nthread_by) + start = ib.allocate("int64", (1,), name="start", scope="local") + middle = ib.allocate("int64", (1,), name="middle", scope="local") + end = ib.allocate("int64", (1,), name="end", scope="local") + tmp = ib.allocate("int32", (1,), name="end", scope="local") + start[0] = width * tid + with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): + middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + end[0] = tvm.tir.min(start[0] + width, num_anchors) + with ib.if_scope(middle[0] < num_anchors): + tmp[0] = output[by * num_anchors + middle[0] - 1] + output[by * num_anchors + middle[0] - 1] = output[by * num_anchors + end[0] - 1] + output[by * num_anchors + end[0] - 1] += tmp[0] + with ib.else_scope(): + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + if reduction is not None: + reduction[bx] = 0 + + + return ib.get() + + +def get_reduction_from_exclusive_scan_ir(data, data_ex_scan, reduction): + """TODO""" + batch_size = data.shape[0] + num_anchors = data.shape[1] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + data_ex_scan = ib.buffer_ptr(data_ex_scan) + reduction = ib.buffer_ptr(reduction) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + + return ib.get() + + +def get_reduction_from_exclusive_scan(data, ex_scan_output): + """TODO""" + assert len(data.shape) == 2, "Only 2D input supported for now" + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8) + ex_scan_output_buf = tvm.tir.decl_buffer( + ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8 + ) + + return te.extern( + [(data.shape[0],)], + [data, ex_scan_output], + lambda ins, outs: get_reduction_from_exclusive_scan_ir(ins[0], ins[1], outs[0]), + dtype=[ex_scan_output.dtype], + in_buffers=[data_buf, ex_scan_output_buf], + name="ex_scan_reduction", + tag="ex_scan_reduction_gpu", + ) + + +def is_thrust_available(): + """ + Test if thrust based scan ops are available. + """ + return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None + + +def scan_thrust(data, exclusive=True, return_reduction=False): + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + output = te.extern( + [data.shape], + [data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive + ), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_sum_scan2d", + tag="exclusive_sum_scan2d_gpu", + ) + + if return_reduction: + ndim = len(data.shape) + if ndim == 1: + output = expand_dims(output, axis=0) + reduction = get_reduction_from_exclusive_scan(data, output) + reduction = squeeze(reduction, 0) + else: + reduction = get_reduction_from_exclusive_scan(data, output) + return output, reduction + + return output + + +def exclusive_scan(data, axis=-1, return_reduction=False): + # TODO(masahi): support other binary associative operators + ndim = len(data.shape) + if axis < 0: + axis += ndim + assert axis == ndim - 1, "Only support scan on the inner most axis." + + target = tvm.target.Target.current() + if target and target.kind.name == "cuda" and is_thrust_available(): + return scan_thrust(data, exclusive=True, return_reduction=return_reduction) + + if ndim == 1: + data = expand_dims(data, axis=0) + + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + + if ndim == 2: + if return_reduction: + output, reduction = te.extern( + [data.shape, (data.shape[0],)], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), + dtype=[data.dtype, data.dtype], + in_buffers=[data_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + else: + output = te.extern( + [data.shape], + [data], + lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="exclusive_scan", + tag="exclusive_scan_gpu", + ) + reduction = None + else: + assert False, "Unsupported dimension {}".format(ndim) + + if ndim == 1: + output = squeeze(output, 0) + if return_reduction: + reduction = squeeze(reduction, 0) + + if return_reduction: + return output, reduction + + return output diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index b34bd1df14e4..444fb25cc34b 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -22,11 +22,7 @@ from ..generic import schedule_extern from .nms import atomic_add from .sort import stable_sort_by_key_thrust, is_thrust_available -from ..utils import prod - - -def ceil_div(a, b): - return (a + b - 1) // b +from ..utils import prod, ceil_div def _memcpy_ir(ib, out_ptr, data_ptr, shape): diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index cb61d9686919..3e7606556d9c 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -23,7 +23,7 @@ from tvm import relay, te from .. import nn -from ..utils import traverse_inline, get_const_tuple, prod, get_const_int +from ..utils import traverse_inline, get_const_tuple, prod, get_const_int, ceil_div def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False): @@ -161,10 +161,6 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr): with either default_function_kernel0 for the transpose or default_function_kernel1 for the multiply. """ - - def ceil_div(a, b): - return (a + (b - 1)) // b - def gen_ir(data, w_data, w_indices, w_indptr, out): # pylint: disable=invalid-name # TODO(tkonolige): use tensorcores for block multiply diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index c3e14eff3919..769aa436fb08 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -487,3 +487,7 @@ def is_empty_shape(shape): Whether input shape is empty or has dimesion with size 0. """ return cpp.utils.is_empty_shape(shape) + + +def ceil_div(a, b): + return (a + b - 1) // b diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 6a48f1ad876a..fe6cc43fefdc 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -264,5 +265,78 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") } }); +template +void thrust_scan(DLTensor* data, + DLTensor* output, + bool exclusive) { + thrust::device_ptr data_ptr(static_cast(data->data)); + thrust::device_ptr output_ptr(static_cast(output->data)); + const auto scan_size = data->shape[data->ndim - 1]; + + if (data->ndim == 1 || (data->ndim == 2 && data->shape[0] == 1)) { + if (exclusive) { + thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } else { + thrust::inclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); + } + } else { + // Use thrust segmented scan to compute scan on the inner most axis + // data->shape[0] * data->shape[1] * ... * data->shape[ndim - 2] scans are + // computed in parallel + + // This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,..., + // without materializing the sequence vector + auto counting_iter = thrust::counting_iterator(0); + // Without __host__ annotation, cub crashes + auto linear_index_to_scan_key = [scan_size] __host__ __device__(int64_t i) { + return i / scan_size; + }; // NOLINT(*) + auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key); + int64_t size = 1; + for (int i = 0; i < data->ndim; ++i) size *= data->shape[i]; + + if (exclusive) { + thrust::exclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } else { + thrust::inclusive_scan_by_key(key_iter, key_iter + size, data_ptr, output_ptr); + } + } +} + +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.num_args, 3); + DLTensor* data = args[0]; + DLTensor* output = args[1]; + bool exclusive = args[2]; + + auto in_dtype = DLDataType2String(data->dtype); + auto out_dtype = DLDataType2String(output->dtype); + + if (in_dtype == "int32") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (in_dtype == "int64") { + if (out_dtype == "int64") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (in_dtype == "float32") { + if (out_dtype == "float32") { + thrust_scan(data, output, exclusive); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << in_dtype; + } +}); + } // namespace contrib } // namespace tvm From a7772a5d965a731207b9a2581fd66cadda53f79e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 31 Dec 2020 15:25:51 +0900 Subject: [PATCH 02/15] fix for 1d scan --- python/tvm/topi/cuda/scan.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0d1459f043cb..380593a6d5b4 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -223,8 +223,9 @@ def scan_thrust(data, exclusive=True, return_reduction=False): return output -def exclusive_scan(data, axis=-1, return_reduction=False): +def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): # TODO(masahi): support other binary associative operators + # TODO: handle output_dtype ndim = len(data.shape) if axis < 0: axis += ndim @@ -240,7 +241,7 @@ def exclusive_scan(data, axis=-1, return_reduction=False): data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) - if ndim == 2: + if len(data.shape) == 2: if return_reduction: output, reduction = te.extern( [data.shape, (data.shape[0],)], From 7608e10f4a46d57243b2d9d114a760ae7ad2d391 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 11 Jan 2021 05:58:57 +0900 Subject: [PATCH 03/15] rename --- python/tvm/topi/cuda/scan.py | 48 +++++++++++++++++------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 380593a6d5b4..da45bee10e82 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -27,7 +27,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): TODO """ batch_size = data.shape[0] - num_anchors = data.shape[1] + scan_axis_size = data.shape[1] ib = tvm.tir.ir_builder.create() @@ -39,11 +39,10 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - # Copy boxes to output - with ib.if_scope(num_anchors > 0): + with ib.if_scope(scan_axis_size > 0): with ib.new_scope(): nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) + nthread_bx = ceil_div(scan_axis_size, max_threads) nthread_by = batch_size tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -52,18 +51,17 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(by, "thread_extent", nthread_by) tid = bx * nthread_tx + tx - with ib.if_scope(tid < num_anchors): + with ib.if_scope(tid < scan_axis_size): output[by, tid] = data[by, tid] nthread_tx = max_threads - nthread_bx = ceil_div(num_anchors, max_threads) + nthread_bx = ceil_div(scan_axis_size, max_threads) nthread_by = batch_size - ## The following algorithm performs parallel exclusive scan to get - ## a tensor that can later be used to select valid indices + # The following algorithm performs parallel exclusive scan # Up Sweep of exclusive scan lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width @@ -75,7 +73,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): ib.scope_attr( bx, "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + tvm.tir.generic.cast(ceil_div(scan_axis_size, max_threads * width), "int32"), ) tid = bx * nthread_tx + tx @@ -85,12 +83,12 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): middle = ib.allocate("int64", (1,), name="middle", scope="local") end = ib.allocate("int64", (1,), name="end", scope="local") start[0] = width * tid - with ib.if_scope(start[0] < num_anchors): + with ib.if_scope(start[0] < scan_axis_size): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.te.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - output[by * num_anchors + end[0] - 1] += output[ - by * num_anchors + middle[0] - 1 + end[0] = tvm.te.min(start[0] + width, scan_axis_size) + with ib.if_scope(middle[0] < scan_axis_size): + output[by * scan_axis_size + end[0] - 1] += output[ + by * scan_axis_size + middle[0] - 1 ] # Down Sweep of exclusive scan @@ -99,8 +97,8 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): if reduction is not None: - reduction[bx] = output[(bx + 1) * num_anchors - 1] - output[(bx + 1) * num_anchors - 1] = 0 + reduction[bx] = output[(bx + 1) * scan_axis_size - 1] + output[(bx + 1) * scan_axis_size - 1] = 0 with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << (lim - l2_width - 1) @@ -112,7 +110,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): ib.scope_attr( bx, "thread_extent", - tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"), + tvm.tir.generic.cast(ceil_div(scan_axis_size, max_threads * width), "int32"), ) tid = bx * nthread_tx + tx @@ -123,13 +121,13 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): end = ib.allocate("int64", (1,), name="end", scope="local") tmp = ib.allocate("int32", (1,), name="end", scope="local") start[0] = width * tid - with ib.if_scope(tvm.tir.all(start[0] < num_anchors)): + with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.tir.min(start[0] + width, num_anchors) - with ib.if_scope(middle[0] < num_anchors): - tmp[0] = output[by * num_anchors + middle[0] - 1] - output[by * num_anchors + middle[0] - 1] = output[by * num_anchors + end[0] - 1] - output[by * num_anchors + end[0] - 1] += tmp[0] + end[0] = tvm.tir.min(start[0] + width, scan_axis_size) + with ib.if_scope(middle[0] < scan_axis_size): + tmp[0] = output[by * scan_axis_size + middle[0] - 1] + output[by * scan_axis_size + middle[0] - 1] = output[by * scan_axis_size + end[0] - 1] + output[by * scan_axis_size + end[0] - 1] += tmp[0] with ib.else_scope(): with ib.new_scope(): bx = te.thread_axis("blockIdx.x") @@ -137,8 +135,6 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): with ib.if_scope(bx < batch_size): if reduction is not None: reduction[bx] = 0 - - return ib.get() From b72840e52152a6172b3b6b7ed33bcde6f923a1a8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 11 Jan 2021 06:11:54 +0900 Subject: [PATCH 04/15] cast to out dtype --- python/tvm/topi/cuda/scan.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index da45bee10e82..c48495183da5 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -20,6 +20,7 @@ from tvm._ffi import get_global_func from ..transform import expand_dims, squeeze from ..utils import ceil_div +from ..math import cast def exclusive_sum_scan2d_ir(data, output, reduction=None): @@ -34,6 +35,8 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): data = ib.buffer_ptr(data) output = ib.buffer_ptr(output) + out_dtype = output.dtype + if reduction is not None: reduction = ib.buffer_ptr(reduction) @@ -98,7 +101,7 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): with ib.if_scope(bx < batch_size): if reduction is not None: reduction[bx] = output[(bx + 1) * scan_axis_size - 1] - output[(bx + 1) * scan_axis_size - 1] = 0 + output[(bx + 1) * scan_axis_size - 1] = cast(0, out_dtype) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << (lim - l2_width - 1) @@ -119,15 +122,18 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): start = ib.allocate("int64", (1,), name="start", scope="local") middle = ib.allocate("int64", (1,), name="middle", scope="local") end = ib.allocate("int64", (1,), name="end", scope="local") - tmp = ib.allocate("int32", (1,), name="end", scope="local") + tmp = ib.allocate(out_dtype, (1,), name="end", scope="local") start[0] = width * tid with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) end[0] = tvm.tir.min(start[0] + width, scan_axis_size) with ib.if_scope(middle[0] < scan_axis_size): tmp[0] = output[by * scan_axis_size + middle[0] - 1] - output[by * scan_axis_size + middle[0] - 1] = output[by * scan_axis_size + end[0] - 1] + output[by * scan_axis_size + middle[0] - 1] = output[ + by * scan_axis_size + end[0] - 1 + ] output[by * scan_axis_size + end[0] - 1] += tmp[0] + with ib.else_scope(): with ib.new_scope(): bx = te.thread_axis("blockIdx.x") @@ -190,16 +196,17 @@ def is_thrust_available(): return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None -def scan_thrust(data, exclusive=True, return_reduction=False): +def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False): + """TODO""" data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) output = te.extern( [data.shape], [data], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.thrust.sum_scan", ins[0], outs[0], exclusive ), - dtype=[data.dtype], + dtype=[output_dtype], in_buffers=[data_buf], out_buffers=[output_buf], name="exclusive_sum_scan2d", @@ -220,22 +227,24 @@ def scan_thrust(data, exclusive=True, return_reduction=False): def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): - # TODO(masahi): support other binary associative operators - # TODO: handle output_dtype + """TODO""" ndim = len(data.shape) if axis < 0: axis += ndim assert axis == ndim - 1, "Only support scan on the inner most axis." + if output_dtype is None: + output_dtype = data.dtype + target = tvm.target.Target.current() if target and target.kind.name == "cuda" and is_thrust_available(): - return scan_thrust(data, exclusive=True, return_reduction=return_reduction) + return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction) if ndim == 1: data = expand_dims(data, axis=0) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) if len(data.shape) == 2: if return_reduction: @@ -243,7 +252,7 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): [data.shape, (data.shape[0],)], [data], lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0], outs[1]), - dtype=[data.dtype, data.dtype], + dtype=[data.dtype, output_dtype], in_buffers=[data_buf], name="exclusive_scan", tag="exclusive_scan_gpu", @@ -253,7 +262,7 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): [data.shape], [data], lambda ins, outs: exclusive_sum_scan2d_ir(ins[0], outs[0]), - dtype=[data.dtype], + dtype=[output_dtype], in_buffers=[data_buf], out_buffers=[output_buf], name="exclusive_scan", From f2667e397f4beb293911d49bb57d84769eddab01 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 11 Jan 2021 06:19:40 +0900 Subject: [PATCH 05/15] do not run return reduction for inclusive scan --- python/tvm/topi/cuda/scan.py | 1 + python/tvm/topi/utils.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index c48495183da5..e251f2993060 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -214,6 +214,7 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False): ) if return_reduction: + assert exclusive, "return_reduction should be False for inclusive scan" ndim = len(data.shape) if ndim == 1: output = expand_dims(output, axis=0) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 769aa436fb08..c5acf384faf6 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -490,4 +490,5 @@ def is_empty_shape(shape): def ceil_div(a, b): - return (a + b - 1) // b + """Return ceil division of a by b""" + return (a + (b - 1)) // b From b1bbedff7c5abbcdb09e3ec722cbf919cb3c4e50 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 11 Jan 2021 06:31:25 +0900 Subject: [PATCH 06/15] remove another ceil_div definition --- python/tvm/topi/cuda/sort.py | 5 +---- python/tvm/topi/utils.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 9b6a18a8b06b..18340385205e 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -23,6 +23,7 @@ from .injective import schedule_injective_from_existing from ..transform import strided_slice, transpose from .. import tag +from ..utils import ceil_div def swap(arr, axis): @@ -61,10 +62,6 @@ def traverse(op): return s -def ceil_div(a, b): - return tvm.tir.indexdiv(a + b - 1, b) - - def _sort_init(ib, shape, axis, keys_in, keys_out, values_out=None, value_init_func=None): """Initialize the output buffers by copying from inputs""" axis_mul_before = 1 diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index c5acf384faf6..dfc226f0c331 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -491,4 +491,4 @@ def is_empty_shape(shape): def ceil_div(a, b): """Return ceil division of a by b""" - return (a + (b - 1)) // b + return tvm.tir.indexdiv(a + (b - 1), b) From 9783462927d2f3d92e8e7899bf761b4a8946f8bc Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 17 Jan 2021 23:16:02 +0900 Subject: [PATCH 07/15] adding scan test --- python/tvm/topi/cuda/__init__.py | 1 - tests/python/contrib/test_sort.py | 35 +----------- tests/python/contrib/test_thrust.py | 84 +++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 35 deletions(-) create mode 100644 tests/python/contrib/test_thrust.py diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index e0ff5a12a9b2..42bf980bec4c 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -56,4 +56,3 @@ from .correlation import * from .sparse import * from .argwhere import * -from .scan import * diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index f338276ca118..a049602ac265 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm import te -from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available, sort_by_key +from tvm.topi.cuda import sort_by_key import numpy as np @@ -91,38 +91,6 @@ def test_sort_np(): tvm.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5) -def test_thrust_stable_sort_by_key(): - if not is_thrust_available(): - print("skip because thrust is not enabled...") - return - - size = 6 - keys = te.placeholder((size,), name="keys", dtype="int32") - values = te.placeholder((size,), name="values", dtype="int32") - - keys_out, values_out = stable_sort_by_key_thrust(keys, values) - - ctx = tvm.gpu(0) - target = "cuda" - s = te.create_schedule([keys_out.op, values_out.op]) - f = tvm.build(s, [keys, values, keys_out, values_out], target) - - keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) - values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) - keys_np_out = np.zeros(keys_np.shape, np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) - keys_in = tvm.nd.array(keys_np, ctx) - values_in = tvm.nd.array(values_np, ctx) - keys_out = tvm.nd.array(keys_np_out, ctx) - values_out = tvm.nd.array(values_np_out, ctx) - f(keys_in, values_in, keys_out, values_out) - - ref_keys_out = np.sort(keys_np) - ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) - tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) - tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) - - def test_sort_by_key_gpu(): size = 6 keys = te.placeholder((size,), name="keys", dtype="int32") @@ -158,5 +126,4 @@ def test_sort_by_key_gpu(): if __name__ == "__main__": test_sort() test_sort_np() - test_thrust_stable_sort_by_key() test_sort_by_key_gpu() diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py new file mode 100644 index 000000000000..ecd9847d6e71 --- /dev/null +++ b/tests/python/contrib/test_thrust.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import te +from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available +from tvm.topi.cuda.scan import exclusive_scan +import numpy as np + + +def test_stable_sort_by_key(): + if not is_thrust_available(): + print("skip because thrust is not enabled...") + return + + size = 6 + keys = te.placeholder((size,), name="keys", dtype="int32") + values = te.placeholder((size,), name="values", dtype="int32") + + keys_out, values_out = stable_sort_by_key_thrust(keys, values) + + ctx = tvm.gpu(0) + target = "cuda" + s = te.create_schedule([keys_out.op, values_out.op]) + f = tvm.build(s, [keys, values, keys_out, values_out], target) + + keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32) + values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32) + keys_np_out = np.zeros(keys_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + keys_in = tvm.nd.array(keys_np, ctx) + values_in = tvm.nd.array(values_np, ctx) + keys_out = tvm.nd.array(keys_np_out, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(keys_in, values_in, keys_out, values_out) + + ref_keys_out = np.sort(keys_np) + ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)]) + tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5) + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + + +def test_scan(): + if not is_thrust_available(): + print("skip because thrust is not enabled...") + return + + for ishape in [(10,), (10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") + + with tvm.target.Target("cuda"): + scan = exclusive_scan(values) + s = te.create_schedule([scan.op]) + + ctx = tvm.gpu(0) + f = tvm.build(s, [values, scan], "cuda") + + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + f(values_in, values_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + + +if __name__ == "__main__": + test_stable_sort_by_key() + test_scan() From 55df6d40223d06f93e15b96b1a103ef9bba1515a Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 17 Jan 2021 23:28:29 +0900 Subject: [PATCH 08/15] add scheduling for scan op, fixed scan 1d test --- python/tvm/topi/cuda/scan.py | 33 +++++++++++++++++++++++++++++ tests/python/contrib/test_thrust.py | 4 ++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index e251f2993060..70293a43f578 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -21,6 +21,8 @@ from ..transform import expand_dims, squeeze from ..utils import ceil_div from ..math import cast +from .. import tag +from .injective import schedule_injective_from_existing def exclusive_sum_scan2d_ir(data, output, reduction=None): @@ -282,3 +284,34 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): return output, reduction return output + + +def schedule_scan(outs): + """Schedule for scan operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of scan + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_from_existing(s, op.output(0)) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + return s diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py index ecd9847d6e71..686ee9afd0b7 100644 --- a/tests/python/contrib/test_thrust.py +++ b/tests/python/contrib/test_thrust.py @@ -18,7 +18,7 @@ import tvm.testing from tvm import te from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available -from tvm.topi.cuda.scan import exclusive_scan +from tvm.topi.cuda.scan import exclusive_scan, schedule_scan import numpy as np @@ -64,7 +64,7 @@ def test_scan(): with tvm.target.Target("cuda"): scan = exclusive_scan(values) - s = te.create_schedule([scan.op]) + s = schedule_scan([scan]) ctx = tvm.gpu(0) f = tvm.build(s, [values, scan], "cuda") From f74059528418d883cd2b73f384e591b0d0cb452b Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 18 Jan 2021 16:13:43 +0900 Subject: [PATCH 09/15] pylint fix --- python/tvm/topi/cuda/scan.py | 1 + python/tvm/topi/cuda/sparse.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 70293a43f578..6bc470f8be33 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements "Scan related operators" import tvm from tvm import te diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 3e7606556d9c..0b46cf0f9f97 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -161,6 +161,7 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr): with either default_function_kernel0 for the transpose or default_function_kernel1 for the multiply. """ + def gen_ir(data, w_data, w_indices, w_indptr, out): # pylint: disable=invalid-name # TODO(tkonolige): use tensorcores for block multiply From b4795ef9502fe2c2a90f10d5911f9c6132d4dcc7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Jan 2021 19:41:20 +0900 Subject: [PATCH 10/15] add doc string --- python/tvm/topi/cuda/scan.py | 179 ++++++++++++++++++++++++++--------- 1 file changed, 132 insertions(+), 47 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 6bc470f8be33..49a0613f597e 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -27,9 +27,20 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): + """Low level IR to do exclusive sum scan along rows of 2D input. + + Parameters + ---------- + data : Buffer + Input data. 2-D Buffer with shape [batch_size, scan_axis_size]. + + output: Buffer + A buffer to store the output scan, of the same size as data + + reduction: Buffer, optional + 1D Buffer of size [batch_size], to store the sum of each row. """ - TODO - """ + batch_size = data.shape[0] scan_axis_size = data.shape[1] @@ -45,7 +56,14 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.if_scope(scan_axis_size > 0): + with ib.if_scope(scan_axis_size == 0): + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", batch_size) + with ib.if_scope(bx < batch_size): + if reduction is not None: + reduction[bx] = 0 + with ib.else_scope(): with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(scan_axis_size, max_threads) @@ -136,71 +154,114 @@ def exclusive_sum_scan2d_ir(data, output, reduction=None): by * scan_axis_size + end[0] - 1 ] output[by * scan_axis_size + end[0] - 1] += tmp[0] - - with ib.else_scope(): - with ib.new_scope(): - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", batch_size) - with ib.if_scope(bx < batch_size): - if reduction is not None: - reduction[bx] = 0 return ib.get() -def get_reduction_from_exclusive_scan_ir(data, data_ex_scan, reduction): - """TODO""" - batch_size = data.shape[0] - num_anchors = data.shape[1] +def get_reduction_from_exclusive_scan(data, ex_scan_output): + """Return the sum of the last element of data and the exclusive scan output. + The is the reduction of data along each row (for 2-D case). - ib = tvm.tir.ir_builder.create() + Parameters + ---------- + data : tvm.te.Tensor + Input data. 1-D tensor with shape [scan_axis_size], or + 2-D tensor with shape [batch_size, scan_axis_size]. - data = ib.buffer_ptr(data) - data_ex_scan = ib.buffer_ptr(data_ex_scan) - reduction = ib.buffer_ptr(reduction) + ex_scan_output : tvm.te.Tensor + 1-D tensor that is the exclusive scan of the input, or + 2-D tensor storing the exclusive scan of each row. - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - with ib.if_scope(tid < batch_size): - reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + Returns + ------- + reduction : tvm.te.Tensor + 1-D tensor storing the reduction of each row. + """ + ndim = len(data.shape) + if ndim == 1: + data = expand_dims(data, axis=0) + ex_scan_output = expand_dims(ex_scan_output, axis=0) - return ib.get() + def ir(data, data_ex_scan, reduction): + batch_size = data.shape[0] + num_anchors = data.shape[1] + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + data_ex_scan = ib.buffer_ptr(data_ex_scan) + reduction = ib.buffer_ptr(reduction) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + + return ib.get() -def get_reduction_from_exclusive_scan(data, ex_scan_output): - """TODO""" assert len(data.shape) == 2, "Only 2D input supported for now" data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8) ex_scan_output_buf = tvm.tir.decl_buffer( ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8 ) - return te.extern( + reduction = te.extern( [(data.shape[0],)], [data, ex_scan_output], - lambda ins, outs: get_reduction_from_exclusive_scan_ir(ins[0], ins[1], outs[0]), + lambda ins, outs: ir(ins[0], ins[1], outs[0]), dtype=[ex_scan_output.dtype], in_buffers=[data_buf, ex_scan_output_buf], name="ex_scan_reduction", tag="ex_scan_reduction_gpu", ) + if ndim == 1: + return squeeze(reduction, 0) + + return reduction + def is_thrust_available(): - """ - Test if thrust based scan ops are available. - """ + """Test if thrust based scan ops are available.""" return get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True) is not None def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False): - """TODO""" + """Do exclusive scan on 1D input or along rows of 2D input, using thrust. + + Parameters + ---------- + data : tvm.te.Tensor + Input data. 1-D tensor with shape [scan_axis_size], or + 2-D tensor with shape [batch_size, scan_axis_size]. + + output_dtype: string + The dtype of the output scan tensor. + + exclusive: bool, optional + Whether or not do exclusive or inclusive scan. + + return_reduction: bool, optional + Whether or not return a 1-D tensor storing the reduction of each row. + Reductions are computed as part of the upsweep pass, so there is no extra cost. + If False, reductions are ignored. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor that is the exclusive scan of the input, or + 2-D tensor storing the exclusive scan of each row. + + reduction : tvm.te.Tensor, optional + 1-D tensor storing the reduction of each row. + Returned if return_reduction is True. + """ data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) output = te.extern( @@ -218,20 +279,43 @@ def scan_thrust(data, output_dtype, exclusive=True, return_reduction=False): if return_reduction: assert exclusive, "return_reduction should be False for inclusive scan" - ndim = len(data.shape) - if ndim == 1: - output = expand_dims(output, axis=0) - reduction = get_reduction_from_exclusive_scan(data, output) - reduction = squeeze(reduction, 0) - else: - reduction = get_reduction_from_exclusive_scan(data, output) + reduction = get_reduction_from_exclusive_scan(data, output) return output, reduction return output def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): - """TODO""" + """Do exclusive scan on 1D input or along rows of 2D input. + + Parameters + ---------- + data : tvm.te.Tensor + Input data. 1-D tensor with shape [scan_axis_size], or + 2-D tensor with shape [batch_size, scan_axis_size]. + + axis: int, optional + The axis to do scan on. For now, only the inner most axis is supported. + + return_reduction: bool, optional + Whether or not return a 1-D tensor storing the reduction of each row. + Reductions are computed as part of the upsweep pass, so there is no extra cost. + If False, reductions are ignored. + + output_dtype: string, optional + The dtype of the output scan tensor. If not provided, the dtype of the input is used. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor that is the exclusive scan of the input, or + 2-D tensor storing the exclusive scan of each row. + + reduction : tvm.te.Tensor, optional + 1-D tensor storing the reduction of each row. + Returned if return_reduction is True. + """ + # TODO(masahi): Support other binary operators ndim = len(data.shape) if axis < 0: axis += ndim @@ -245,6 +329,7 @@ def exclusive_scan(data, axis=-1, return_reduction=False, output_dtype=None): return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction) if ndim == 1: + # TIR exclusive scan accepts only 2D inputs. data = expand_dims(data, axis=0) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) From 6c70ed22f9cf282b30ce2f489e6828d3411d0015 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Jan 2021 20:44:25 +0900 Subject: [PATCH 11/15] add more thrust scan test --- tests/python/contrib/test_thrust.py | 53 +++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/tests/python/contrib/test_thrust.py b/tests/python/contrib/test_thrust.py index 686ee9afd0b7..5f66d465bf17 100644 --- a/tests/python/contrib/test_thrust.py +++ b/tests/python/contrib/test_thrust.py @@ -18,7 +18,7 @@ import tvm.testing from tvm import te from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available -from tvm.topi.cuda.scan import exclusive_scan, schedule_scan +from tvm.topi.cuda.scan import exclusive_scan, scan_thrust, schedule_scan import numpy as np @@ -54,31 +54,70 @@ def test_stable_sort_by_key(): tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) -def test_scan(): +def test_exclusive_scan(): if not is_thrust_available(): print("skip because thrust is not enabled...") return + for ishape in [(1,), (10, 10)]: + values = te.placeholder(ishape, name="values", dtype="int32") + + with tvm.target.Target("cuda"): + scan, reduction = exclusive_scan(values, return_reduction=True) + s = schedule_scan([scan, reduction]) + + ctx = tvm.gpu(0) + f = tvm.build(s, [values, scan, reduction], "cuda") + + values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) + values_np_out = np.zeros(values_np.shape, np.int32) + + if len(ishape) == 1: + reduction_shape = () + else: + reduction_shape = (ishape[0],) + + reduction_np_out = np.zeros(reduction_shape, np.int32) + + values_in = tvm.nd.array(values_np, ctx) + values_out = tvm.nd.array(values_np_out, ctx) + reduction_out = tvm.nd.array(reduction_np_out, ctx) + f(values_in, values_out, reduction_out) + + ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np + tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) + ref_reduction_out = np.sum(values_np, axis=-1) + tvm.testing.assert_allclose(reduction_out.asnumpy(), ref_reduction_out, rtol=1e-5) + + +def test_inclusive_scan(): + if not is_thrust_available(): + print("skip because thrust is not enabled...") + return + + out_dtype = "int64" + for ishape in [(10,), (10, 10)]: values = te.placeholder(ishape, name="values", dtype="int32") with tvm.target.Target("cuda"): - scan = exclusive_scan(values) - s = schedule_scan([scan]) + scan = scan_thrust(values, out_dtype, exclusive=False) + s = tvm.te.create_schedule([scan.op]) ctx = tvm.gpu(0) f = tvm.build(s, [values, scan], "cuda") values_np = np.random.randint(0, 10, size=ishape).astype(np.int32) - values_np_out = np.zeros(values_np.shape, np.int32) + values_np_out = np.zeros(values_np.shape, out_dtype) values_in = tvm.nd.array(values_np, ctx) values_out = tvm.nd.array(values_np_out, ctx) f(values_in, values_out) - ref_values_out = np.cumsum(values_np, axis=-1, dtype="int32") - values_np + ref_values_out = np.cumsum(values_np, axis=-1, dtype=out_dtype) tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5) if __name__ == "__main__": test_stable_sort_by_key() - test_scan() + test_exclusive_scan() + test_inclusive_scan() From a6c740348282c2d13f22883e62c7c910b73ad8c2 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 20 Jan 2021 11:43:38 +0900 Subject: [PATCH 12/15] add dynamic get valid count test, including empty size tensor --- tests/python/relay/test_any.py | 45 ++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index d30e7873dae7..34947d57a406 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -879,6 +879,51 @@ def test_any_topk(): verify_any_topk(any_dims(1), 0, (0,), "float32", ret_type="both") +def verify_any_get_valid_counts(num_anchor_real, dtype, targets=None): + mod = tvm.IRModule() + batch_size = 1 + num_anchor = relay.Any() + data = relay.var("data", shape=(batch_size, num_anchor, 5), dtype=dtype) + np_data = np.random.uniform(size=(batch_size, num_anchor_real, 5)).astype(dtype) + + np_out1 = np.zeros(shape=(batch_size,)) + np_out2 = np.zeros(shape=np_data.shape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor_real)) + score_threshold = 0.95 + + for i in range(batch_size): + np_out1[i] = 0 + inter_idx = 0 + for j in range(num_anchor_real): + score = np_data[i, j, 0] + if score > score_threshold: + for k in range(5): + np_out2[i, inter_idx, k] = np_data[i, j, k] + np_out1[i] += 1 + np_out3[i, inter_idx] = j + inter_idx += 1 + if j >= np_out1[i]: + for k in range(5): + np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 + + z = relay.vision.get_valid_counts(data, score_threshold, 0, score_index=0) + + mod["main"] = relay.Function([data], z.astuple()) + + check_result([np_data], mod, [np_out1, np_out2, np_out3], targets=targets) + + +@tvm.testing.uses_gpu +def test_any_get_valid_counts(): + verify_any_get_valid_counts(10, "float32") + # opencl seems to have issues with empty size buffer + # Check failed: err_code == CL_SUCCESS == false: OpenCL Error, + # code=-61: CL_INVALID_BUFFER_SIZE + targets = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0)), ("nvptx", tvm.gpu(0))] + verify_any_get_valid_counts(0, "float32", targets=targets) + + @tvm.testing.uses_gpu def test_fused_ops(): x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32") From e2df3c6dee65e63e6a0ad3fb01497730b0e41232 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 12:48:04 +0900 Subject: [PATCH 13/15] fix hard coded gpu targets for cpu only env --- tests/python/relay/test_any.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 34947d57a406..a537782355d2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -920,7 +920,10 @@ def test_any_get_valid_counts(): # opencl seems to have issues with empty size buffer # Check failed: err_code == CL_SUCCESS == false: OpenCL Error, # code=-61: CL_INVALID_BUFFER_SIZE - targets = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0)), ("nvptx", tvm.gpu(0))] + targets = [] + for tgt, ctx in tvm.testing.enabled_targets(): + if "opencl" not in tgt: + targets.append((tgt, ctx)) verify_any_get_valid_counts(0, "float32", targets=targets) From a88e53cfed996a27f56e49d1ce3a1667f0d1d32f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 14:48:36 +0900 Subject: [PATCH 14/15] try retunring early if scan_size is 0 --- src/runtime/contrib/thrust/thrust.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index fe6cc43fefdc..4e3e3a81af1a 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -273,6 +273,8 @@ void thrust_scan(DLTensor* data, thrust::device_ptr output_ptr(static_cast(output->data)); const auto scan_size = data->shape[data->ndim - 1]; + if (scan_size == 0) return; + if (data->ndim == 1 || (data->ndim == 2 && data->shape[0] == 1)) { if (exclusive) { thrust::exclusive_scan(data_ptr, data_ptr + scan_size, output_ptr); From 717270b8e2b1575a318bcd0b1b9c939eb40a6a8f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Jan 2021 15:16:59 +0900 Subject: [PATCH 15/15] another change for empty tensor and thrust path --- python/tvm/topi/cuda/scan.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 49a0613f597e..f19e4a14239a 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -201,7 +201,10 @@ def ir(data, data_ex_scan, reduction): ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): - reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + with ib.if_scope(num_anchors > 0): + reduction[tid] = data_ex_scan[tid, num_anchors - 1] + data[tid, num_anchors - 1] + with ib.else_scope(): + reduction[tid] = 0 return ib.get()