Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX][TOPI] Support select_last_index for argmin/max #8816

Merged
merged 53 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
53d5ada
support select_last_index for argmin/max
AndrewZhaoLuo Aug 23, 2021
95a6517
reverse conditions which made on accident
AndrewZhaoLuo Aug 23, 2021
5e1f06a
forward args in reduce.py
AndrewZhaoLuo Aug 23, 2021
962e38a
make proper nodes for reduction ops
AndrewZhaoLuo Aug 23, 2021
f92089b
remove complicated nested lambdas
AndrewZhaoLuo Aug 23, 2021
9edc8e6
fix lambda capture for conversion
AndrewZhaoLuo Aug 23, 2021
9e4e69a
forward more arguments
AndrewZhaoLuo Aug 23, 2021
5cf4772
forward more args
AndrewZhaoLuo Aug 23, 2021
4f5a662
enable onnx tests
AndrewZhaoLuo Aug 23, 2021
75cb608
wrapping casts to remove ambiguity
Aug 24, 2021
7a60353
revert changes extraneous
AndrewZhaoLuo Aug 24, 2021
6a9e82f
correct incorrect attrs being used for ops
AndrewZhaoLuo Aug 24, 2021
0fb5db5
change attributes
AndrewZhaoLuo Aug 24, 2021
55a412d
remove old impl
Aug 24, 2021
93173fc
register new attribute node
AndrewZhaoLuo Aug 24, 2021
47b7eed
clean up test
AndrewZhaoLuo Aug 24, 2021
e62513b
reformat
AndrewZhaoLuo Aug 24, 2021
e9ea784
reformat
AndrewZhaoLuo Aug 24, 2021
587e94a
coolio
AndrewZhaoLuo Aug 24, 2021
d048e25
stable comparison
AndrewZhaoLuo Aug 24, 2021
71ab1f3
casts to avoid ambiguity
AndrewZhaoLuo Aug 24, 2021
aecf630
casting more
AndrewZhaoLuo Aug 24, 2021
423d092
correct arg passing
AndrewZhaoLuo Aug 26, 2021
2faf06d
support select_last_index for argmin/max
AndrewZhaoLuo Aug 23, 2021
edbc0f1
reverse conditions which made on accident
AndrewZhaoLuo Aug 23, 2021
ba7f57c
forward args in reduce.py
AndrewZhaoLuo Aug 23, 2021
dbf6dc1
make proper nodes for reduction ops
AndrewZhaoLuo Aug 23, 2021
fa4dd43
remove complicated nested lambdas
AndrewZhaoLuo Aug 23, 2021
78cc734
fix lambda capture for conversion
AndrewZhaoLuo Aug 23, 2021
0979f4d
forward more arguments
AndrewZhaoLuo Aug 23, 2021
647413e
forward more args
AndrewZhaoLuo Aug 23, 2021
f694e58
enable onnx tests
AndrewZhaoLuo Aug 23, 2021
576c56b
wrapping casts to remove ambiguity
Aug 24, 2021
67b5762
revert changes extraneous
AndrewZhaoLuo Aug 24, 2021
6d59d1c
correct incorrect attrs being used for ops
AndrewZhaoLuo Aug 24, 2021
d7a595f
change attributes
AndrewZhaoLuo Aug 24, 2021
6b645de
remove old impl
Aug 24, 2021
0faf5b6
register new attribute node
AndrewZhaoLuo Aug 24, 2021
96d85c2
clean up test
AndrewZhaoLuo Aug 24, 2021
8a6a4bc
reformat
AndrewZhaoLuo Aug 24, 2021
29a2660
reformat
AndrewZhaoLuo Aug 24, 2021
3a2a38d
coolio
AndrewZhaoLuo Aug 24, 2021
296ac2e
stable comparison
AndrewZhaoLuo Aug 24, 2021
12f7213
casts to avoid ambiguity
AndrewZhaoLuo Aug 24, 2021
20cdd36
casting more
AndrewZhaoLuo Aug 24, 2021
49b6322
correct arg passing
AndrewZhaoLuo Aug 26, 2021
fcc420e
Merge branch 'aluo/onnx/argmin_and_argmax' of github.com:AndrewZhaoLu…
AndrewZhaoLuo Aug 27, 2021
8f37f89
fix broken input
AndrewZhaoLuo Aug 27, 2021
2db29ca
OneElementReduceAttrs-->ArgReduceAttrs"
Aug 30, 2021
4055190
reduce boilerplate
Aug 30, 2021
1f56147
change names
Aug 30, 2021
d4cbfcc
remove log statement
Aug 30, 2021
c5f308b
jostle ci
AndrewZhaoLuo Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
}
};

