Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy][Operator] 'where' Implementation in MXNet #16829

Merged
merged 8 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num']
'nan_to_num', 'where']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5308,3 +5308,75 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
return _npi.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf, out=None)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.ndarray.numpy')
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.

.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. The rest of this documentation
covers only the case where all three arguments are provided.

Parameters
----------
condition : ndarray
Where True, yield `x`, otherwise yield `y`.
x, y : ndarray
Values from which to choose. `x`, `y` and `condition` need to be
broadcastable to some shape. `x` and `y` must have the same dtype.

Returns
-------
out : ndarray
An array with elements from `x` where `condition` is True, and elements
from `y` elsewhere.

Notes
-----
If all the arrays are 1-D, `where` is equivalent to::

[xv if c else yv
for c, xv, yv in zip(condition, x, y)]

Examples
--------
>>> a = np.arange(10)
>>> a
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>> np.where(a < 5, a, 10*a)
array([ 0., 1., 2., 3., 4., 50., 60., 70., 80., 90.])

This can be used on multidimensional arrays too:

>>> cond = np.array([[True, False], [True, True]])
>>> x = np.array([[1, 2], [3, 4]])
>>> y = np.array([[9, 8], [7, 6]])
>>> np.where(cond, x, y)
array([[1., 8.],
[3., 4.]])

The shapes of x, y, and the condition are broadcast together:

>>> x, y = onp.ogrid[:3, :4]
>>> x = np.array(x)
>>> y = np.array(y)
>>> np.where(x < y, x, 10 + y) # both x and 10+y are broadcast
array([[10, 0, 0, 0],
[10, 11, 1, 1],
[10, 11, 12, 2]], dtype=int64)

>>> a = np.array([[0, 1, 2],
... [0, 2, 4],
... [0, 3, 6]])
>>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast
array([[ 0., 1., 2.],
[ 0., 2., -1.],
[ 0., 3., -1.]])
"""
if x is None and y is None:
return nonzero(condition)
else:
return _npi.where(condition, x, y, out=None)
71 changes: 70 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num']
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -7295,3 +7295,72 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
[ 2.22222000e+005, 2.22222000e+005, -1.79769313e+308]], dtype=float64)
"""
return _mx_nd_np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


@set_module('mxnet.numpy')
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.

.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. The rest of this documentation
covers only the case where all three arguments are provided.

Parameters
----------
condition : ndarray
Where True, yield `x`, otherwise yield `y`.
x, y : ndarray
Values from which to choose. `x`, `y` and `condition` need to be
broadcastable to some shape. `x` and `y` must have the same dtype.

Returns
-------
out : ndarray
An array with elements from `x` where `condition` is True, and elements
from `y` elsewhere.

Notes
-----
If all the arrays are 1-D, `where` is equivalent to::

[xv if c else yv
for c, xv, yv in zip(condition, x, y)]

Examples
--------
>>> a = np.arange(10)
>>> a
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>> np.where(a < 5, a, 10*a)
array([ 0., 1., 2., 3., 4., 50., 60., 70., 80., 90.])

This can be used on multidimensional arrays too:

>>> cond = np.array([[True, False], [True, True]])
>>> x = np.array([[1, 2], [3, 4]])
>>> y = np.array([[9, 8], [7, 6]])
>>> np.where(cond, x, y)
array([[1., 8.],
[3., 4.]])

The shapes of x, y, and the condition are broadcast together:

>>> x, y = onp.ogrid[:3, :4]
>>> x = np.array(x)
>>> y = np.array(y)
>>> np.where(x < y, x, 10 + y) # both x and 10+y are broadcast
array([[10, 0, 0, 0],
[10, 11, 1, 1],
[10, 11, 12, 2]], dtype=int64)

>>> a = np.array([[0, 1, 2],
... [0, 2, 4],
... [0, 3, 6]])
>>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast
array([[ 0., 1., 2.],
[ 0., 2., -1.],
[ 0., 3., -1.]])
"""
return _mx_nd_np.where(condition, x, y)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'may_share_memory',
'diff',
'resize',
'where',
]


Expand Down
28 changes: 25 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num']
'resize', 'nan_to_num', 'where']


def _num_outputs(sym):
Expand Down Expand Up @@ -4845,7 +4845,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):

Parameters
----------
x : Symbol
x : _Symbol
Input data.
copy : bool, optional
Whether to create a copy of `x` (True) or to replace values
Expand All @@ -4868,7 +4868,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):

Returns
-------
out : ndarray
out : _Symbol
`x`, with the non-finite values replaced. If `copy` is False, this may
be `x` itself.

Expand All @@ -4888,5 +4888,27 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.symbol.numpy')
def where(condition, x, y):
"""
Return elements chosen from `x` or `y` depending on `condition`.

Parameters
----------
condition : _Symbol
Where True, yield `x`, otherwise yield `y`.
x, y : _Symbol
Values from which to choose. `x`, `y` and `condition` need to be
broadcastable to some shape. `x` and `y` must have the same dtype.

Returns
-------
out : _Symbol
An array with elements from `x` where `condition` is True, and elements
from `y` elsewhere.

"""
return _npi.where(condition, x, y, out=None)


_set_np_symbol_class(_Symbol)
Loading