From d970243f0a3a2a5df719d2527fabcaa355f485ee Mon Sep 17 00:00:00 2001 From: Minghao Liu <40382964+Tommliu@users.noreply.github.com> Date: Mon, 18 May 2020 18:11:30 +0800 Subject: [PATCH] binary scalar op support scalar dtype Co-Authored-By: Wentao Xu --- python/mxnet/symbol/numpy/_symbol.py | 8 +- src/api/operator/ufunc_helper.cc | 64 ++++++++++- .../numpy/np_elemwise_broadcast_logic_op.cc | 6 +- .../numpy/np_elemwise_broadcast_op.cc | 34 +++--- src/operator/numpy/np_elemwise_broadcast_op.h | 10 -- .../np_elemwise_broadcast_op_extended.cc | 72 ++++++------ .../np_elemwise_broadcast_op_extended_sec.cc | 48 ++++---- .../tensor/elemwise_binary_scalar_op.h | 106 ++++++++++++++---- .../tensor/elemwise_binary_scalar_op_basic.cc | 46 ++++---- .../elemwise_binary_scalar_op_extended.cc | 36 ++---- tests/python/unittest/test_numpy_op.py | 60 ++++++++++ 11 files changed, 323 insertions(+), 167 deletions(-) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 4a74e5f0f9f3..6869c1642460 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -1571,13 +1571,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou if isinstance(rhs, numeric_types): return fn_scalar(lhs, rhs, out=out) else: + is_int = isinstance(rhs, integer_types) if rfn_scalar is None: # commutative function - return lfn_scalar(rhs, float(lhs), out=out) + return lfn_scalar(rhs, float(lhs), is_int=is_int, out=out) else: - return rfn_scalar(rhs, float(lhs), out=out) + return rfn_scalar(rhs, float(lhs), is_int=is_int, out=out) elif isinstance(rhs, numeric_types): - return lfn_scalar(lhs, float(rhs), out=out) + is_int = isinstance(rhs, integer_types) + return lfn_scalar(lhs, float(rhs), is_int=is_int, out=out) elif isinstance(rhs, Symbol): return fn_array(lhs, rhs, out=out) else: diff --git a/src/api/operator/ufunc_helper.cc b/src/api/operator/ufunc_helper.cc index c29d1b7985be..9f4dcb7992f0 100644 --- a/src/api/operator/ufunc_helper.cc +++ b/src/api/operator/ufunc_helper.cc @@ -24,6 +24,7 @@ #include "ufunc_helper.h" #include "utils.h" #include "../../imperative/imperative_utils.h" +#include "../../operator/tensor/elemwise_binary_scalar_op.h" namespace mxnet { @@ -51,13 +52,38 @@ void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out, } } +void UFuncHelper(NDArray* lhs, int rhs, NDArray* out, + runtime::MXNetRetValue* ret, const nnvm::Op* op) { + using namespace runtime; + nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = rhs; + param.is_int = true; + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); + NDArray** inputs = &lhs; + int num_inputs = 1; + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (outputs) { + *ret = PythonArg(2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +} + void UFuncHelper(NDArray* lhs, double rhs, NDArray* out, runtime::MXNetRetValue* ret, const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = rhs; + param.is_int = false; attrs.op = op; - attrs.parsed = rhs; - SetAttrDict(&attrs); + attrs.parsed = param; + SetAttrDict(&attrs); NDArray** inputs = &lhs; int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; @@ -70,13 +96,38 @@ void UFuncHelper(NDArray* lhs, double rhs, NDArray* out, } } +void UFuncHelper(int lhs, NDArray* rhs, NDArray* out, + runtime::MXNetRetValue* ret, const nnvm::Op* op) { + using namespace runtime; + nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = lhs; + param.is_int = true; + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); + NDArray** inputs = &rhs; + int num_inputs = 1; + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (outputs) { + *ret = PythonArg(2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +} + void UFuncHelper(double lhs, NDArray* rhs, NDArray* out, runtime::MXNetRetValue* ret, const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = lhs; + param.is_int = false; attrs.op = op; - attrs.parsed = lhs; - SetAttrDict(&attrs); + attrs.parsed = param; + SetAttrDict(&attrs); NDArray** inputs = &rhs; int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; @@ -99,9 +150,14 @@ void UFuncHelper(runtime::MXNetArgs args, if (args[0].type_code() == kNDArrayHandle) { if (args[1].type_code() == kNDArrayHandle) { UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array); + } else if (args[1].type_code() == kDLInt){ + UFuncHelper(args[0].operator NDArray*(), args[1].operator int(), out, ret, lfn_scalar); } else { UFuncHelper(args[0].operator NDArray*(), args[1].operator double(), out, ret, lfn_scalar); } + } else if (args[0].type_code() == kDLInt) { + UFuncHelper(args[0].operator int(), args[1].operator NDArray*(), out, ret, + rfn_scalar ? rfn_scalar : lfn_scalar); } else { UFuncHelper(args[0].operator double(), args[1].operator NDArray*(), out, ret, rfn_scalar ? rfn_scalar : lfn_scalar); diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index 0f5d9923a3e2..b44a22d65856 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -242,9 +242,7 @@ struct TVMBinaryBroadcastScalarCompute { NNVM_REGISTER_OP(_npi_##name##_scalar) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FListInputNames", \ [](const NodeAttrs& attrs) { \ return std::vector{"data"}; \ @@ -257,7 +255,7 @@ struct TVMBinaryBroadcastScalarCompute { }) \ .set_attr("FGradient", MakeZeroGradNodes) \ .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 6ec880e0ba8b..c2db0fee7512 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -28,21 +28,25 @@ namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +DMLC_REGISTER_PARAMETER(NumpyBinaryScalarParam); + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index a0e204318839..8bf32f58b2aa 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -35,16 +35,6 @@ namespace mxnet { namespace op { -inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return in_attrs->at(0) != -1; -} - inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { LOG(FATAL) << "Operator " << op_name << " does not support combination of " << mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 70233a596dc7..96e27789cf11 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -29,21 +29,23 @@ namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) .describe(R"code()code" ADD_FILELINE) @@ -86,9 +88,7 @@ NNVM_REGISTER_OP(_npi_lcm) NNVM_REGISTER_OP(_npi_lcm_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -97,7 +97,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_npi_bitwise_and) @@ -121,9 +121,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and) NNVM_REGISTER_OP(_npi_bitwise_and_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -132,7 +130,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_xor) @@ -174,9 +172,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or) NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -185,15 +181,13 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_or_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -202,7 +196,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) @@ -274,14 +268,14 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); @@ -362,13 +356,13 @@ NNVM_REGISTER_OP(_backward_npi_ldexp) mshadow_op::ldexp_rgrad>); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); } // namespace op diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc index 7455da139a14..3ae3ab93e0bb 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc @@ -29,21 +29,23 @@ namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmax) .set_attr("FCompute", BinaryBroadcastCompute) @@ -69,8 +71,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmax_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmax_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmax_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmin) @@ -97,8 +99,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmin_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmin_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmin_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmod) @@ -125,8 +127,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmod_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmod_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar) @@ -134,8 +136,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rfmod_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rfmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); } // namespace op diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index ad15593031b9..8e766f45e2d2 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -35,6 +35,46 @@ namespace mxnet { namespace op { +struct NumpyBinaryScalarParam : public dmlc::Parameter { + double scalar; + bool is_int; + DMLC_DECLARE_PARAMETER(NumpyBinaryScalarParam) { + DMLC_DECLARE_FIELD(scalar) + .set_default(1) + .describe("Scalar input value"); + DMLC_DECLARE_FIELD(is_int) + .set_default(false) + .describe("Indicate whether scalar input is int type"); + } + + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream scalar_s, is_int_s; + scalar_s << scalar; + is_int_s << is_int; + (*dict)["scalar"] = scalar_s.str(); + (*dict)["is_int"] = is_int_s.str(); + } +}; + +inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + if (common::is_int(in_attrs->at(0)) && !scalar_is_int) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } else if (in_attrs->at(0) == mshadow::kBool) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, scalar_is_int ? mshadow::kInt64 : mshadow::kFloat64); + } + else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } + return out_attrs->at(0) != -1; +} + class BinaryScalarOp : public UnaryOp { /*! \brief Tensor operation against a scalar with a dense result */ template @@ -44,7 +84,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &input, const OpReqType req, const NDArray &output) { - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; CHECK_EQ(output.shape(), input.shape()); const int64_t row_count = output.shape()[0]; const int64_t items_per_row = output.shape().Size() / row_count; @@ -136,7 +177,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &output) { CHECK_EQ(output.shape(), input.shape()); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; const DType dense_fill_val = OP::Map(DType(0), DType(alpha)); const TBlob column_indexes = input.aux_data(csr::kIdx); const size_t item_count = column_indexes.Size(); @@ -235,11 +277,23 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + TBlob temp_tblob; + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if ((common::is_int(inputs[0].type_flag_) && !scalar_is_int) || + (inputs[0].type_flag_ == kBool)) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + temp_tblob = TBlob(temp_tensor); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + } else { + temp_tblob = inputs[0]; + } MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + s, inputs[0].Size(), outputs[0].dptr(), temp_tblob.dptr(), DType(alpha)); }); }); } @@ -255,7 +309,8 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( @@ -275,9 +330,11 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + const double alpha = param.scalar; TBlob temp_tblob; - if (common::is_int(inputs[0].type_flag_)) { + if (common::is_int(inputs[0].type_flag_) && !scalar_is_int) { Tensor temp_tensor = ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); temp_tblob = TBlob(temp_tensor); @@ -353,7 +410,8 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet::op::mxnet_op::Kernelparsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +#define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc index ae356deff0a1..6f96a55d2d4e 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc @@ -27,22 +27,24 @@ #include "./elemwise_binary_scalar_op.h" #define MXNET_OPERATOR_REGISTER_BINARY_WITH_SCALAR_SUPPORT_WITH_DENSE_RESULT(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ - .set_attr("FInferStorageType", \ - BinaryScalarStorageTypeWithDenseResultStorageType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInferStorageType", \ + BinaryScalarStorageTypeWithDenseResultStorageType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) namespace mxnet { namespace op { @@ -188,8 +190,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rdiv_scalar) .add_alias("_RDivScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rdiv_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rdiv_grad>); @@ -199,8 +201,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_mod_scalar) .add_alias("_ModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_mod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::mod_grad>); @@ -210,8 +212,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rmod_scalar) .add_alias("_RModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rmod_grad>); diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index 4ada2f036f7d..26b572a8ee99 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -35,8 +35,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_maximum_scalar) .add_alias("_npi_maximum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_maximum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) @@ -46,8 +46,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) .add_alias("_npi_minimum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_minimum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) @@ -56,8 +56,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) .add_alias("_PowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_power_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::power_grad>); @@ -68,8 +68,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rpower_scalar) .add_alias("_RPowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rpower_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rpower_grad>); @@ -81,8 +81,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_hypot_scalar) .add_alias("_HypotScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_hypot_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::hypot_grad_left>); @@ -108,13 +108,7 @@ Example:: )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FInplaceOption", @@ -127,13 +121,7 @@ Example:: .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_smooth_l1" }); MXNET_OPERATOR_REGISTER_BINARY(_backward_smooth_l1) - .set_attr_parser([](NodeAttrs *attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } -}) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d1167c4e180b..35cbc94a3cd4 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2660,6 +2660,66 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) +@with_seed() +@use_np +def test_np_binary_scalar_funcs(): + itypes = [np.int8, np.int32, np.int64] + def check_binary_scalar_func(func, low, high, lshape, lgrad, ltype, scalar_is_int): + class TestBinaryScalar(HybridBlock): + def __init__(self, func, scalar): + super(TestBinaryScalar, self).__init__() + self._func = func + self._scalar = scalar + + def hybrid_forward(self, F, a, *args, **kwargs): + return getattr(F.np, self._func)(a, self._scalar) + + np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype) + np_test_x2 = int(_np.random.uniform(low, high)) if scalar_is_int else _np.random.uniform(low, high) + mx_test_x1 = np.array(np_test_x1, dtype=ltype) + mx_test_x2 = np_test_x2 + np_func = getattr(_np, func) + mx_func = TestBinaryScalar(func, np_test_x2) + rtol = 1e-2 if ltype is np.float16 else 1e-3 + atol = 1e-3 if ltype is np.float16 else 1e-5 + if ltype not in itypes: + if lgrad: + mx_test_x1.attach_grad() + np_out = np_func(np_test_x1, np_test_x2) + with mx.autograd.record(): + y = mx_func(mx_test_x1, mx_test_x2) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol) + if lgrad: + y.backward() + assert_almost_equal(mx_test_x1.grad.asnumpy(), + collapse_sum_like(lgrad(y.asnumpy(), np_test_x1, np_test_x2), mx_test_x1.shape), + rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) + + # Test imperative + np_out = getattr(_np, func)(np_test_x1, np_test_x2) + mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2) + assert mx_out.shape == np_out.shape + assert mx_out.asnumpy().dtype == np_out.dtype + assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol) + + funcs = { + 'add': (-1.0, 1.0, None), + 'subtract': (-1.0, 1.0, None), + 'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape)), + 'power': (1.0, 5.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2), + } + + shapes = [(3, 2), (3, 0), (3, 1), (0, 2), (2, 3, 4)] + ltypes = [np.int32, np.int64, np.float16, np.float32, np.float64] + for func, func_data in funcs.items(): + low, high, lgrad = func_data + for shape in shapes: + for ltype in ltypes: + for is_int in [True, False]: + check_binary_scalar_func(func, low, high, shape, lgrad, ltype, is_int) + + @with_seed() @use_np def test_np_boolean_binary_funcs():