Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
binary scalar op support scalar dtype
Browse files Browse the repository at this point in the history
Co-Authored-By: Wentao Xu <[email protected]>
  • Loading branch information
Tommliu and cassinixu committed May 19, 2020
1 parent 8eee524 commit d970243
Show file tree
Hide file tree
Showing 11 changed files with 323 additions and 167 deletions.
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,13 +1571,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
if isinstance(rhs, numeric_types):
return fn_scalar(lhs, rhs, out=out)
else:
is_int = isinstance(rhs, integer_types)
if rfn_scalar is None:
# commutative function
return lfn_scalar(rhs, float(lhs), out=out)
return lfn_scalar(rhs, float(lhs), is_int=is_int, out=out)
else:
return rfn_scalar(rhs, float(lhs), out=out)
return rfn_scalar(rhs, float(lhs), is_int=is_int, out=out)
elif isinstance(rhs, numeric_types):
return lfn_scalar(lhs, float(rhs), out=out)
is_int = isinstance(rhs, integer_types)
return lfn_scalar(lhs, float(rhs), is_int=is_int, out=out)
elif isinstance(rhs, Symbol):
return fn_array(lhs, rhs, out=out)
else:
Expand Down
64 changes: 60 additions & 4 deletions src/api/operator/ufunc_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "ufunc_helper.h"
#include "utils.h"
#include "../../imperative/imperative_utils.h"
#include "../../operator/tensor/elemwise_binary_scalar_op.h"

namespace mxnet {

Expand Down Expand Up @@ -51,13 +52,38 @@ void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out,
}
}

void UFuncHelper(NDArray* lhs, int rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = rhs;
param.is_int = true;
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &lhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (outputs) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<NDArray*>(ndoutputs[0]);
}
}

void UFuncHelper(NDArray* lhs, double rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = rhs;
param.is_int = false;
attrs.op = op;
attrs.parsed = rhs;
SetAttrDict<double>(&attrs);
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &lhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
Expand All @@ -70,13 +96,38 @@ void UFuncHelper(NDArray* lhs, double rhs, NDArray* out,
}
}

void UFuncHelper(int lhs, NDArray* rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = lhs;
param.is_int = true;
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &rhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (outputs) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<NDArray*>(ndoutputs[0]);
}
}

void UFuncHelper(double lhs, NDArray* rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = lhs;
param.is_int = false;
attrs.op = op;
attrs.parsed = lhs;
SetAttrDict<double>(&attrs);
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &rhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
Expand All @@ -99,9 +150,14 @@ void UFuncHelper(runtime::MXNetArgs args,
if (args[0].type_code() == kNDArrayHandle) {
if (args[1].type_code() == kNDArrayHandle) {
UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array);
} else if (args[1].type_code() == kDLInt){
UFuncHelper(args[0].operator NDArray*(), args[1].operator int(), out, ret, lfn_scalar);
} else {
UFuncHelper(args[0].operator NDArray*(), args[1].operator double(), out, ret, lfn_scalar);
}
} else if (args[0].type_code() == kDLInt) {
UFuncHelper(args[0].operator int(), args[1].operator NDArray*(), out, ret,
rfn_scalar ? rfn_scalar : lfn_scalar);
} else {
UFuncHelper(args[0].operator double(), args[1].operator NDArray*(), out, ret,
rfn_scalar ? rfn_scalar : lfn_scalar);
Expand Down
6 changes: 2 additions & 4 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ struct TVMBinaryBroadcastScalarCompute {
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"data"}; \
Expand All @@ -257,7 +255,7 @@ struct TVMBinaryBroadcastScalarCompute {
}) \
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \
.add_argument("data", "NDArray-or-Symbol", "First input to the function") \
.add_argument("scalar", "float", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal);
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal);
Expand Down
34 changes: 19 additions & 15 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,25 @@
namespace mxnet {
namespace op {

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
DMLC_REGISTER_PARAMETER(NumpyBinaryScalarParam);

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
Expand Down
10 changes: 0 additions & 10 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@
namespace mxnet {
namespace op {

inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return in_attrs->at(0) != -1;
}

inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
LOG(FATAL) << "Operator " << op_name << " does not support combination of "
<< mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2)
Expand Down
72 changes: 33 additions & 39 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,23 @@
namespace mxnet {
namespace op {

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
.describe(R"code()code" ADD_FILELINE)
Expand Down Expand Up @@ -86,9 +88,7 @@ NNVM_REGISTER_OP(_npi_lcm)
NNVM_REGISTER_OP(_npi_lcm_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -97,7 +97,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_bitwise_and)
Expand All @@ -121,9 +121,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and)
NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -132,7 +130,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_and>);

NNVM_REGISTER_OP(_npi_bitwise_xor)
Expand Down Expand Up @@ -174,9 +172,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or)
NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -185,15 +181,13 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_xor>);

NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -202,7 +196,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_or>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
Expand Down Expand Up @@ -274,14 +268,14 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"});

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_rgrad>);

Expand Down Expand Up @@ -362,13 +356,13 @@ NNVM_REGISTER_OP(_backward_npi_ldexp)
mshadow_op::ldexp_rgrad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::ldexp_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::rldexp_grad>);

} // namespace op
Expand Down
Loading

0 comments on commit d970243

Please sign in to comment.