From 628e9fc7cf724f07c6716b3698664fbf71a05547 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Mon, 9 Dec 2019 04:52:30 +0000 Subject: [PATCH] fix axis=-1 bug --- src/operator/numpy/np_broadcast_reduce_op.h | 2 +- tests/python/unittest/test_numpy_op.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 3566323f1eb3..a87e2c58bf9e 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -187,7 +187,7 @@ inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs, if (param.axis.has_value()) { const mxnet::Tuple& axes = param.axis.value(); for (int i = 0; i < axes.ndim(); ++i) { - if (ishape[axes[i]] == 0) { + if ((axes[i] >= 0) && (ishape[axes[i]] == 0)) { is_all_reducded_axes_not_zero = false; break; } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 38633760c400..3f4fb0645fc3 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -531,7 +531,7 @@ def get_grad(axis, func_name): elif axis == 2: temp[:,:,index,:] = 1 return temp - elif axis == 3: + elif (axis == 3 or axis == -1): temp[:,:,:,index] = 1 return temp elif not axis: @@ -549,7 +549,7 @@ def _test_np_exception(func, shape, dim): for func in ['max', 'min']: for hybridize in [False, True]: for keepdims in [True, False]: - for axis in ([i for i in range(in_data_dim)] + [(), None]): + for axis in ([i for i in range(in_data_dim)] + [(), None] + [-1]): for itype in ['float16', 'float32', 'float64', 'int']: # test gluon if func == 'max':