diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 896e8af99921..49c744edcf3f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2955,6 +2955,39 @@ def _impl_v11(cls, inputs, attr, params): return out +class Unique(OnnxOpConverter): + """Operator converter for unique""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + if len(inputs) != 1: + raise ValueError("Unique expects 1 input") + + data = inputs[0] + axis = attr.get("axis", None) + if axis is None: # If axis is None, flatten the input before calling unique + data = _op.reshape(data, _op.const([-1])) + else: + data_shape = infer_shape(data) + if len(data_shape) != 1: + raise ValueError("TVM only supports 1D Unique operator.") + is_sorted = attr.get("sorted", 1) # sorted is 0 or 1, 1 by default + + # ONNX documentation lists return_counts as optional but there is no input to specify + # whether it is returned. Therefore we'll just always return it. + unique = _op.unique(data, is_sorted=(is_sorted == 1), return_counts=True) + num_unique = unique[3] + + trim_unique_lambda = lambda input: _op.strided_slice(input, _op.const([0]), num_unique) + + unique_vals = trim_unique_lambda(unique[0]) + indices = trim_unique_lambda(unique[1]) + inverse_indices = unique[2] + counts = trim_unique_lambda(unique[4]) + # ONNX unique returns unique, indices, inverse_indices, (optional) counts + return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3118,6 +3151,7 @@ def _get_convert_map(opset): "NonZero": NonZero.get_converter(opset), "Range": Range.get_converter(opset), "CumSum": CumSum.get_converter(opset), + "Unique": Unique.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), @@ -3306,6 +3340,12 @@ def from_onnx(self, graph, opset, get_output_expr=False): outputs_num = 1 else: outputs_num = len(op) + + if outputs_num == 1: + op = fold_constant(op) + else: + op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) + if outputs_num > 1: # ONNX supports optional outputs for some nodes. # This block searches for missing outputs in the ONNX graph @@ -3327,8 +3367,8 @@ def from_onnx(self, graph, opset, get_output_expr=False): # Create the new op with valid outputs if len(outputs) == 1: op = outputs[0] - else: - op = _expr.TupleWrapper(outputs, len(outputs)) + elif len(outputs) != outputs_num: + op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs)) # Drop invalid outputs for the onnx node outputs_num = len(outputs) node_output = [output for output in node_output if output != ""] @@ -3337,10 +3377,10 @@ def from_onnx(self, graph, opset, get_output_expr=False): ), "Number of output mismatch {} vs {} in {}.".format( len(node_output), outputs_num, op_name ) + if outputs_num == 1: - self._nodes[node_output[0]] = fold_constant(op) + self._nodes[node_output[0]] = op else: - op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op)) for k, i in zip(list(node_output), range(len(node_output))): self._nodes[k] = op[i] diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b5cfcf5e3bac..f0ba99291727 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2294,16 +2294,18 @@ def unique(self, inputs, input_types): logging.warning("TVM always assumes sorted=True for torch.unique") is_sorted = True if return_counts: - [unique, indices, num_uniq, counts] = _op.unique( + [unique, indices, inverse_indices, num_uniq, counts] = _op.unique( data, is_sorted=is_sorted, return_counts=True ) unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") - return (unique_sliced, indices, counts_sliced) + return (unique_sliced, inverse_indices, counts_sliced) else: - [unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False) + [unique, indices, inverse_indices, num_uniq] = _op.unique( + data, is_sorted=is_sorted, return_counts=False + ) unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") - return (unique_sliced, indices) + return (unique_sliced, inverse_indices) # Operator mappings def create_convert_map(self): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4af73702ad9c..040f8384dbe0 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2702,19 +2702,21 @@ def _impl(inputs, attr, params, mod): assert len(inputs) == 1 data = inputs[0] if return_counts: - [unique, indices, num_uniq, counts] = _op.unique( + [unique, _, inverse_indices, num_uniq, counts] = _op.unique( data, is_sorted=False, return_counts=True ) unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") return _expr.TupleWrapper( - _expr.Tuple([unique_sliced, indices, counts_sliced]), + _expr.Tuple([unique_sliced, inverse_indices, counts_sliced]), 3, ) - [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) + [unique, _, inverse_indices, num_uniq] = _op.unique( + data, is_sorted=False, return_counts=False + ) unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") return _expr.TupleWrapper( - _expr.Tuple([unique_sliced, indices]), + _expr.Tuple([unique_sliced, inverse_indices]), 2, ) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 94c413b6df6c..fee3eacf1aec 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1045,24 +1045,28 @@ def ensure_tensor(tensor): def _unique_shape(data_shape): unique_shape = output_tensor((1,), "int64") indices_shape = output_tensor((1,), "int64") + inverse_indices_shape = output_tensor((1,), "int64") num_unique_shape = output_tensor((1,), "int64") unique_shape[0] = data_shape[0] indices_shape[0] = data_shape[0] + inverse_indices_shape[0] = data_shape[0] num_unique_shape[0] = int64(1) - return (unique_shape, indices_shape, num_unique_shape) + return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape) @script def _unique_with_counts_shape(data_shape): unique_shape = output_tensor((1,), "int64") indices_shape = output_tensor((1,), "int64") + inverse_indices_shape = output_tensor((1,), "int64") num_unique_shape = output_tensor((1,), "int64") counts_shape = output_tensor((1,), "int64") unique_shape[0] = data_shape[0] indices_shape[0] = data_shape[0] + inverse_indices_shape[0] = data_shape[0] num_unique_shape[0] = int64(1) counts_shape[0] = data_shape[0] - return (unique_shape, indices_shape, num_unique_shape, counts_shape) + return (unique_shape, indices_shape, inverse_indices_shape, num_unique_shape, counts_shape) @_reg.register_shape_func("unique", False) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 74fb44fc2232..71ecc0076285 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1658,7 +1658,7 @@ def unique(data, is_sorted=True, return_counts=False): data : relay.Expr A 1-D tensor of integers. - sorted : bool + is_sorted : bool Whether to sort the unique elements in ascending order before returning as output. return_counts : bool @@ -1666,12 +1666,16 @@ def unique(data, is_sorted=True, return_counts=False): Returns ------- - output : relay.Expr + unique : relay.Expr A 1-D tensor containing the unique elements of the input data tensor. indices : relay.Expr A 1-D tensor containing the index of each data element in the output tensor. + inverse_indices : relay.Expr + A 1-D tensor. For each entry in data, it contains the index of that data element in the + unique array. + num_unique : relay.Expr A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. @@ -1698,5 +1702,5 @@ def unique(data, is_sorted=True, return_counts=False): num_unique = [5] """ if return_counts: - return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) - return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3) + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5) + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py index 2bca3c447c4c..911ee71a0057 100644 --- a/python/tvm/topi/cuda/unique.py +++ b/python/tvm/topi/cuda/unique.py @@ -119,7 +119,7 @@ def _calc_num_unique(inc_scan): def _calc_unique_ir( - data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts + data, argsorted_indices, inc_scan, index_converter, unique_elements, inverse_indices, counts ): """Low level IR to calculate unique elements, inverse indices, and counts (optional) of unique elements of 1-D array. @@ -143,7 +143,7 @@ def _calc_unique_ir( unique_elements : Buffer A buffer that stores the unique elements. - indices : Buffer + inverse_indices : Buffer A buffer that stores the the index of each input data element in the unique element array. counts (optional) : Buffer @@ -154,7 +154,7 @@ def _calc_unique_ir( argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) unique_elements_ptr = ib.buffer_ptr(unique_elements) - indices_ptr = ib.buffer_ptr(indices) + inverse_indices_ptr = ib.buffer_ptr(inverse_indices) index_converter_ptr = None if isinstance(index_converter, tir.Buffer): @@ -163,7 +163,7 @@ def _calc_unique_ir( if isinstance(counts, tir.Buffer): counts_ptr = ib.buffer_ptr(counts) # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] - unique_seq_indices_ptr = ib.buffer_ptr(indices) + unique_seq_indices_ptr = ib.buffer_ptr(inverse_indices) batch_size = data.shape[0] max_threads = _get_max_threads(batch_size) @@ -218,7 +218,7 @@ def _calc_unique_ir( if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[tid]] ) - indices_ptr[data_idx] = unique_idx + inverse_indices_ptr[data_idx] = unique_idx with ib.if_scope(tid == 0): unique_elements_ptr[unique_idx] = data_ptr[data_idx] with ib.else_scope(): @@ -293,11 +293,20 @@ def unique(data, is_sorted=True, return_counts=False): Returns ------- - output : tvm.te.Tensor - A 1-D tensor containing the unique elements of the input data tensor. + unique : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. The same size as + the input data. If there are less unique elements than input data, the end of the tensor + is padded with zeros. indices : tvm.te.Tensor - A 1-D tensor containing the index of each data element in the output tensor. + A 1-D tensor. The same size as output. For each entry in output, it contains + the index of its first occurence in the input data. The end of the tensor is padded + with the length of the input data. + + inverse_indices : tvm.te.Tensor + A 1-D tensor. For each entry in data, it contains the index of that data element in the + unique array. (Note that inverse_indices is very similar to indices if output is not + sorted) num_unique : tvm.te.Tensor A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. @@ -309,20 +318,23 @@ def unique(data, is_sorted=True, return_counts=False): -------- .. code-block:: python [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] - counts = [2, 2, 1, 1, 2, ?, ?, ?] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) - output = [1, 2, 3, 4, 5, ?, ?, ?] - indices = [3, 4, 0, 1, 2, 2, 3, 4] - num_unique = [5] + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [2, 3, 4, 0, 1, ?, ?, ?] + inverse_indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] """ sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") @@ -355,6 +367,20 @@ def unique(data, is_sorted=True, return_counts=False): out_buffers = [unique_elements_buf, inverse_indices_buf] out_dtypes = [data.dtype, "int32"] # prepare inputs and fcompute + # calculate first occurence + first_occurence_buf = tir.decl_buffer( + data.shape, "int32", "first_occurence_buf", data_alignment=8 + ) + first_occurence = te.extern( + [data.shape], + [argsorted_indices, inc_scan], + lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]), + dtype=["int32"], + in_buffers=[argsorted_indices_buf, inc_scan_buf], + out_buffers=[first_occurence_buf], + name="_calc_first_occurence", + tag="_calc_first_occurence_gpu", + ) if is_sorted: in_data = [data, argsorted_indices, inc_scan] in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf] @@ -362,22 +388,8 @@ def unique(data, is_sorted=True, return_counts=False): fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) + indices = first_occurence else: - # calculate the index converter if the unique elements should not be sorted - # calculate first occurence - first_occurence_buf = tir.decl_buffer( - data.shape, "int32", "first_occurence_buf", data_alignment=8 - ) - first_occurence = te.extern( - [data.shape], - [argsorted_indices, inc_scan], - lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]), - dtype=["int32"], - in_buffers=[argsorted_indices_buf, inc_scan_buf], - out_buffers=[first_occurence_buf], - name="_calc_first_occurence", - tag="_calc_first_occurence_gpu", - ) # calculate index converter by sorting unique elements by their first occurence argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") @@ -390,6 +402,7 @@ def unique(data, is_sorted=True, return_counts=False): fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + indices = sort(first_occurence) outs = te.extern( out_data_shape, in_data, @@ -401,5 +414,5 @@ def unique(data, is_sorted=True, return_counts=False): tag="_calc_unique_gpu", ) if return_counts: - return [outs[0], outs[1], num_unique_elements, outs[2]] - return [*outs, num_unique_elements] + return [outs[0], indices, outs[1], num_unique_elements, outs[2]] + return [outs[0], indices, outs[1], num_unique_elements] diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index e7256551d7b6..49869c2ecda4 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -93,7 +93,7 @@ def _calc_num_unique(inc_scan): def _calc_unique_ir( - data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts + data, argsorted_indices, inc_scan, index_converter, unique_elements, inverse_indices, counts ): """Low level IR to calculate unique elements, inverse indices, and counts (optional) of unique elements of 1-D array. @@ -117,7 +117,7 @@ def _calc_unique_ir( unique_elements : Buffer A buffer that stores the unique elements. - indices : Buffer + inverse_indices : Buffer A buffer that stores the the index of each input data element in the unique element array. counts (optional) : Buffer @@ -128,7 +128,7 @@ def _calc_unique_ir( argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) unique_elements_ptr = ib.buffer_ptr(unique_elements) - indices_ptr = ib.buffer_ptr(indices) + inverse_indices_ptr = ib.buffer_ptr(inverse_indices) index_converter_ptr = None if isinstance(index_converter, tir.Buffer): @@ -137,7 +137,7 @@ def _calc_unique_ir( if isinstance(counts, tir.Buffer): counts_ptr = ib.buffer_ptr(counts) # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] - unique_seq_indices_ptr = ib.buffer_ptr(indices) + unique_seq_indices_ptr = ib.buffer_ptr(inverse_indices) data_length = data.shape[0] @@ -167,7 +167,7 @@ def _calc_unique_ir( unique_idx = ( inc_scan_ptr[i] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[i]] ) - indices_ptr[data_idx] = unique_idx + inverse_indices_ptr[data_idx] = unique_idx with ib.if_scope(i == 0): unique_elements_ptr[unique_idx] = data_ptr[data_idx] with ib.else_scope(): @@ -219,11 +219,20 @@ def unique(data, is_sorted=True, return_counts=False): Returns ------- - output : tvm.te.Tensor - A 1-D tensor containing the unique elements of the input data tensor. + unique : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. The same size as + the input data. If there are less unique elements than input data, the end of the tensor + is padded with zeros. indices : tvm.te.Tensor - A 1-D tensor containing the index of each data element in the output tensor. + A 1-D tensor. The same size as output. For each entry in output, it contains + the index of its first occurence in the input data. The end of the tensor is padded + with the length of the input data. + + inverse_indices : tvm.te.Tensor + A 1-D tensor. For each entry in data, it contains the index of that data element in + the unique array. (Note that inverse_indices is very similar to indices if output is not + sorted.) num_unique : tvm.te.Tensor A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. @@ -235,20 +244,23 @@ def unique(data, is_sorted=True, return_counts=False): -------- .. code-block:: python [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) - output = [4, 5, 1, 2, 3, ?, ?, ?] - indices = [0, 1, 2, 3, 4, 4, 0, 1] - num_unique = [5] - counts = [2, 2, 1, 1, 2, ?, ?, ?] + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, ?, ?, ?] + inverse_indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) - output = [1, 2, 3, 4, 5, ?, ?, ?] - indices = [3, 4, 0, 1, 2, 2, 3, 4] - num_unique = [5] + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [2, 3, 4, 0, 1, ?, ?, ?] + inverse_indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] """ sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") @@ -266,16 +278,17 @@ def unique(data, is_sorted=True, return_counts=False): out_data_shape = [data.shape] * 2 out_dtypes = [data.dtype, "int32"] # prepare inputs and fcompute + + first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) if is_sorted: in_data = [data, argsorted_indices, inc_scan] if return_counts: fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) + + indices = first_occurence else: - # calculate the index converter if the unique elements should not be sorted - # calculate first occurence - first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) # calculate index converter by sorting unique elements by their first occurence argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") @@ -284,6 +297,10 @@ def unique(data, is_sorted=True, return_counts=False): fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) else: fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + # First occurence is in order of sorted unique output, if we sort the first_occurence array + # we get the correct result + indices = sort(first_occurence) + outs = te.extern( out_data_shape, in_data, @@ -293,5 +310,5 @@ def unique(data, is_sorted=True, return_counts=False): tag="_calc_unique_cpu", ) if return_counts: - return [outs[0], outs[1], num_unique_elements, outs[2]] - return [*outs, num_unique_elements] + return [outs[0], indices, outs[1], num_unique_elements, outs[2]] + return [outs[0], indices, outs[1], num_unique_elements] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 10fe5e543dfc..d6caefbb4e2c 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3976,10 +3976,11 @@ bool UniqueRel(const Array& types, int num_inputs, const Attrs& attrs, } const int ndim = static_cast(data->shape.size()); ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor"; - ICHECK_EQ(data->dtype.is_int(), true) << "Unique: input must have int32 or int64 dtype"; + std::vector fields; fields.push_back(TensorType(data->shape, data->dtype)); // unique fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices + fields.push_back(TensorType(data->shape, DataType::Int(32))); // inverse_indices fields.push_back(TensorType(Array{1}, DataType::Int(32))); // num_unique const auto* param = attrs.as(); if (param->return_counts) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d1ecfc5559a4..4ac7ff2a81f3 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4277,11 +4277,9 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_tfidfvectorizer_tf_onlybigrams_levelempty/", "test_tfidfvectorizer_tf_onlybigrams_skip5/", "test_tfidfvectorizer_tf_uniandbigrams_skip5/", - "test_unique_not_sorted_without_axis/", "test_unique_sorted_with_axis/", "test_unique_sorted_with_axis_3d/", "test_unique_sorted_with_negative_axis/", - "test_unique_sorted_without_axis/", "test_upsample_nearest/", ] diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 07955943e341..fc67f0b90295 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1959,7 +1959,14 @@ def calc_numpy_unique(data, is_sorted=False): uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + index = np.sort(index) # In unsorted case, need to sort the index of first occurence + return [ + uniq.astype(data.dtype), + index.astype("int32"), + inverse.astype("int32"), + num_uniq, + counts, + ] def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): if is_dyn: @@ -1980,18 +1987,26 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - tvm_res = intrp.evaluate()(x_data) - np_res = calc_numpy_unique(x_data, is_sorted) + tvm_res = intrp.evaluate()( + x_data + ) # unique, indices, inverse_indices, num_unique, (counts) + np_res = calc_numpy_unique( + x_data, is_sorted + ) # unique, indices, inverse_indices, num_unique, counts num_unique = np_res[3][0] - assert num_unique == tvm_res[2].numpy()[0] + + # num_unique + assert num_unique == tvm_res[3].numpy()[0] # unique tvm.testing.assert_allclose(tvm_res[0].numpy()[:num_unique], np_res[0], rtol=1e-5) + # indices + tvm.testing.assert_allclose(tvm_res[1].numpy()[:num_unique], np_res[1], rtol=1e-5) # inverse_indices - tvm.testing.assert_allclose(tvm_res[1].numpy(), np_res[1], rtol=1e-5) + tvm.testing.assert_allclose(tvm_res[2].numpy(), np_res[2], rtol=1e-5) # counts if return_counts: tvm.testing.assert_allclose( - tvm_res[3].numpy()[:num_unique], np_res[2], rtol=1e-5 + tvm_res[4].numpy()[:num_unique], np_res[4], rtol=1e-5 ) for dtype in ["int32", "int64"]: diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index 032b4db73918..3e26241cea94 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -30,15 +30,24 @@ def calc_numpy_unique(data, is_sorted=False): num_uniq = np.array([len(uniq)]).astype("int32") if not is_sorted: order = np.argsort(index) + index = np.sort(index) reverse_order = np.argsort(order) uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + return [ + uniq.astype(data.dtype), + index.astype("int32"), + inverse.astype("int32"), + counts, + num_uniq, + ] - def check_unique(data, is_sorted=False): + def check_unique(data, is_sorted=False, with_counts=False): # numpy reference - np_unique, np_indices, np_counts, np_num_unique = calc_numpy_unique(data, is_sorted) + np_unique, np_indices, np_inverse_indices, np_counts, np_num_unique = calc_numpy_unique( + data, is_sorted + ) num_unique = np_num_unique[0] implementations = { @@ -59,44 +68,54 @@ def check_unique(data, is_sorted=False): tvm_data = tvm.nd.array(data, device=dev) tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), device=dev) tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) + tvm_inverse_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), device=dev) - # without counts with tvm.target.Target(target): te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) - outs = fcompute(te_input, False) + outs = fcompute(te_input, with_counts) s = fschedule(outs) func = tvm.build(s, [te_input, *outs]) - func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique) - assert tvm_num_unique.numpy()[0] == np_num_unique - np.testing.assert_allclose(tvm_unique.numpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(tvm_indices.numpy(), np_indices, atol=1e-5, rtol=1e-5) + if with_counts: + tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) + func( + tvm_data, + tvm_unique, + tvm_indices, + tvm_inverse_indices, + tvm_num_unique, + tvm_counts, + ) + else: + func(tvm_data, tvm_unique, tvm_indices, tvm_inverse_indices, tvm_num_unique) - # with counts - tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), device=dev) - with tvm.target.Target(target): - te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) - outs = fcompute(te_input, True) - s = fschedule(outs) - func = tvm.build(s, [te_input, *outs]) - func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique, tvm_counts) - - np_unique, np_indices, _, np_num_unique = calc_numpy_unique(data, is_sorted) num_unique = np_num_unique[0] assert tvm_num_unique.numpy()[0] == np_num_unique + np.testing.assert_allclose(tvm_unique.numpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(tvm_indices.numpy(), np_indices, atol=1e-5, rtol=1e-5) - np.testing.assert_allclose(tvm_counts.numpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + tvm_indices.numpy()[:num_unique], np_indices, atol=1e-5, rtol=1e-5 + ) + + np.testing.assert_allclose( + tvm_inverse_indices.numpy(), np_inverse_indices, atol=1e-5, rtol=1e-5 + ) + + if with_counts: + np.testing.assert_allclose( + tvm_counts.numpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5 + ) for in_dtype in ["int32", "int64"]: for is_sorted in [True, False]: - data = np.random.randint(0, 100, size=(1)).astype(in_dtype) - check_unique(data, is_sorted) - data = np.random.randint(0, 10, size=(10)).astype(in_dtype) - check_unique(data, is_sorted) - data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) - check_unique(data, is_sorted) + for with_counts in [True, False]: + data = np.random.randint(0, 100, size=(1)).astype(in_dtype) + check_unique(data, is_sorted, with_counts) + data = np.random.randint(0, 10, size=(10)).astype(in_dtype) + check_unique(data, is_sorted, with_counts) + data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) + check_unique(data, is_sorted, with_counts) if __name__ == "__main__":