Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support np.argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 30, 2019
1 parent 5fb2916 commit 2319a3e
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 9 deletions.
78 changes: 77 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman',
Expand Down Expand Up @@ -796,6 +796,82 @@ def power(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)


@set_module('mxnet.ndarray.numpy')
def argsort(a, axis=-1, kind=None, order=None):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified
by the `kind` keyword. It returns an array of indices of the same shape as
`a` that index data along the given axis in sorted order.
Parameters
----------
a : ndarray
Array to sort.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
Returns
-------
index_array : ndarray, int
Array of indices that sort `a` along the specified `axis`.
If `a` is one-dimensional, ``a[index_array]`` yields a sorted `a`.
More generally, ``np.take_along_axis(a, index_array, axis=axis)``
always yields the sorted `a`, irrespective of dimensionality.
Notes
-----
This operator does not support different sorting algorithms.
Examples
--------
One dimensional array:
>>> x = np.array([3, 1, 2])
>>> np.argsort(x)
array([1, 2, 0])
Two-dimensional array:
>>> x = np.array([[0, 3], [2, 2]])
>>> x
array([[0, 3],
[2, 2]])
>>> ind = np.argsort(x, axis=0) # sorts along first axis (down)
>>> ind
array([[0, 1],
[1, 0]])
>>> np.take_along_axis(x, ind, axis=0) # same as np.sort(x, axis=0)
array([[0, 2],
[2, 3]])
>>> ind = np.argsort(x, axis=1) # sorts along last axis (across)
>>> ind
array([[0, 1],
[0, 1]])
>>> np.take_along_axis(x, ind, axis=1) # same as np.sort(x, axis=1)
array([[0, 3],
[2, 2]])
Indices of the sorted elements of a N-dimensional array:
>>> ind = np.unravel_index(np.argsort(x, axis=None), x.shape)
>>> ind
(array([0, 1, 1, 0]), array([0, 0, 1, 1]))
>>> x[ind] # same as np.sort(x, axis=None)
array([0, 2, 2, 3])
"""
if order is not None:
raise NotImplementedError("order not supported here")

return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')


@set_module('mxnet.ndarray.numpy')
def tensordot(a, b, axes=2):
r"""
Expand Down
79 changes: 76 additions & 3 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
Expand Down Expand Up @@ -1369,13 +1369,13 @@ def topk(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute topk')

def argsort(self, *args, **kwargs):
def argsort(self, axis=-1, kind=None, order=None):
"""Convenience fluent method for :py:func:`argsort`.
The arguments are the same as for :py:func:`argsort`, with
this array as data.
"""
raise NotImplementedError
raise argsort(self, axis=axis, kind=kind, order=order)

def argmax_channel(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax_channel`.
Expand Down Expand Up @@ -4200,6 +4200,79 @@ def arctanh(x, out=None, **kwargs):
return _mx_nd_np.arctanh(x, out=out, **kwargs)


@set_module('mxnet.numpy')
def argsort(a, axis=-1, kind=None, order=None):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified
by the `kind` keyword. It returns an array of indices of the same shape as
`a` that index data along the given axis in sorted order.
Parameters
----------
a : ndarray
Array to sort.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
Returns
-------
index_array : ndarray, int
Array of indices that sort `a` along the specified `axis`.
If `a` is one-dimensional, ``a[index_array]`` yields a sorted `a`.
More generally, ``np.take_along_axis(a, index_array, axis=axis)``
always yields the sorted `a`, irrespective of dimensionality.
Notes
-----
This operator does not support different sorting algorithms.
Examples
--------
One dimensional array:
>>> x = np.array([3, 1, 2])
>>> np.argsort(x)
array([1, 2, 0])
Two-dimensional array:
>>> x = np.array([[0, 3], [2, 2]])
>>> x
array([[0, 3],
[2, 2]])
>>> ind = np.argsort(x, axis=0) # sorts along first axis (down)
>>> ind
array([[0, 1],
[1, 0]])
>>> np.take_along_axis(x, ind, axis=0) # same as np.sort(x, axis=0)
array([[0, 2],
[2, 3]])
>>> ind = np.argsort(x, axis=1) # sorts along last axis (across)
>>> ind
array([[0, 1],
[0, 1]])
>>> np.take_along_axis(x, ind, axis=1) # same as np.sort(x, axis=1)
array([[0, 3],
[2, 2]])
Indices of the sorted elements of a N-dimensional array:
>>> ind = np.unravel_index(np.argsort(x, axis=None), x.shape)
>>> ind
(array([0, 1, 1, 0]), array([0, 0, 1, 1]))
>>> x[ind] # same as np.sort(x, axis=None)
array([0, 2, 2, 3])
"""
return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order)


@set_module('mxnet.numpy')
def tensordot(a, b, axes=2):
r"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'argmin',
'argmax',
'around',
'argsort',
'append',
'broadcast_arrays',
'broadcast_to',
Expand Down
45 changes: 42 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman',
Expand Down Expand Up @@ -484,13 +484,13 @@ def topk(self, *args, **kwargs):
"""
raise AttributeError('_Symbol object has no attribute topk')

def argsort(self, *args, **kwargs):
def argsort(self, axis=-1, kind=None, order=None):
"""Convenience fluent method for :py:func:`argsort`.
The arguments are the same as for :py:func:`argsort`, with
this array as data.
"""
raise NotImplementedError
raise argsort(self, axis=axis, kind=kind, order=order)

def argmax_channel(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax_channel`.
Expand Down Expand Up @@ -1325,6 +1325,45 @@ def lcm(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.lcm, _np.lcm, _npi.lcm_scalar, None, out)


@set_module('mxnet.symbol.numpy')
def argsort(a, axis=-1, kind=None, order=None):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified
by the `kind` keyword. It returns an array of indices of the same shape as
`a` that index data along the given axis in sorted order.
Parameters
----------
a : _Symbol
Array to sort.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
Returns
-------
index_array : _Symbol, int
Array of indices that sort `a` along the specified `axis`.
If `a` is one-dimensional, ``a[index_array]`` yields a sorted `a`.
More generally, ``np.take_along_axis(a, index_array, axis=axis)``
always yields the sorted `a`, irrespective of dimensionality.
Notes
-----
This operator does not support different sorting algorithms.
"""
if order is not None:
raise NotImplementedError("order is not supported yet...")

return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')


@set_module('mxnet.symbol.numpy')
def tensordot(a, b, axes=2):
r"""
Expand Down
8 changes: 6 additions & 2 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ inline void ParseTopKParam(const TShape& src_shape,
CHECK(*axis >= 0 && *axis < static_cast<int>(src_shape.ndim()))
<< "Invalid axis! axis should be between 0 and "
<< src_shape.ndim() << ", found axis=" << *axis;
*batch_size = src_shape.Size() / src_shape[*axis];
if (src_shape[*axis] != 0) {
*batch_size = src_shape.Size() / src_shape[*axis];
}
*element_num = src_shape[*axis];
if (*axis != src_shape.ndim() - 1) {
*do_transpose = true;
Expand All @@ -180,7 +182,7 @@ inline void ParseTopKParam(const TShape& src_shape,
(*target_shape)[*axis] = *k;
}
}
CHECK(*k >= 1 && *k <= *element_num) << "k must be smaller than "
CHECK(*k >= 0 && *k <= *element_num) << "k must be smaller than "
<< *element_num << ", get k = " << *k;
}

Expand Down Expand Up @@ -391,6 +393,8 @@ void TopKImpl(const RunContext &ctx,
const TopKParam& param) {
using namespace mshadow;
using namespace mshadow::expr;
// 0. If input shape is 0-shape, directly return
if (src.Size() == 0) return;
// 1. Parse and initialize information
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 1, char> workspace;
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/ordering_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Examples::
.add_arguments(SortParam::__FIELDS__());

NNVM_REGISTER_OP(argsort)
.add_alias("_npi_argsort")
.describe(R"code(Returns the indices that would sort an input array along the given axis.
This function performs sorting along the given axis and returns an array of indices having same shape
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,17 @@ def _add_workload_around():
OpArgMngr.add_workload('around', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1)


def _add_workload_argsort():
for dtype in [np.int32, np.float32]:
a = np.arange(101, dtype=dtype)
OpArgMngr.add_workload('argsort', a)
OpArgMngr.add_workload('argsort', np.array([[3, 2], [1, 0]]), 1)
OpArgMngr.add_workload('argsort', np.array([[3, 2], [1, 0]]), 0)
a = np.ones((3, 2, 1, 0))
for axis in range(-a.ndim, a.ndim):
OpArgMngr.add_workload('argsort', a, axis)


def _add_workload_broadcast_arrays(array_pool):
OpArgMngr.add_workload('broadcast_arrays', array_pool['4x1'], array_pool['1x2'])

Expand Down Expand Up @@ -1311,6 +1322,7 @@ def _prepare_workloads():
_add_workload_argmin()
_add_workload_argmax()
_add_workload_around()
_add_workload_argsort()
_add_workload_append()
_add_workload_broadcast_arrays(array_pool)
_add_workload_broadcast_to()
Expand Down
35 changes: 35 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,41 @@ def hybrid_forward(self, F, a):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)


@with_seed()
@use_np
def test_np_argsort():
class TestArgsort(HybridBlock):
def __init__(self, axis):
super(TestArgsort, self).__init__()
self._axis = axis

def hybrid_forward(self, F, x):
return F.np.argsort(x, axis=self._axis)

shapes = [
(),
(2, 3),
(1, 0, 2),
]

for shape in shapes:
data = np.random.uniform(size=shape)
np_data = data.asnumpy()

for axis in [None] + [i for i in range(-len(shape), len(shape))]:
np_out = _np.argsort(np_data, axis)

test_argsort = TestArgsort(axis)
for hybrid in [False, True]:
if hybrid:
test_argsort.hybridize()
mx_out = test_argsort(data)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False)

mx_out = np.argsort(data, axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False)


@with_seed()
@use_np
def test_np_squeeze():
Expand Down

0 comments on commit 2319a3e

Please sign in to comment.