Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Add NonZero op #17453

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,6 +2394,14 @@ def _impl_v11(cls, bb, inputs, attr, params):
return relax.op.unique(data, sorted=sorted, axis=axis)


class NonZero(OnnxOpConverter):
"""Converts an onnx NonZero node into an equivalent Relax expression."""

@classmethod
def _impl_v9(cls, bb, inputs, attr, params):
return relax.op.nonzero(inputs[0])


class HardSigmoid(OnnxOpConverter):
"""Converts an onnx HardSigmoid node into an equivalent Relax expression."""

Expand Down Expand Up @@ -2779,7 +2787,7 @@ def _get_convert_map():
"Range": Range,
"OneHot": OneHot,
"Unique": Unique,
# "NonZero": NonZero,
"NonZero": NonZero,
# "If": If,
# "LRN": LRN,
# "MaxRoiPool": MaxRoiPool,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
from .qdq import dequantize, quantize
from .sampling import multinomial_from_uniform
from .search import argmax, argmin, where
from .set import unique
from .set import nonzero, unique
from .sorting import argsort, sort, topk
from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance
from .ternary import ewise_fma
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relax/op/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,40 @@ def numpy_unique(
return tvm.nd.array(output_sorted_numpy)
output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis)
return tvm.nd.array(output_numpy)


def nonzero(x: Expr) -> Expr:
"""Find the indices of elements of a tensor that are non-zero.

Parameters
----------
x : relax.Expr
The input data tensor.

Returns
-------
result : relax.Expr
A (n+1)-D tensor containing indices of non-zero elements.

Note
----
This function is equivalent to `onnx.nonzero`.

Examples
--------

.. code-block:: python

x = [[0, 1],
[2, 0]]
nonzero(x) = [[0, 1],
[1, 0]]

"""
return _ffi_api.nonzero(x) # type: ignore


@tvm.register_func("relax.run.nonzero")
def numpy_nonzero(x: tvm.nd.array) -> tvm.nd.array:
np_result = np.atleast_1d(x.numpy()).nonzero()
return tvm.nd.array(np.stack(np_result, axis=0))
23 changes: 23 additions & 0 deletions src/relax/op/tensor/set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "set.h"

#include <algorithm>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -137,5 +138,27 @@ TVM_REGISTER_OP("relax.unique")
.set_attr<FCallPacked>("FCallPacked", "relax.run.unique")
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nonzero */
Expr nonzero(Expr x) {
static const Op& op = Op::Get("relax.nonzero");
return Call(op, {std::move(x)});
}

TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero);

StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
// Cheat zero dim scalar as 1-dim.
int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1;
return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.nonzero")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNonzero)
.set_attr<FCallPacked>("FCallPacked", "relax.run.nonzero")
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
28 changes: 28 additions & 0 deletions src/relax/op/tensor/set.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,36 @@
namespace tvm {
namespace relax {

/*!
* \brief Find the unique elements in a given tensor.
* In addition, it optionally returns
* - the indices of the input tensor that give the unique values;
* - the indices of the unique tensor that reconstruct the input tensor;
* - the number of times each unique value comes up in the input tensor.
* \param x The input tensor.
* \param sorted Whether to sort the unique elements in ascending order before
* returning as output.
* \param return_index Whether to return an additional tensor with indices for where elements in
* the unique tensor come from the original input.
* \param return_inverse Whether to return an additional tensor with indices for where elements in
* the original input ended up in the returned unique list.
* \param return_counts Whether to return an additional tensor with counts of each unique elements.
* \param axis The dimension to apply unique.
* If not specified, the unique values of the flattened input are returned.
* \return The unique elements of the array. The returned array will be sorted if `sorted` is True.
* Additional return values depend on `return_index`, `return_inverse`, and `return_counts`.
*/
Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse,
PrimValue return_counts, Optional<PrimValue> axis);

/*!
* \brief Returns the indices of the non-zero elements of the input tensor.
* \param x The input tensor.
* \return a list of 1-D tensors containing indices of non-zero elements for each dimension.
* \note This function behaves similarly to numpy.nonzero(), but return a multi-dimensional array
* instead of a tuple of 1-D arrays.
*/
Expr nonzero(Expr x);
} // namespace relax
} // namespace tvm

Expand Down
5 changes: 5 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,11 @@ def test_unique(axis: Optional[int], sorted: int):
check_correctness(model)


@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)])
def test_nonzero(shape):
verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64)


@pytest.mark.parametrize("mode", ["DCR", "CRD"])
def test_depth_to_space(mode: Literal["DCR", "CRD"]):
in_shape = [1, 8, 2, 3]
Expand Down
34 changes: 34 additions & 0 deletions tests/python/relax/test_op_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,5 +867,39 @@ def test_unique_infer_struct_info_wrong_input_dtype():
bb.normalize(relax.op.unique(x1))


@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)])
def test_nonzero_infer_struct_info(shape):
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor(shape, "bool"))

_check_inference(
bb,
relax.op.nonzero(x0),
relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"),
)


def test_nonzero_infer_struct_info_ndim_zero():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((), "bool"))

_check_inference(
bb,
relax.op.nonzero(x),
relax.TensorStructInfo(ndim=2, dtype="int64"),
)


def test_nonzero_infer_struct_info_wrong_input_dtype():
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4)))
x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32")))

with pytest.raises(TVMError):
bb.normalize(relax.op.nonzero(x0))
with pytest.raises(TVMError):
bb.normalize(relax.op.nonzero(x1))


if __name__ == "__main__":
tvm.testing.main()
Loading