diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index e6c6039b610e..7023aebe17ad 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -72,7 +72,10 @@ Operation ExternOpNode::make(std::string name, CHECK_EQ(inputs.size(), input_placeholders.size()); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); - CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape)); + CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); + for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) { + CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); + } CHECK_EQ(input_placeholders[i]->strides.size(), 0U); } n->inputs = std::move(inputs); diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 0c27bd216999..925cf24acd11 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -24,6 +24,7 @@ from tvm.intrin import if_then_else, log, power from topi.vision import non_max_suppression, get_valid_counts from .sort import argsort +from .. import tag def get_valid_counts_pre(data, flag, idx, score_threshold): @@ -730,7 +731,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, "valid_count_buf", data_alignment=4) score_axis = score_index score_shape = (batch_size, num_anchors) - score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) + score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 99ba8527cdfb..678d494dae50 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -20,6 +20,10 @@ from tvm import api from topi.sort import argsort +from topi.math import identity +from .. import generic +from .. import tag + def sort_ir(data, output, axis, is_ascend): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. @@ -104,8 +108,6 @@ def sort_ir(data, output, axis, is_ascend): return ib.get() - - def sort_nms_ir(data, valid_count, output, axis, is_ascend): """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. @@ -221,29 +223,60 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0 out : tvm.Tensor The output of this function. """ - data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) + sorted_data = identity(data) if flag: valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) out = tvm.extern([data.shape], - [data, valid_count], + [sorted_data, valid_count], lambda ins, outs: sort_nms_ir( ins[0], ins[1], outs[0], axis, is_ascend), dtype="int32", - in_buffers=[data_buf, valid_count_buf], + in_buffers=[sorted_data_buf, valid_count_buf], out_buffers=[out_buf], name="argsort_nms_gpu", tag="argsort_nms_gpu") else: out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out = tvm.extern([data.shape], - [data], + [sorted_data], lambda ins, outs: sort_ir( ins[0], outs[0], axis, is_ascend), dtype=dtype, - in_buffers=[data_buf], + in_buffers=[sorted_data_buf], out_buffers=[out_buf], name="argsort_gpu", tag="argsort_gpu") return out + +@generic.schedule_argsort.register(["cuda", "gpu"]) +def schedule_argsort(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argsort + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + from .injective import _schedule_injective + def traverse(op): + if tag.is_broadcast(op.tag): + _schedule_injective(op, s) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + traverse(outs[0].op) + + return s diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 78f5c1f51ec6..968e554ac81d 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -25,41 +25,17 @@ def _default_schedule(outs): """Default schedule for gpu.""" - target = tvm.target.current_target() outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] - + from .injective import _schedule_injective def traverse(op): - """inline all one-to-one-mapping operators except the last stage (output)""" - if op.tag in ["nms", "invalid_to_bottom"]: - if op.tag == "nms": - sort = op.input_tensors[1] - else: - out = op.input_tensors[0] - sort = s[out].op.input_tensors[1] - score = s[sort].op.input_tensors[0] - fused = s[score].fuse(*s[score].op.axis) - num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads) - bx, tx = s[score].split(fused, factor=num_thread) - s[score].bind(bx, tvm.thread_axis("blockIdx.x")) - s[score].bind(tx, tvm.thread_axis("threadIdx.x")) - if tag.is_broadcast(op.tag): - if op not in s.outputs: - s[op].compute_inline() - else: - x = op.output(0) - fused = s[x].fuse(*s[x].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads - bx, tx = s[x].split(fused, factor=num_thread) - s[x].bind(bx, tvm.thread_axis("blockIdx.x")) - s[x].bind(tx, tvm.thread_axis("threadIdx.x")) - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - + if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']: + _schedule_injective(op, s) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) scheduled_ops.append(op) - traverse(outs[0].op) return s @@ -173,19 +149,7 @@ def schedule_proposal(outs): s: Schedule The computation schedule for the op. """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - from .injective import _schedule_injective - def traverse(op): - if op.tag in ['bbox_score', 'sorted_bbox']: - _schedule_injective(op, s) - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - scheduled_ops.append(op) - traverse(outs[0].op) - return s + return _default_schedule(outs) @generic.schedule_get_valid_counts.register(["cuda", "gpu"]) def schedule_get_valid_counts(outs): @@ -203,30 +167,3 @@ def schedule_get_valid_counts(outs): The computation schedule for the op. """ return _default_schedule(outs) - -@generic.schedule_argsort.register(["cuda", "gpu"]) -def schedule_argsort(outs): - """Schedule for argsort operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of argsort - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for the op. - """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - scheduled_ops = [] - from .injective import _schedule_injective - def traverse(op): - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - scheduled_ops.append(op) - traverse(outs[0].op) - return s