Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add np.product
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Feb 18, 2020
1 parent deeb9f9 commit 8dd333a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 1 deletion.
13 changes: 12 additions & 1 deletion python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _np_all(a, axis=None, keepdims=False, out=None):
>>> o=np.array(False)
>>> z=np.all([-1, 4, 5], out=o)
>>> id(z), id(o), z
(28293632, 28293632, array(True)) # may vary
(28293632, 28293632, array(True)) # may vary
"""
pass

Expand Down Expand Up @@ -1171,6 +1171,17 @@ def _np_prod(a, axis=None, dtype=None, out=None, keepdims=False):
pass


def _np_product(a, axis=None, dtype=None, out=None, keepdims=False):
"""
Return the product of array elements over a given axis.
See Also
--------
prod : equivalent function; see for details.
"""
pass


def _np_moveaxis(a, source, destination):
"""Move axes of an array to new positions.
Other axes remain in their original order.
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'atleast_2d',
'atleast_3d',
'prod',
'product',
'ravel',
'repeat',
'reshape',
Expand Down
1 change: 1 addition & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ NNVM_REGISTER_OP(_backward_np_min)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeBackward<cpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_prod)
.add_alias("_np_product")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,10 @@ def _add_workload_prod(array_pool):
OpArgMngr.add_workload('prod', array_pool['4x1'])


def _add_workload_product(array_pool):
OpArgMngr.add_workload('product', array_pool['4x1'])


def _add_workload_repeat(array_pool):
OpArgMngr.add_workload('repeat', array_pool['4x1'], 3)
OpArgMngr.add_workload('repeat', np.array(_np.arange(12).reshape(4, 3)[:, 2]), 3)
Expand Down Expand Up @@ -2015,6 +2019,7 @@ def _prepare_workloads():
_add_workload_ones_like(array_pool)
_add_workload_atleast_nd()
_add_workload_prod(array_pool)
_add_workload_product(array_pool)
_add_workload_repeat(array_pool)
_add_workload_reshape()
_add_workload_rint(array_pool)
Expand Down

0 comments on commit 8dd333a

Please sign in to comment.