diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 30b3d2f4d90b..f9164855dfe9 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -38,7 +38,8 @@ 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize'] + 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', + 'nan_to_num'] @set_module('mxnet.ndarray.numpy') @@ -5208,3 +5209,102 @@ def resize(a, new_shape): [0., 1., 2., 3.]]) """ return _npi.resize_fallback(a, new_shape=new_shape) + + +@set_module('mxnet.ndarray.numpy') +def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): + """ + Replace NaN with zero and infinity with large finite numbers (default + behaviour) or with the numbers defined by the user using the `nan`, + `posinf` and/or `neginf` keywords. + + If `x` is inexact, NaN is replaced by zero or by the user defined value in + `nan` keyword, infinity is replaced by the largest finite floating point + values representable by ``x.dtype`` or by the user defined value in + `posinf` keyword and -infinity is replaced by the most negative finite + floating point values representable by ``x.dtype`` or by the user defined + value in `neginf` keyword. + + For complex dtypes, the above is applied to each of the real and + imaginary components of `x` separately. + + If `x` is not inexact, then no replacements are made. + + Parameters + ---------- + x : ndarray + Input data. + copy : bool, optional + Whether to create a copy of `x` (True) or to replace values + in-place (False). The in-place operation only occurs if + casting to an array does not require a copy. + Default is True. + nan : int, float, optional + Value to be used to fill NaN values. If no value is passed + then NaN values will be replaced with 0.0. + posinf : int, float, optional + Value to be used to fill positive infinity values. If no value is + passed then positive infinity values will be replaced with a very + large number. + neginf : int, float, optional + Value to be used to fill negative infinity values. If no value is + passed then negative infinity values will be replaced with a very + small (or negative) number. + + .. versionadded:: 1.13 + + Returns + ------- + out : ndarray + `x`, with the non-finite values replaced. If `copy` is False, this may + be `x` itself. + + Notes + ----- + NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic + (IEEE 754). This means that Not a Number is not equivalent to infinity. + + Examples + -------- + >>> np.nan_to_num(np.inf) + 1.7976931348623157e+308 + >>> np.nan_to_num(-np.inf) + -1.7976931348623157e+308 + >>> np.nan_to_num(np.nan) + 0.0 + >>> x = np.array([np.inf, -np.inf, np.nan, -128, 128]) + >>> np.nan_to_num(x) + array([ 3.4028235e+38, -3.4028235e+38, 0.0000000e+00, -1.2800000e+02, + 1.2800000e+02]) + >>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333) + array([ 3.3333332e+07, 3.3333332e+07, -9.9990000e+03, -1.2800000e+02, + 1.2800000e+02]) + >>> y = np.array([[-1, 0, 1],[9999,234,-14222]],dtype="float64")/0 + array([[-inf, nan, inf], + [ inf, inf, -inf]], dtype=float64) + >>> np.nan_to_num(y) + array([[-1.79769313e+308, 0.00000000e+000, 1.79769313e+308], + [ 1.79769313e+308, 1.79769313e+308, -1.79769313e+308]], dtype=float64) + >>> np.nan_to_num(y, nan=111111, posinf=222222) + array([[-1.79769313e+308, 1.11111000e+005, 2.22222000e+005], + [ 2.22222000e+005, 2.22222000e+005, -1.79769313e+308]], dtype=float64) + >>> y + array([[-inf, nan, inf], + [ inf, inf, -inf]], dtype=float64) + >>> np.nan_to_num(y, copy=False, nan=111111, posinf=222222) + array([[-1.79769313e+308, 1.11111000e+005, 2.22222000e+005], + [ 2.22222000e+005, 2.22222000e+005, -1.79769313e+308]], dtype=float64) + >>> y + array([[-1.79769313e+308, 1.11111000e+005, 2.22222000e+005], + [ 2.22222000e+005, 2.22222000e+005, -1.79769313e+308]], dtype=float64) + """ + if isinstance(x, numeric_types): + return _np.nan_to_num(x, copy, nan, posinf, neginf) + elif isinstance(x, NDArray): + if x.dtype in ['int8', 'uint8', 'int32', 'int64']: + return x + if not copy: + return _npi.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf, out=x) + return _npi.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf, out=None) + else: + raise TypeError('type {} not supported'.format(str(type(x)))) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 9439e751f1be..b6816d75a98e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -57,7 +57,7 @@ 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -7202,3 +7202,95 @@ def resize(a, new_shape): [0., 1., 2., 3.]]) """ return _mx_nd_np.resize(a, new_shape) + + +@set_module('mxnet.numpy') +def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): + """ + Replace NaN with zero and infinity with large finite numbers (default + behaviour) or with the numbers defined by the user using the `nan`, + `posinf` and/or `neginf` keywords. + + If `x` is inexact, NaN is replaced by zero or by the user defined value in + `nan` keyword, infinity is replaced by the largest finite floating point + values representable by ``x.dtype`` or by the user defined value in + `posinf` keyword and -infinity is replaced by the most negative finite + floating point values representable by ``x.dtype`` or by the user defined + value in `neginf` keyword. + + For complex dtypes, the above is applied to each of the real and + imaginary components of `x` separately. + + If `x` is not inexact, then no replacements are made. + + Parameters + ---------- + x : scalar + ndarray + Input data. + copy : bool, optional + Whether to create a copy of `x` (True) or to replace values + in-place (False). The in-place operation only occurs if + casting to an array does not require a copy. + Default is True. + Gluon does not support copy = False. + nan : int, float, optional + Value to be used to fill NaN values. If no value is passed + then NaN values will be replaced with 0.0. + posinf : int, float, optional + Value to be used to fill positive infinity values. If no value is + passed then positive infinity values will be replaced with a very + large number. + neginf : int, float, optional + Value to be used to fill negative infinity values. If no value is + passed then negative infinity values will be replaced with a very + small (or negative) number. + + .. versionadded:: 1.13 + + Returns + ------- + out : ndarray + `x`, with the non-finite values replaced. If `copy` is False, this may + be `x` itself. + + Notes + ----- + NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic + (IEEE 754). This means that Not a Number is not equivalent to infinity. + + Examples + -------- + >>> np.nan_to_num(np.inf) + 1.7976931348623157e+308 + >>> np.nan_to_num(-np.inf) + -1.7976931348623157e+308 + >>> np.nan_to_num(np.nan) + 0.0 + >>> x = np.array([np.inf, -np.inf, np.nan, -128, 128]) + >>> np.nan_to_num(x) + array([ 3.4028235e+38, -3.4028235e+38, 0.0000000e+00, -1.2800000e+02, + 1.2800000e+02]) + >>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333) + array([ 3.3333332e+07, 3.3333332e+07, -9.9990000e+03, -1.2800000e+02, + 1.2800000e+02]) + >>> y = np.array([[-1, 0, 1],[9999,234,-14222]],dtype="float64")/0 + array([[-inf, nan, inf], + [ inf, inf, -inf]], dtype=float64) + >>> np.nan_to_num(y) + array([[-1.79769313e+308, 0.00000000e+000, 1.79769313e+308], + [ 1.79769313e+308, 1.79769313e+308, -1.79769313e+308]], dtype=float64) + >>> np.nan_to_num(y, nan=111111, posinf=222222) + array([[-1.79769313e+308, 1.11111000e+005, 2.22222000e+005], + [ 2.22222000e+005, 2.22222000e+005, -1.79769313e+308]], dtype=float64) + >>> y + array([[-inf, nan, inf], + [ inf, inf, -inf]], dtype=float64) + >>> np.nan_to_num(y, copy=False, nan=111111, posinf=222222) + array([[-1.79769313e+308, 1.11111000e+005, 2.22222000e+005], + [ 2.22222000e+005, 2.22222000e+005, -1.79769313e+308]], dtype=float64) + >>> y + array([[-1.79769313e+308, 1.11111000e+005, 2.22222000e+005], + [ 2.22222000e+005, 2.22222000e+005, -1.79769313e+308]], dtype=float64) + """ + return _mx_nd_np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 1eab6ad0342a..0d7303865b92 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -41,7 +41,7 @@ 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', - 'resize'] + 'resize', 'nan_to_num'] def _num_outputs(sym): @@ -4824,4 +4824,69 @@ def resize(a, new_shape): return _npi.resize_fallback(a, new_shape=new_shape) +@set_module('mxnet.symbol.numpy') +def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): + """ + Replace NaN with zero and infinity with large finite numbers (default + behaviour) or with the numbers defined by the user using the `nan`, + `posinf` and/or `neginf` keywords. + + If `x` is inexact, NaN is replaced by zero or by the user defined value in + `nan` keyword, infinity is replaced by the largest finite floating point + values representable by ``x.dtype`` or by the user defined value in + `posinf` keyword and -infinity is replaced by the most negative finite + floating point values representable by ``x.dtype`` or by the user defined + value in `neginf` keyword. + + For complex dtypes, the above is applied to each of the real and + imaginary components of `x` separately. + + If `x` is not inexact, then no replacements are made. + + Parameters + ---------- + x : Symbol + Input data. + copy : bool, optional + Whether to create a copy of `x` (True) or to replace values + in-place (False). The in-place operation only occurs if + casting to an array does not require a copy. + Default is True. + nan : int, float, optional + Value to be used to fill NaN values. If no value is passed + then NaN values will be replaced with 0.0. + posinf : int, float, optional + Value to be used to fill positive infinity values. If no value is + passed then positive infinity values will be replaced with a very + large number. + neginf : int, float, optional + Value to be used to fill negative infinity values. If no value is + passed then negative infinity values will be replaced with a very + small (or negative) number. + + .. versionadded:: 1.13 + + Returns + ------- + out : ndarray + `x`, with the non-finite values replaced. If `copy` is False, this may + be `x` itself. + + Notes + ----- + NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic + (IEEE 754). This means that Not a Number is not equivalent to infinity. + + """ + if isinstance(x, numeric_types): + return _np.nan_to_num(x, copy, nan, posinf, neginf) + elif isinstance(x, _Symbol): + if not copy: + return _npi.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf, out=x) + return _npi.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf, out=None) + else: + raise TypeError('type {} not supported'.format(str(type(x)))) + + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index b8db165675a0..0477d8553898 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -52,6 +52,7 @@ const float SELU_ALPHA = 1.6732632423543772848170429916717; const float SELU_LAMBDA = 1.0507009873554804934193349852946; const float SQRT_2 = 1.4142135623730950488016887242096; using std::isnan; +using std::isinf; #endif using std::enable_if; using std::is_unsigned; @@ -1012,6 +1013,30 @@ namespace isnan_typed { } }; // namespace isnan_typed +namespace isinf_typed { + template + MSHADOW_XINLINE bool IsInf(volatile DType val) { + return false; + } + template<> + MSHADOW_XINLINE bool IsInf(volatile float val) { + return isinf(val); + } + template<> + MSHADOW_XINLINE bool IsInf(volatile double val) { + return isinf(val); + } + template<> + MSHADOW_XINLINE bool IsInf(volatile long double val) { + return isinf(val); + } + + template<> + MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) { + return (val.half_ & 0x7fff) >= 0x7c00; + } +}; // namespace isinf_typed + MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0)); /*! \brief used for computing gradient of relu operator */ diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc index c980dcfaab5d..cad736aab65b 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -419,5 +419,34 @@ NNVM_REGISTER_OP(_npi_around) .add_arguments(AroundParam::__FIELDS__()) .set_attr("FGradient", MakeZeroGradNodes); +DMLC_REGISTER_PARAMETER(NumpyNanToNumParam); + +NNVM_REGISTER_OP(_npi_nan_to_num) +.describe("" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", NumpyNanToNumOpForward) +.set_attr("FGradient", ElemwiseGradUseIn{"_npi_backward_nan_to_num"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyNanToNumParam::__FIELDS__()); + +NNVM_REGISTER_OP(_npi_backward_nan_to_num) +.set_attr_parser(ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyNanToNumOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu index 44743ed94be8..af8834f01664 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cu +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -110,5 +110,11 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arctanh, mshadow_op::arctanh); NNVM_REGISTER_OP(_npi_around) .set_attr("FCompute", AroundOpForward); +NNVM_REGISTER_OP(_npi_nan_to_num) +.set_attr("FCompute", NumpyNanToNumOpForward); + +NNVM_REGISTER_OP(_npi_backward_nan_to_num) +.set_attr("FCompute", NumpyNanToNumOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 188ccd68a340..27013dfb98ae 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -34,6 +34,7 @@ #include "../mshadow_op.h" #include "../mxnet_op.h" #include "../elemwise_op_common.h" +#include "../../common/utils.h" #include "../../ndarray/ndarray_function.h" #if MSHADOW_USE_MKL == 1 @@ -660,6 +661,134 @@ void AroundOpForward(const nnvm::NodeAttrs& attrs, } } +struct NumpyNanToNumParam : public dmlc::Parameter { + bool copy; + double nan; + dmlc::optional posinf, neginf; + DMLC_DECLARE_PARAMETER(NumpyNanToNumParam) { + DMLC_DECLARE_FIELD(copy) + .set_default(true) + .describe("Whether to create a copy of `x` (True) or to replace values" + "in-place (False). The in-place operation only occurs if" + "casting to an array does not require a copy." + "Default is True."); + DMLC_DECLARE_FIELD(nan) + .set_default(0.0) + .describe("Value to be used to fill NaN values. If no value is passed" + "then NaN values will be replaced with 0.0."); + DMLC_DECLARE_FIELD(posinf) + .set_default(dmlc::optional()) + .describe("Value to be used to fill positive infinity values." + "If no value is passed then positive infinity values will be" + "replaced with a very large number."); + DMLC_DECLARE_FIELD(neginf) + .set_default(dmlc::optional()) + .describe("Value to be used to fill negative infinity values." + "If no value is passed then negative infinity values" + "will be replaced with a very small (or negative) number."); + } +}; + +template +struct nan_to_num_forward { + template + MSHADOW_XINLINE static void Map(int i, + DType* out_data, + const DType* in_data, + const DType nan, + const DType posinf, + const DType neginf) { + DType val = in_data[i]; + if (mshadow_op::isnan_typed::IsNan(val)) val = nan; + if (val > 0 && mshadow_op::isinf_typed::IsInf(val)) val = posinf; + if (val < 0 && mshadow_op::isinf_typed::IsInf(val)) val = neginf; + KERNEL_ASSIGN(out_data[i], req, val); + } +}; + +template +void NumpyNanToNumOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + mshadow::Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const NumpyNanToNumParam& param = nnvm::get(attrs.parsed); + using namespace mxnet_op; + + if (!common::is_float(in_data.type_flag_) && req[0] == kWriteInplace) return; + if (!common::is_float(in_data.type_flag_)) { + copy(s, out_data, in_data); + return; + } + + MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, { + DType defaultnan = static_cast(param.nan); + DType posinf; + DType neginf; + if (param.posinf.has_value()) { + posinf = static_cast(param.posinf.value()); + } else { + posinf = mshadow::red::limits::MaxValue(); + } + if (param.neginf.has_value()) { + neginf = static_cast(param.neginf.value()); + } else { + neginf = mshadow::red::limits::MinValue(); + } + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, out_data.Size(), out_data.dptr(), in_data.dptr(), + defaultnan, posinf, neginf); + }); + }); +} + +template +struct nan_to_num_backward { + template + MSHADOW_XINLINE static void Map(int i, + DType* in_grad, + const DType* out_grad, + const DType* in_data) { + DType val = out_grad[i]; + if (mshadow_op::isnan_typed::IsNan(in_data[i])) val = 0; + if (val > 0 && mshadow_op::isinf_typed::IsInf(in_data[i])) val = 0; + if (val < 0 && mshadow_op::isinf_typed::IsInf(in_data[i])) val = 0; + KERNEL_ASSIGN(in_grad[i], req, val); + } +}; + +template +void NumpyNanToNumOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_NE(req[0], kWriteInplace); + mshadow::Stream *s = ctx.get_stream(); + const TBlob& out_grad = inputs[0]; + const TBlob& in_data = inputs[1]; + const TBlob& in_grad = outputs[0]; + CHECK_EQ(common::is_float(in_data.type_flag_), true); + using namespace mxnet_op; + MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, in_grad.Size(), in_grad.dptr(), out_grad.dptr(), + in_data.dptr()); + }); + }); +} + /*! \brief Unary compute */ #define MXNET_OPERATOR_REGISTER_UNARY(__name$) \ NNVM_REGISTER_OP(__name$) \ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 32cd5b10717e..35c992b23194 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4227,6 +4227,109 @@ def hybrid_forward(self, F, x, *args, **kwargs): assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) +@with_seed() +@use_np +def test_np_nan_to_num(): + + def take_ele_grad(ele): + if _np.isinf(ele) or _np.isnan(ele): + return 0 + return 1 + def np_nan_to_num_grad(data): + shape = data.shape + arr = list(map(take_ele_grad,data.flatten())) + return _np.array(arr).reshape(shape) + + class TestNanToNum(HybridBlock): + def __init__(self, copy=True, nan=0.0, posinf=None, neginf=None): + super(TestNanToNum, self).__init__() + self.copy = copy + self.nan = nan + self.posinf = posinf + self.neginf = neginf + # necessary initializations + + def hybrid_forward(self, F, a): + return F.np.nan_to_num(a, self.copy, self.nan, self.posinf, self.neginf) + + src_list = [ + _np.nan, + _np.inf, + -_np.inf, + 1, + [_np.nan], + [_np.inf], + [-_np.inf], + [1], + [1,2,3,4,-1,-2,-3,-4,0], + [_np.nan, _np.inf, -_np.inf], + [_np.nan, _np.inf, -_np.inf, -574, 0, 23425, 24234,-5], + [_np.nan, -1, 0, 1], + [[-433, 0, 456, _np.inf], [-1, -_np.inf, 0, 1]] + ] + + dtype_list = ['float16', 'float32', 'float64'] + # [nan, inf, -inf] + param_list = [[None, None, None], [0, 1000, -100], [0.0, 9999.9, -9999.9]] + copy_list = [True, False] + hybridize_list = [True, False] + atol, rtol = 1e-5, 1e-3 + + src_dtype_comb = list(itertools.product(src_list,dtype_list)) + # check the dtype = int case in both imperative and sympolic expression + src_dtype_comb.append((1,'int32')) + src_dtype_comb.append(([234, 0, -40],'int64')) + + combinations = itertools.product(hybridize_list, src_dtype_comb, copy_list, param_list) + + numpy_version = _np.version.version + for [hybridize, src_dtype, copy, param] in combinations: + src, dtype = src_dtype + # np.nan, np.inf, -np.int are float type + x1 = mx.nd.array(src, dtype=dtype).as_np_ndarray().asnumpy() + x2 = mx.nd.array(src, dtype=dtype).as_np_ndarray() + x3 = mx.nd.array(src, dtype=dtype).as_np_ndarray() + + expected_grad = np_nan_to_num_grad(x1) + x2.attach_grad() + # with optional parameters or without + if param[0] !=None and numpy_version>="1.17": + test_np_nan_to_num = TestNanToNum(copy=copy, nan=param[0], posinf=param[1], neginf=param[2]) + np_out = _np.nan_to_num(x1, copy=copy, nan=param[0], posinf=param[1], neginf=param[2]) + mx_out = np.nan_to_num(x3, copy=copy, nan=param[0], posinf=param[1], neginf=param[2]) + else: + test_np_nan_to_num = TestNanToNum(copy=copy) + np_out = _np.nan_to_num(x1, copy=copy) + mx_out = np.nan_to_num(x3, copy=copy) + + assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol) + # check the inplace operation when copy = False + # if x1.shape = 0, _np.array will not actually execute copy logic + # only check x3 from np.nan_to_num instead of x2 from gluon + if copy == False and x1.shape!=(): + assert x1.shape == x3.asnumpy().shape + assert x1.dtype == x3.asnumpy().dtype + assert_almost_equal(x1, x3.asnumpy(), rtol=rtol, atol=atol) + # gluon does not support nan_to_num when copy=False + # backward will check int type and if so, throw error + # if not this case, test gluon + if not (hybridize== False and copy == False) and ('float' in dtype): + if hybridize: + test_np_nan_to_num.hybridize() + with mx.autograd.record(): + mx_out_gluon = test_np_nan_to_num(x2) + assert_almost_equal(mx_out_gluon.asnumpy(), np_out, rtol, atol) + mx_out_gluon.backward() + assert_almost_equal(x2.grad.asnumpy(), expected_grad, rtol=1e-3, atol=1e-5) + + # Test imperative once again + # if copy = False, the value of x1 and x2 has changed + if copy == True: + np_out = _np.nan_to_num(x1) + mx_out = np.nan_to_num(x3) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + + if __name__ == '__main__': import nose nose.runmodule()