Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix binary scalar dtype and add bool support #18277

Merged
merged 5 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,13 +1577,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
if isinstance(rhs, numeric_types):
return fn_scalar(lhs, rhs, out=out)
else:
is_int = isinstance(rhs, integer_types)
if rfn_scalar is None:
# commutative function
return lfn_scalar(rhs, float(lhs), out=out)
return lfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
else:
return rfn_scalar(rhs, float(lhs), out=out)
return rfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
elif isinstance(rhs, numeric_types):
return lfn_scalar(lhs, float(rhs), out=out)
is_int = isinstance(rhs, integer_types)
return lfn_scalar(lhs, scalar=float(rhs), is_int=is_int, out=out)
elif isinstance(rhs, Symbol):
return fn_array(lhs, rhs, out=out)
else:
Expand Down
64 changes: 60 additions & 4 deletions src/api/operator/ufunc_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "ufunc_helper.h"
#include "utils.h"
#include "../../imperative/imperative_utils.h"
#include "../../operator/tensor/elemwise_binary_scalar_op.h"

namespace mxnet {

Expand Down Expand Up @@ -51,13 +52,38 @@ void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out,
}
}

void UFuncHelper(NDArray* lhs, int rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = rhs;
param.is_int = true;
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &lhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (outputs) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<NDArray*>(ndoutputs[0]);
}
}

void UFuncHelper(NDArray* lhs, double rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = rhs;
param.is_int = false;
attrs.op = op;
attrs.parsed = rhs;
SetAttrDict<double>(&attrs);
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &lhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
Expand All @@ -70,13 +96,38 @@ void UFuncHelper(NDArray* lhs, double rhs, NDArray* out,
}
}

void UFuncHelper(int lhs, NDArray* rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = lhs;
param.is_int = true;
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &rhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (outputs) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<NDArray*>(ndoutputs[0]);
}
}

void UFuncHelper(double lhs, NDArray* rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = lhs;
param.is_int = false;
attrs.op = op;
attrs.parsed = lhs;
SetAttrDict<double>(&attrs);
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &rhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
Expand All @@ -99,9 +150,14 @@ void UFuncHelper(runtime::MXNetArgs args,
if (args[0].type_code() == kNDArrayHandle) {
if (args[1].type_code() == kNDArrayHandle) {
UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array);
} else if (args[1].type_code() == kDLInt) {
UFuncHelper(args[0].operator NDArray*(), args[1].operator int(), out, ret, lfn_scalar);
} else {
UFuncHelper(args[0].operator NDArray*(), args[1].operator double(), out, ret, lfn_scalar);
}
} else if (args[0].type_code() == kDLInt) {
UFuncHelper(args[0].operator int(), args[1].operator NDArray*(), out, ret,
rfn_scalar ? rfn_scalar : lfn_scalar);
} else {
UFuncHelper(args[0].operator double(), args[1].operator NDArray*(), out, ret,
rfn_scalar ? rfn_scalar : lfn_scalar);
Expand Down
6 changes: 2 additions & 4 deletions src/operator/contrib/gradient_multiplier_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ In forward pass it acts as an identity transform. During backpropagation it
multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to
the preceding layer.
)code" ADD_FILELINE)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::IdentityComputeEx<cpu>)
Expand All @@ -87,7 +85,7 @@ the preceding layer.
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.add_argument("scalar", "float", "lambda multiplier");
.add_arguments(NumpyBinaryScalarParam::__FIELDS__());

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
Expand Down
18 changes: 10 additions & 8 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ struct TVMBinaryBroadcastScalarCompute {

// scalar param
type_codes[1] = kDLFloat;
values[1].v_float64 = nnvm::get<double>(attrs.parsed);
const NumpyBinaryScalarParam& param = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
values[1].v_float64 = param.scalar;

// output tensor
type_codes[2] = kTVMDLTensorHandle;
Expand All @@ -242,9 +243,7 @@ struct TVMBinaryBroadcastScalarCompute {
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"data"}; \
Expand All @@ -257,7 +256,7 @@ struct TVMBinaryBroadcastScalarCompute {
}) \
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \
.add_argument("data", "NDArray-or-Symbol", "First input to the function") \
.add_argument("scalar", "float", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal);
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal);
Expand Down Expand Up @@ -317,9 +316,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(logical_xor);

#else

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>)
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>) \
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
})

#endif // MXNET_USE_TVM_OP

Expand Down
30 changes: 15 additions & 15 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@
namespace mxnet {
namespace op {

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
DMLC_REGISTER_PARAMETER(NumpyBinaryScalarParam);

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
Expand Down
10 changes: 0 additions & 10 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@
namespace mxnet {
namespace op {

inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return in_attrs->at(0) != -1;
}

inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
LOG(FATAL) << "Operator " << op_name << " does not support combination of "
<< mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2)
Expand Down
68 changes: 29 additions & 39 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,19 @@
namespace mxnet {
namespace op {

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
.describe(R"code()code" ADD_FILELINE)
Expand Down Expand Up @@ -86,9 +84,7 @@ NNVM_REGISTER_OP(_npi_lcm)
NNVM_REGISTER_OP(_npi_lcm_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -97,7 +93,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_bitwise_and)
Expand All @@ -121,9 +117,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and)
NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -132,7 +126,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_and>);

NNVM_REGISTER_OP(_npi_bitwise_xor)
Expand Down Expand Up @@ -174,9 +168,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or)
NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -185,15 +177,13 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_xor>);

NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -202,7 +192,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_or>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
Expand Down Expand Up @@ -274,14 +264,14 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"});

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_rgrad>);

Expand Down Expand Up @@ -362,13 +352,13 @@ NNVM_REGISTER_OP(_backward_npi_ldexp)
mshadow_op::ldexp_rgrad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::ldexp_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::rldexp_grad>);

} // namespace op
Expand Down
Loading