From ba278faa5a738bd17b27d49c7cc7825708d4ce7f Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 19 Apr 2019 18:34:15 +0000 Subject: [PATCH 1/4] fix min max of zero-sized ndarray --- src/operator/tensor/broadcast_reduce_op.h | 27 +++++++++++++++++++ .../tensor/broadcast_reduce_op_value.cc | 4 +-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index fc51d8af0f01..9fec6cd1255a 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -312,6 +312,23 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool ReduceMinMaxAxesShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known((*in_attrs)[0])) return false; + CHECK_GT((*in_attrs)[0].Size(), 0U) + << "Reduction input's size should > 0 " + << (*in_attrs)[0]; + const ReduceAxesParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + ReduceAxesShapeImpl((*in_attrs)[0], param.axis, + param.keepdims, param.exclude)); + return true; +} + + inline bool NormType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -1488,6 +1505,16 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs, .add_argument("data", "NDArray-or-Symbol", "The input") \ .add_arguments(ReduceAxesParam::__FIELDS__()) +#define MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(AxesParamParser) \ + .set_attr("FInferShape", ReduceMinMaxAxesShape) \ + .set_attr("FInferType", ElemwiseType<1, 1>) \ + .add_argument("data", "NDArray-or-Symbol", "The input") \ + .add_arguments(ReduceAxesParam::__FIELDS__()) + #define MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(name) \ NNVM_REGISTER_OP(name) \ .set_num_outputs(1) \ diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index f4231917e90d..f890963c2cf1 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -186,7 +186,7 @@ MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_nanprod) .set_num_inputs(3) .set_attr("FCompute", ReduceAxesBackwardUseInOut); -MXNET_OPERATOR_REGISTER_REDUCE(max) +MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(max) .add_alias("max_axis") .describe(get_reduce_axes_description("max", __LINE__)) .set_attr("FCompute", ReduceAxesCompute) @@ -200,7 +200,7 @@ MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_max) .set_num_inputs(3) .set_attr("FCompute", ReduceAxesBackwardUseInOut); -MXNET_OPERATOR_REGISTER_REDUCE(min) +MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(min) .add_alias("min_axis") .describe(get_reduce_axes_description("min", __LINE__)) .set_attr("FCompute", ReduceAxesCompute) From 5a99737ddcd5dad81bb05e0185e0cce40a3102d5 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 19 Apr 2019 18:38:44 +0000 Subject: [PATCH 2/4] add test --- tests/python/unittest/test_operator.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f2d8a1b2524f..78c8760bfb73 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6990,6 +6990,20 @@ def test_float16_min_max(): assert np.finfo('float16').max == mx.nd.max(a).asscalar() +@with_seed() +def test_zero_sized_min_max(): + def min(): + a = mx.nd.zeros(shape=(5, 0)) + a.min() + + def max(): + a = mx.nd.zeros(shape=(5, 0)) + a.max() + + assert_raises(MXNetError, min) + assert_raises(MXNetError, max) + + @with_seed() def test_squeeze_op(): def check_squeeze_op(shape, axis=None): From 31e0048bb5fd23a696d54c6e19f8e849a409c504 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Sat, 20 Apr 2019 03:59:49 +0000 Subject: [PATCH 3/4] turn on numpy mode --- tests/python/unittest/test_operator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 78c8760bfb73..2a8e7d6e6698 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6991,7 +6991,8 @@ def test_float16_min_max(): @with_seed() -def test_zero_sized_min_max(): +@mx.use_np_compat +def test_zero_size_min_max(): def min(): a = mx.nd.zeros(shape=(5, 0)) a.min() From 92f993f31369d5476ec34dfeb5f5e63482049de1 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 23 Apr 2019 14:30:16 +0800 Subject: [PATCH 4/4] trigger CI