From eeb7e20db3efa4f5d2e0387c779eb44ae45d9c7d Mon Sep 17 00:00:00 2001 From: Minghao Liu <40382964+Tommliu@users.noreply.github.com> Date: Mon, 18 May 2020 18:11:30 +0800 Subject: [PATCH] binary scalar op support scalar dtype We wrote this together Co-Authored-By: Wentao Xu --- src/api/operator/ufunc_helper.cc | 64 ++++++++++++++- .../numpy/np_elemwise_broadcast_logic_op.cc | 6 +- .../numpy/np_elemwise_broadcast_op.cc | 8 +- src/operator/numpy/np_elemwise_broadcast_op.h | 14 ---- .../np_elemwise_broadcast_op_extended.cc | 46 +++++------ .../np_elemwise_broadcast_op_extended_sec.cc | 22 +++--- .../tensor/elemwise_binary_scalar_op.cuh | 6 +- .../tensor/elemwise_binary_scalar_op.h | 78 +++++++++++++++---- .../tensor/elemwise_binary_scalar_op_basic.cc | 20 +++-- .../elemwise_binary_scalar_op_extended.cc | 36 +++------ tests/python/unittest/test_numpy_op.py | 8 +- 11 files changed, 187 insertions(+), 121 deletions(-) diff --git a/src/api/operator/ufunc_helper.cc b/src/api/operator/ufunc_helper.cc index c29d1b7985be..9f4dcb7992f0 100644 --- a/src/api/operator/ufunc_helper.cc +++ b/src/api/operator/ufunc_helper.cc @@ -24,6 +24,7 @@ #include "ufunc_helper.h" #include "utils.h" #include "../../imperative/imperative_utils.h" +#include "../../operator/tensor/elemwise_binary_scalar_op.h" namespace mxnet { @@ -51,13 +52,38 @@ void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out, } } +void UFuncHelper(NDArray* lhs, int rhs, NDArray* out, + runtime::MXNetRetValue* ret, const nnvm::Op* op) { + using namespace runtime; + nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = rhs; + param.is_int = true; + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); + NDArray** inputs = &lhs; + int num_inputs = 1; + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (outputs) { + *ret = PythonArg(2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +} + void UFuncHelper(NDArray* lhs, double rhs, NDArray* out, runtime::MXNetRetValue* ret, const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = rhs; + param.is_int = false; attrs.op = op; - attrs.parsed = rhs; - SetAttrDict(&attrs); + attrs.parsed = param; + SetAttrDict(&attrs); NDArray** inputs = &lhs; int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; @@ -70,13 +96,38 @@ void UFuncHelper(NDArray* lhs, double rhs, NDArray* out, } } +void UFuncHelper(int lhs, NDArray* rhs, NDArray* out, + runtime::MXNetRetValue* ret, const nnvm::Op* op) { + using namespace runtime; + nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = lhs; + param.is_int = true; + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); + NDArray** inputs = &rhs; + int num_inputs = 1; + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (outputs) { + *ret = PythonArg(2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +} + void UFuncHelper(double lhs, NDArray* rhs, NDArray* out, runtime::MXNetRetValue* ret, const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; + op::NumpyBinaryScalarParam param; + param.scalar = lhs; + param.is_int = false; attrs.op = op; - attrs.parsed = lhs; - SetAttrDict(&attrs); + attrs.parsed = param; + SetAttrDict(&attrs); NDArray** inputs = &rhs; int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; @@ -99,9 +150,14 @@ void UFuncHelper(runtime::MXNetArgs args, if (args[0].type_code() == kNDArrayHandle) { if (args[1].type_code() == kNDArrayHandle) { UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array); + } else if (args[1].type_code() == kDLInt){ + UFuncHelper(args[0].operator NDArray*(), args[1].operator int(), out, ret, lfn_scalar); } else { UFuncHelper(args[0].operator NDArray*(), args[1].operator double(), out, ret, lfn_scalar); } + } else if (args[0].type_code() == kDLInt) { + UFuncHelper(args[0].operator int(), args[1].operator NDArray*(), out, ret, + rfn_scalar ? rfn_scalar : lfn_scalar); } else { UFuncHelper(args[0].operator double(), args[1].operator NDArray*(), out, ret, rfn_scalar ? rfn_scalar : lfn_scalar); diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index 0f5d9923a3e2..b44a22d65856 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -242,9 +242,7 @@ struct TVMBinaryBroadcastScalarCompute { NNVM_REGISTER_OP(_npi_##name##_scalar) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FListInputNames", \ [](const NodeAttrs& attrs) { \ return std::vector{"data"}; \ @@ -257,7 +255,7 @@ struct TVMBinaryBroadcastScalarCompute { }) \ .set_attr("FGradient", MakeZeroGradNodes) \ .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 95a8ffcb0946..c2db0fee7512 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -28,13 +28,13 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(NumpyBinaryScalarParam); + #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FInferShape", ElemwiseShape<1, 1>) \ .set_attr("FInferType", NumpyBinaryScalarType) \ .set_attr("FInplaceOption", \ @@ -46,7 +46,7 @@ namespace op { return std::vector{ResourceRequest::kTempSpace}; \ }) \ .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index 979b18245038..8bf32f58b2aa 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -35,20 +35,6 @@ namespace mxnet { namespace op { -inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - if (common::is_int(in_attrs->at(0))) { - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); - } else { - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - } - return out_attrs->at(0) != -1; -} - inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { LOG(FATAL) << "Operator " << op_name << " does not support combination of " << mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index c3f87532c1b5..5ba2d5431bcc 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -33,9 +33,7 @@ namespace op { NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FInferShape", ElemwiseShape<1, 1>) \ .set_attr("FInferType", NumpyBinaryScalarType) \ .set_attr("FInplaceOption", \ @@ -47,7 +45,7 @@ namespace op { return std::vector{ResourceRequest::kTempSpace}; \ }) \ .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) .describe(R"code()code" ADD_FILELINE) @@ -90,9 +88,7 @@ NNVM_REGISTER_OP(_npi_lcm) NNVM_REGISTER_OP(_npi_lcm_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -105,7 +101,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_npi_bitwise_and) @@ -129,9 +125,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and) NNVM_REGISTER_OP(_npi_bitwise_and_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -140,7 +134,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_xor) @@ -182,9 +176,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or) NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -193,15 +185,13 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_or_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -210,7 +200,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) @@ -282,14 +272,14 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); @@ -370,13 +360,13 @@ NNVM_REGISTER_OP(_backward_npi_ldexp) mshadow_op::ldexp_rgrad>); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); } // namespace op diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc index 05806baf95c0..3ae3ab93e0bb 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc @@ -33,9 +33,7 @@ namespace op { NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FInferShape", ElemwiseShape<1, 1>) \ .set_attr("FInferType", NumpyBinaryScalarType) \ .set_attr("FInplaceOption", \ @@ -47,7 +45,7 @@ namespace op { return std::vector{ResourceRequest::kTempSpace}; \ }) \ .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmax) .set_attr("FCompute", BinaryBroadcastCompute) @@ -73,8 +71,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmax_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmax_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmax_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmin) @@ -101,8 +99,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmin_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmin_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmin_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_fmod) @@ -129,8 +127,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_fmod_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_fmod_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_fmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar) @@ -138,8 +136,8 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rfmod_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rfmod_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rfmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); } // namespace op diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cuh b/src/operator/tensor/elemwise_binary_scalar_op.cuh index 062c18767ac6..2795a38f5f70 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.cuh +++ b/src/operator/tensor/elemwise_binary_scalar_op.cuh @@ -155,7 +155,8 @@ void BinaryScalarOp::Compute_(const nnvm::NodeAttrs &attrs, if (req[0] == kNullOp) return; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { using LType = uint4; @@ -182,7 +183,8 @@ void BinaryScalarOp::Backward_(const nnvm::NodeAttrs &attrs, if (req[0] == kNullOp) return; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { using LType = uint4; diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 4419ef17864b..2af68873b530 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -35,6 +35,46 @@ namespace mxnet { namespace op { +struct NumpyBinaryScalarParam : public dmlc::Parameter { + double scalar; + bool is_int; + DMLC_DECLARE_PARAMETER(NumpyBinaryScalarParam) { + DMLC_DECLARE_FIELD(scalar) + .set_default(1) + .describe("Scalar input value"); + DMLC_DECLARE_FIELD(is_int) + .set_default(false) + .describe("Indicate whether scalar input is int type"); + } + + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream scalar_s, is_int_s; + scalar_s << scalar; + is_int_s << is_int; + (*dict)["scalar"] = scalar_s.str(); + (*dict)["is_int"] = is_int_s.str(); + } +}; + +inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + if (common::is_int(in_attrs->at(0)) && !scalar_is_int) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } else if (in_attrs->at(0) == mshadow::kBool) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, scalar_is_int ? mshadow::kInt64 : mshadow::kFloat64); + } + else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } + return out_attrs->at(0) != -1; +} + class BinaryScalarOp : public UnaryOp { /*! \brief Tensor operation against a scalar with a dense result */ template @@ -44,7 +84,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &input, const OpReqType req, const NDArray &output) { - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; CHECK_EQ(output.shape(), input.shape()); const int64_t row_count = output.shape()[0]; const int64_t items_per_row = output.shape().Size() / row_count; @@ -136,7 +177,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &output) { CHECK_EQ(output.shape(), input.shape()); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; const DType dense_fill_val = OP::Map(DType(0), DType(alpha)); const TBlob column_indexes = input.aux_data(csr::kIdx); const size_t item_count = column_indexes.Size(); @@ -234,7 +276,8 @@ class BinaryScalarOp : public UnaryOp { DCHECK_EQ(outputs.size(), 1); using namespace mshadow; using namespace mshadow::expr; - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, cpu>::Launch( @@ -262,8 +305,11 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; TBlob temp_tblob; + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (common::is_int(inputs[0].type_flag_)) { + if ((common::is_int(inputs[0].type_flag_) && !scalar_is_int) || + (inputs[0].type_flag_ == kBool)) { Tensor temp_tensor = ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); temp_tblob = TBlob(temp_tensor); @@ -287,7 +333,8 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( @@ -307,11 +354,13 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + const double alpha = param.scalar; TBlob temp_tblob; - if (common::is_int(inputs[0].type_flag_)) { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + if (common::is_int(inputs[0].type_flag_) && !scalar_is_int) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); temp_tblob = TBlob(temp_tensor); CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); } else { @@ -384,7 +433,8 @@ class BinaryScalarOp : public UnaryOp { const std::vector &outputs) { using namespace mshadow; using namespace mshadow::expr; - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet::op::mxnet_op::Kernelparsed = std::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}}; \ @@ -436,7 +484,7 @@ class BinaryScalarOp : public UnaryOp { return std::vector{ResourceRequest::kTempSpace}; \ }) \ .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc index 00d2b9cdd0e8..cca71550b7bf 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc @@ -30,11 +30,9 @@ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ .set_attr("FInferStorageType", \ BinaryScalarStorageTypeWithDenseResultStorageType) \ .set_attr("FInplaceOption", \ @@ -46,7 +44,7 @@ return std::vector{ResourceRequest::kTempSpace}; \ }) \ .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) namespace mxnet { namespace op { @@ -192,8 +190,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rdiv_scalar) .add_alias("_RDivScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rdiv_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rdiv_grad>); @@ -203,8 +201,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_mod_scalar) .add_alias("_ModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_mod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::mod_grad>); @@ -214,8 +212,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rmod_scalar) .add_alias("_RModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rmod_grad>); diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index 31cfe59ca9ed..3949d20552c1 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -35,8 +35,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_maximum_scalar) .add_alias("_npi_maximum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_maximum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) @@ -46,8 +46,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) .add_alias("_npi_minimum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_minimum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) @@ -56,8 +56,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) .add_alias("_PowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_power_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::power_grad>); @@ -68,8 +68,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rpower_scalar) .add_alias("_RPowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rpower_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rpower_grad>); @@ -81,8 +81,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_hypot_scalar) .add_alias("_HypotScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_hypot_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::hypot_grad_left>); @@ -108,13 +108,7 @@ Example:: )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FInplaceOption", @@ -131,13 +125,7 @@ Example:: .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_smooth_l1" }); MXNET_OPERATOR_REGISTER_BINARY(_backward_smooth_l1) - .set_attr_parser([](NodeAttrs *attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } -}) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 09404b5f150a..cadeaa539a09 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2665,7 +2665,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): @use_np def test_np_binary_scalar_funcs(): itypes = [np.int8, np.int32, np.int64] - def check_binary_scalar_func(func, low, high, lshape, lgrad, ltype): + def check_binary_scalar_func(func, low, high, lshape, lgrad, ltype, scalar_is_int): class TestBinaryScalar(HybridBlock): def __init__(self, func, scalar): super(TestBinaryScalar, self).__init__() @@ -2676,7 +2676,7 @@ def hybrid_forward(self, F, a, *args, **kwargs): return getattr(F.np, self._func)(a, self._scalar) np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype) - np_test_x2 = _np.random.uniform(low, high) + np_test_x2 = int(_np.random.uniform(low, high)) if scalar_is_int else _np.random.uniform(low, high) mx_test_x1 = np.array(np_test_x1, dtype=ltype) mx_test_x2 = np_test_x2 np_func = getattr(_np, func) @@ -2701,6 +2701,7 @@ def hybrid_forward(self, F, a, *args, **kwargs): np_out = getattr(_np, func)(np_test_x1, np_test_x2) mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2) assert mx_out.shape == np_out.shape + assert mx_out.asnumpy().dtype == np_out.dtype assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol) funcs = { @@ -2716,7 +2717,8 @@ def hybrid_forward(self, F, a, *args, **kwargs): low, high, lgrad = func_data for shape in shapes: for ltype in ltypes: - check_binary_scalar_func(func, low, high, shape, lgrad, ltype) + for is_int in [True, False]: + check_binary_scalar_func(func, low, high, shape, lgrad, ltype, is_int) @with_seed()