/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */
struct ArgReduceAttrs : public tvm::AttrsNode<ArgReduceAttrs> {
Array<Integer> axis;
bool keepdims;
bool select_last_index;
bool exclude;

TVM_DECLARE_ATTRS(ArgReduceAttrs, "relay.attrs.ArgReduceAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.

The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.

If `axis` is int, a reduction is performed on a particular axis.

If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.

If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");

TVM_ATTR_FIELD(keepdims).set_default(false).describe(
"If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(select_last_index)
.set_default(false)
.describe(
"Whether to select the last index if the target element appears multiple times, else "
"select the first index which the target element appears");
TVM_ATTR_FIELD(exclude).set_default(false).describe(
"Whether to perform reduction on axis that are NOT in axis instead.");
}
};

struct VarianceAttrs : public tvm::AttrsNode<VarianceAttrs> {
Array<Integer> axis;
bool keepdims;
Expand Down
100 changes: 77 additions & 23 deletions include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,45 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
}

inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
// Create a Commutative Reducer with a comparison operation, and method to get the initial value.
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;

// Casting to avoid operator ambiguity
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);

// These variables compare the actual values of the array
auto is_smaller = lhs_val < rhs_val;
auto is_same = lhs_val == rhs_val;

// This checks if the indices are correct for the reduction. E.g. for select_last_index
// it gives precedence for later indices of the same element and precedence for sooner
// indices if not select_last_index;
PrimExpr proper_index;
if (select_last_index) {
proper_index = lhs_idx > rhs_idx;
} else {
proper_index = lhs_idx < rhs_idx;
}

PrimExpr update_index = is_smaller || (is_same && proper_index);
result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [&](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::max_value(types[1])); // val
return result;
};
return MakeCommReducer(fcombine, fidentity, "argmin");
}

/*!
* \brief Creates an operation that finds the indices of the minimum
* values over a given axis.
Expand All @@ -442,35 +481,49 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* \param select_last_index Whether to select the last index if the minimum element
* appears multiple times, else select the first index.
*
* \return A Tensor whose op member is the argmin operation
*/
inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::max_value(types[1])); // val
return result;
};
auto func = MakeCommReducer(fcombine, fidentity, "argmin");
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
bool atleast1d = false, bool select_last_index = false) {
auto reducer = MakeArgminReducer(select_last_index);
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}

inline FCommReduce MakeArgmaxReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
// Create a Commutative Reducer with a comparison operation, and method to get the initial value.
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val

// Casting to avoid operator ambiguity
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);

// These variables compare the actual values of the array
auto is_bigger = lhs_val > rhs_val;
auto is_same = lhs_val == rhs_val;

// This checks if the indices are correct for the reduction. E.g. for select_last_index
// it gives precedence for later indices of the same element and precedence for sooner
// indices if not select_last_index;
PrimExpr proper_index;
if (select_last_index) {
proper_index = lhs_idx > rhs_idx;
} else {
proper_index = lhs_idx < rhs_idx;
}

