Skip to content

Commit

Permalink
Fix binary scalar dtype and add bool support (apache#18277)
Browse files Browse the repository at this point in the history
* add boolean support for concatenate (apache#18213)

* fix binary scalar logic dtype (apache#16964)

* common_expr test remove

* binary scalar op support scalar dtype

Co-Authored-By: Wentao Xu <[email protected]>

* test_error_fix

Co-authored-by: Wentao Xu <[email protected]>
  • Loading branch information
Tommliu and cassinixu authored May 22, 2020
1 parent d9fc74e commit 48dea6e
Show file tree
Hide file tree
Showing 18 changed files with 397 additions and 234 deletions.
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,13 +1577,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
if isinstance(rhs, numeric_types):
return fn_scalar(lhs, rhs, out=out)
else:
is_int = isinstance(rhs, integer_types)
if rfn_scalar is None:
# commutative function
return lfn_scalar(rhs, float(lhs), out=out)
return lfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
else:
return rfn_scalar(rhs, float(lhs), out=out)
return rfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out)
elif isinstance(rhs, numeric_types):
return lfn_scalar(lhs, float(rhs), out=out)
is_int = isinstance(rhs, integer_types)
return lfn_scalar(lhs, scalar=float(rhs), is_int=is_int, out=out)
elif isinstance(rhs, Symbol):
return fn_array(lhs, rhs, out=out)
else:
Expand Down
64 changes: 60 additions & 4 deletions src/api/operator/ufunc_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "ufunc_helper.h"
#include "utils.h"
#include "../../imperative/imperative_utils.h"
#include "../../operator/tensor/elemwise_binary_scalar_op.h"

namespace mxnet {

Expand Down Expand Up @@ -51,13 +52,38 @@ void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out,
}
}

void UFuncHelper(NDArray* lhs, int rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = rhs;
param.is_int = true;
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &lhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (outputs) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<NDArray*>(ndoutputs[0]);
}
}

void UFuncHelper(NDArray* lhs, double rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = rhs;
param.is_int = false;
attrs.op = op;
attrs.parsed = rhs;
SetAttrDict<double>(&attrs);
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &lhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
Expand All @@ -70,13 +96,38 @@ void UFuncHelper(NDArray* lhs, double rhs, NDArray* out,
}
}

void UFuncHelper(int lhs, NDArray* rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = lhs;
param.is_int = true;
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &rhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (outputs) {
*ret = PythonArg(2);
} else {
*ret = reinterpret_cast<NDArray*>(ndoutputs[0]);
}
}

void UFuncHelper(double lhs, NDArray* rhs, NDArray* out,
runtime::MXNetRetValue* ret, const nnvm::Op* op) {
using namespace runtime;
nnvm::NodeAttrs attrs;
op::NumpyBinaryScalarParam param;
param.scalar = lhs;
param.is_int = false;
attrs.op = op;
attrs.parsed = lhs;
SetAttrDict<double>(&attrs);
attrs.parsed = param;
SetAttrDict<op::NumpyBinaryScalarParam>(&attrs);
NDArray** inputs = &rhs;
int num_inputs = 1;
NDArray** outputs = out == nullptr ? nullptr : &out;
Expand All @@ -99,9 +150,14 @@ void UFuncHelper(runtime::MXNetArgs args,
if (args[0].type_code() == kNDArrayHandle) {
if (args[1].type_code() == kNDArrayHandle) {
UFuncHelper(args[0].operator NDArray*(), args[1].operator NDArray*(), out, ret, fn_array);
} else if (args[1].type_code() == kDLInt) {
UFuncHelper(args[0].operator NDArray*(), args[1].operator int(), out, ret, lfn_scalar);
} else {
UFuncHelper(args[0].operator NDArray*(), args[1].operator double(), out, ret, lfn_scalar);
}
} else if (args[0].type_code() == kDLInt) {
UFuncHelper(args[0].operator int(), args[1].operator NDArray*(), out, ret,
rfn_scalar ? rfn_scalar : lfn_scalar);
} else {
UFuncHelper(args[0].operator double(), args[1].operator NDArray*(), out, ret,
rfn_scalar ? rfn_scalar : lfn_scalar);
Expand Down
6 changes: 2 additions & 4 deletions src/operator/contrib/gradient_multiplier_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ In forward pass it acts as an identity transform. During backpropagation it
multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to
the preceding layer.
)code" ADD_FILELINE)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::IdentityComputeEx<cpu>)
Expand All @@ -87,7 +85,7 @@ the preceding layer.
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.add_argument("scalar", "float", "lambda multiplier");
.add_arguments(NumpyBinaryScalarParam::__FIELDS__());

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
Expand Down
18 changes: 10 additions & 8 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ struct TVMBinaryBroadcastScalarCompute {

// scalar param
type_codes[1] = kDLFloat;
values[1].v_float64 = nnvm::get<double>(attrs.parsed);
const NumpyBinaryScalarParam& param = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
values[1].v_float64 = param.scalar;

// output tensor
type_codes[2] = kTVMDLTensorHandle;
Expand All @@ -242,9 +243,7 @@ struct TVMBinaryBroadcastScalarCompute {
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"data"}; \
Expand All @@ -257,7 +256,7 @@ struct TVMBinaryBroadcastScalarCompute {
}) \
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) \
.add_argument("data", "NDArray-or-Symbol", "First input to the function") \
.add_argument("scalar", "float", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal);
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal);
Expand Down Expand Up @@ -317,9 +316,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(logical_xor);

#else

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>)
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>) \
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
})

#endif // MXNET_USE_TVM_OP

Expand Down
30 changes: 15 additions & 15 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,21 @@
namespace mxnet {
namespace op {

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
DMLC_REGISTER_PARAMETER(NumpyBinaryScalarParam);

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
Expand Down
10 changes: 0 additions & 10 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@
namespace mxnet {
namespace op {

inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return in_attrs->at(0) != -1;
}

inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
LOG(FATAL) << "Operator " << op_name << " does not support combination of "
<< mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2)
Expand Down
68 changes: 29 additions & 39 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,19 @@
namespace mxnet {
namespace op {

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
.describe(R"code()code" ADD_FILELINE)
Expand Down Expand Up @@ -86,9 +84,7 @@ NNVM_REGISTER_OP(_npi_lcm)
NNVM_REGISTER_OP(_npi_lcm_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -97,7 +93,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_bitwise_and)
Expand All @@ -121,9 +117,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and)
NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -132,7 +126,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_and>);

NNVM_REGISTER_OP(_npi_bitwise_xor)
Expand Down Expand Up @@ -174,9 +168,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or)
NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -185,15 +177,13 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_xor>);

NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
Expand All @@ -202,7 +192,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_or>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
Expand Down Expand Up @@ -274,14 +264,14 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"});

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_rgrad>);

Expand Down Expand Up @@ -362,13 +352,13 @@ NNVM_REGISTER_OP(_backward_npi_ldexp)
mshadow_op::ldexp_rgrad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::ldexp_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::rldexp_grad>);

} // namespace op
Expand Down
Loading

0 comments on commit 48dea6e

Please sign in to comment.