From 4721a46feb9f7523f1be24aed7a6e2348941ee07 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 8 Oct 2024 16:41:22 +0800 Subject: [PATCH] [Relax] Add NonZero op this PR adds the NonZero op to Relax, together with ONNX frontend support --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 10 ++++- python/tvm/relax/op/__init__.py | 2 +- python/tvm/relax/op/set.py | 37 +++++++++++++++++++ src/relax/op/tensor/set.cc | 23 ++++++++++++ src/relax/op/tensor/set.h | 28 ++++++++++++++ tests/python/relax/test_frontend_onnx.py | 5 +++ tests/python/relax/test_op_set.py | 34 +++++++++++++++++ 7 files changed, 137 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 36a7823f8655..dcdfa413f2bb 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2394,6 +2394,14 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.unique(data, sorted=sorted, axis=axis) +class NonZero(OnnxOpConverter): + """Converts an onnx NonZero node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.nonzero(inputs[0]) + + class HardSigmoid(OnnxOpConverter): """Converts an onnx HardSigmoid node into an equivalent Relax expression.""" @@ -2779,7 +2787,7 @@ def _get_convert_map(): "Range": Range, "OneHot": OneHot, "Unique": Unique, - # "NonZero": NonZero, + "NonZero": NonZero, # "If": If, # "LRN": LRN, # "MaxRoiPool": MaxRoiPool, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 4581defa1a77..58d9fa08f043 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -99,7 +99,7 @@ from .qdq import dequantize, quantize from .sampling import multinomial_from_uniform from .search import argmax, argmin, where -from .set import unique +from .set import nonzero, unique from .sorting import argsort, sort, topk from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance from .ternary import ewise_fma diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index 0b86e19ce53f..c5db852ddd5d 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -110,3 +110,40 @@ def numpy_unique( return tvm.nd.array(output_sorted_numpy) output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis) return tvm.nd.array(output_numpy) + + +def nonzero(x: Expr) -> Expr: + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + x : relax.Expr + The input data tensor. + + Returns + ------- + result : relax.Expr + A (n+1)-D tensor containing indices of non-zero elements. + + Note + ---- + This function is equivalent to `onnx.nonzero`. + + Examples + -------- + + .. code-block:: python + + x = [[0, 1], + [2, 0]] + nonzero(x) = [[0, 1], + [1, 0]] + + """ + return _ffi_api.nonzero(x) # type: ignore + + +@tvm.register_func("relax.run.nonzero") +def numpy_nonzero(x: tvm.nd.array) -> tvm.nd.array: + np_result = np.atleast_1d(x.numpy()).nonzero() + return tvm.nd.array(np.stack(np_result, axis=0)) diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 29d9d52c6077..c659a49afd12 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -24,6 +24,7 @@ #include "set.h" +#include #include #include @@ -137,5 +138,27 @@ TVM_REGISTER_OP("relax.unique") .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", Bool(true)); +/* relax.nonzero */ +Expr nonzero(Expr x) { + static const Op& op = Op::Get("relax.nonzero"); + return Call(op, {std::move(x)}); +} + +TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); + +StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + // Cheat zero dim scalar as 1-dim. + int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1; + return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.nonzero") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoNonzero) + .set_attr("FCallPacked", "relax.run.nonzero") + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h index a5c7ee85bfb2..251dd1975e9f 100644 --- a/src/relax/op/tensor/set.h +++ b/src/relax/op/tensor/set.h @@ -29,8 +29,36 @@ namespace tvm { namespace relax { +/*! + * \brief Find the unique elements in a given tensor. + * In addition, it optionally returns + * - the indices of the input tensor that give the unique values; + * - the indices of the unique tensor that reconstruct the input tensor; + * - the number of times each unique value comes up in the input tensor. + * \param x The input tensor. + * \param sorted Whether to sort the unique elements in ascending order before + * returning as output. + * \param return_index Whether to return an additional tensor with indices for where elements in + * the unique tensor come from the original input. + * \param return_inverse Whether to return an additional tensor with indices for where elements in + * the original input ended up in the returned unique list. + * \param return_counts Whether to return an additional tensor with counts of each unique elements. + * \param axis The dimension to apply unique. + * If not specified, the unique values of the flattened input are returned. + * \return The unique elements of the array. The returned array will be sorted if `sorted` is True. + * Additional return values depend on `return_index`, `return_inverse`, and `return_counts`. + */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, PrimValue return_counts, Optional axis); + +/*! + * \brief Returns the indices of the non-zero elements of the input tensor. + * \param x The input tensor. + * \return a list of 1-D tensors containing indices of non-zero elements for each dimension. + * \note This function behaves similarly to numpy.nonzero(), but return a multi-dimensional array + * instead of a tuple of 1-D arrays. + */ +Expr nonzero(Expr x); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index f2bbd3f3f585..cf107a63b4d8 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2126,6 +2126,11 @@ def test_unique(axis: Optional[int], sorted: int): check_correctness(model) +@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)]) +def test_nonzero(shape): + verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64) + + @pytest.mark.parametrize("mode", ["DCR", "CRD"]) def test_depth_to_space(mode: Literal["DCR", "CRD"]): in_shape = [1, 8, 2, 3] diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py index 741d7869d52f..e9070f99fc3f 100644 --- a/tests/python/relax/test_op_set.py +++ b/tests/python/relax/test_op_set.py @@ -867,5 +867,39 @@ def test_unique_infer_struct_info_wrong_input_dtype(): bb.normalize(relax.op.unique(x1)) +@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)]) +def test_nonzero_infer_struct_info(shape): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor(shape, "bool")) + + _check_inference( + bb, + relax.op.nonzero(x0), + relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_ndim_zero(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((), "bool")) + + _check_inference( + bb, + relax.op.nonzero(x), + relax.TensorStructInfo(ndim=2, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_wrong_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x1)) + + if __name__ == "__main__": tvm.testing.main()