PrimExpr update_index = is_bigger || (is_same && proper_index);
result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
LOG(WARNING) << result;
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
return result;
};
auto fidentity = [](std::vector<DataType> types) {
auto fidentity = [&](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::min_value(types[1])); // val
Expand All @@ -490,12 +543,13 @@ inline FCommReduce MakeArgmaxReducer() {
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
*
* \param select_last_index Whether to select the last index if the maximum element
* appears multiple times, else select the first index.
* \return A Tensor whose op member is the argmax operation
*/
inline Tensor argmax(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto reducer = MakeArgmaxReducer();
bool atleast1d = false, bool select_last_index = false) {
auto reducer = MakeArgmaxReducer(select_last_index);
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}

Expand Down
20 changes: 9 additions & 11 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@
from .. import loops as _loops
from .. import op as _op
from .. import qnn as _qnn
from .. import random as _random
from .. import ty as _ty
from .. import vision as _vision
from .. import random as _random
from .common import (
AttrCvt,
Renamer,
fold_constant,
get_name,
get_relay_op,
gru_cell,
infer_channels,
infer_shape,
infer_type,
infer_value,
lstm_cell,
new_var,
unbind,
gru_cell,
lstm_cell,
)

__all__ = ["from_onnx"]
Expand Down Expand Up @@ -1824,25 +1824,23 @@ class ArgMax(OnnxOpConverter):
"""Operator converter for ArgMax."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMax")
def _impl_v13(cls, inputs, attr, params):
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
select_last_index = attr.get("select_last_index", False)
attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index}
return _op.cast(AttrCvt("argmax")(inputs, attr), "int64")


class ArgMin(OnnxOpConverter):
"""Operator converter for ArgMin."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMin")
def _impl_v13(cls, inputs, attr, params):
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
select_last_index = attr.get("select_last_index", False)
attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index}
return _op.cast(AttrCvt("argmin")(inputs, attr), "int64")


Expand Down
20 changes: 14 additions & 6 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"""Reduce operators."""
# pylint: disable=redefined-builtin

from ..expr import Tuple, TupleWrapper
from . import _make
from .tensor import sqrt, log, exp
from .tensor import exp, log, sqrt
from .transform import squeeze
from ..expr import Tuple, TupleWrapper


def argmax(data, axis=None, keepdims=False, exclude=False):
def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
"""Returns the indices of the maximum values along an axis.

Parameters
Expand All @@ -45,16 +45,20 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

select_last_index : bool
Whether to select the last index or the first index if the max element appears in
multiple indices, default is False (first index).

Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.argmax(data, axis, keepdims, exclude)
return _make.argmax(data, axis, keepdims, exclude, select_last_index)


def argmin(data, axis=None, keepdims=False, exclude=False):
def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
"""Returns the indices of the minimum values along an axis.

Parameters
Expand All @@ -76,13 +80,17 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

select_last_index : bool
Whether to select the last index or the first index if the min element appears in
multiple indices, default is False (first index).

Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.argmin(data, axis, keepdims, exclude)
return _make.argmin(data, axis, keepdims, exclude, select_last_index)


def sum(data, axis=None, keepdims=False, exclude=False):
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/topi/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False):
return cpp.min(data, axis, keepdims)


def argmax(data, axis=None, keepdims=False):
def argmax(data, axis=None, keepdims=False, select_last_index=False):
"""Returns the indices of the maximum values along an axis.

Parameters
Expand All @@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False):
with size one.
With this option, the result will broadcast correctly against the input array.

select_last_index: bool
Whether to select the last index if the maximum element appears multiple times, else
select the first index.

Returns
-------
ret : tvm.te.Tensor
"""
return cpp.argmax(data, axis, keepdims)
return cpp.argmax(data, axis, keepdims, select_last_index)


def argmin(data, axis=None, keepdims=False):
def argmin(data, axis=None, keepdims=False, select_last_index=False):
"""Returns the indices of the minimum values along an axis.

Parameters
Expand All @@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False):
with size one.
With this option, the result will broadcast correctly against the input array.

select_last_index: bool
Whether to select the last index if the minimum element appears multiple times, else
select the first index.

Returns
-------
ret : tvm.te.Tensor
"""
return cpp.argmin(data, axis, keepdims)
return cpp.argmin(data, axis, keepdims, select_last_index)


def prod(data, axis=None, keepdims=False):
Expand Down
Loading