diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index fc51d8af0f01..9fec6cd1255a 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -312,6 +312,23 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool ReduceMinMaxAxesShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + if (!shape_is_known((*in_attrs)[0])) return false; + CHECK_GT((*in_attrs)[0].Size(), 0U) + << "Reduction input's size should > 0 " + << (*in_attrs)[0]; + const ReduceAxesParam& param = nnvm::get(attrs.parsed); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + ReduceAxesShapeImpl((*in_attrs)[0], param.axis, + param.keepdims, param.exclude)); + return true; +} + + inline bool NormType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -1488,6 +1505,16 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs, .add_argument("data", "NDArray-or-Symbol", "The input") \ .add_arguments(ReduceAxesParam::__FIELDS__()) +#define MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(AxesParamParser) \ + .set_attr("FInferShape", ReduceMinMaxAxesShape) \ + .set_attr("FInferType", ElemwiseType<1, 1>) \ + .add_argument("data", "NDArray-or-Symbol", "The input") \ + .add_arguments(ReduceAxesParam::__FIELDS__()) + #define MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(name) \ NNVM_REGISTER_OP(name) \ .set_num_outputs(1) \ diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index f4231917e90d..f890963c2cf1 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -186,7 +186,7 @@ MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_nanprod) .set_num_inputs(3) .set_attr("FCompute", ReduceAxesBackwardUseInOut); -MXNET_OPERATOR_REGISTER_REDUCE(max) +MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(max) .add_alias("max_axis") .describe(get_reduce_axes_description("max", __LINE__)) .set_attr("FCompute", ReduceAxesCompute) @@ -200,7 +200,7 @@ MXNET_OPERATOR_REGISTER_REDUCE_BACKWARD(_backward_max) .set_num_inputs(3) .set_attr("FCompute", ReduceAxesBackwardUseInOut); -MXNET_OPERATOR_REGISTER_REDUCE(min) +MXNET_OPERATOR_REGISTER_MINMAX_REDUCE(min) .add_alias("min_axis") .describe(get_reduce_axes_description("min", __LINE__)) .set_attr("FCompute", ReduceAxesCompute) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f2d8a1b2524f..2a8e7d6e6698 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6990,6 +6990,21 @@ def test_float16_min_max(): assert np.finfo('float16').max == mx.nd.max(a).asscalar() +@with_seed() +@mx.use_np_compat +def test_zero_size_min_max(): + def min(): + a = mx.nd.zeros(shape=(5, 0)) + a.min() + + def max(): + a = mx.nd.zeros(shape=(5, 0)) + a.max() + + assert_raises(MXNetError, min) + assert_raises(MXNetError, max) + + @with_seed() def test_squeeze_op(): def check_squeeze_op(shape, axis=None):