Skip to content

Commit

Permalink
fix binary scalar casting issue apache#16653
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed May 11, 2020
1 parent de51058 commit 5b4ac20
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 20 deletions.
32 changes: 18 additions & 14 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,24 @@
namespace mxnet {
namespace op {

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", NumpyBinaryScalarType) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("data", "NDArray-or-Symbol", "source input") \
.add_argument("scalar", "float", "scalar input")

bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
Expand Down
10 changes: 7 additions & 3 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return in_attrs->at(0) != -1;
if (common::is_int(in_attrs->at(0))) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64);
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
}
return out_attrs->at(0) != -1;
}

inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/elemwise_binary_scalar_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class VectorizedBinaryScalarBwd {

template <typename OP>
void BinaryScalarOp::Compute_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
mshadow::Stream<gpu>* s,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
Expand Down
22 changes: 19 additions & 3 deletions src/operator/tensor/elemwise_binary_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ class BinaryScalarOp : public UnaryOp {
public:
template<typename OP>
static void Compute_(const nnvm::NodeAttrs &attrs,
mshadow::Stream<cpu>* s,
const OpContext &ctx,
mshadow::Stream<cpu>* s,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
Expand All @@ -235,7 +236,7 @@ class BinaryScalarOp : public UnaryOp {
using namespace mshadow;
using namespace mshadow::expr;
const double alpha = nnvm::get<double>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
s, inputs[0].Size(), outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), DType(alpha));
Expand All @@ -246,6 +247,7 @@ class BinaryScalarOp : public UnaryOp {
#if MXNET_USE_CUDA
template<typename OP>
static void Compute_(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
mshadow::Stream<gpu>* s,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
Expand All @@ -259,7 +261,21 @@ class BinaryScalarOp : public UnaryOp {
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
Compute_<OP>(attrs, s, inputs, req, outputs);
using namespace mshadow;
using namespace mshadow::expr;
TBlob temp_tblob;
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
if (common::is_int(inputs[0].type_flag_)) {
Tensor<xpu, 1, DType> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(inputs[0].Size()), s);
temp_tblob = TBlob(temp_tensor);
CastCompute<xpu>(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob});
} else {
temp_tblob = inputs[0];
}
});
std::vector<TBlob> input{temp_tblob};
Compute_<OP>(attrs, ctx, s, input, req, outputs);
}

template<typename xpu, typename OP>
Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2661,6 +2661,64 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)


@with_seed()
@use_np
def test_np_binary_scalar_funcs():
itypes = [np.int8, np.int32, np.int64]
def check_binary_scalar_func(func, low, high, lshape, lgrad, ltype):
class TestBinaryScalar(HybridBlock):
def __init__(self, func, scalar):
super(TestBinaryScalar, self).__init__()
self._func = func
self._scalar = scalar

def hybrid_forward(self, F, a, *args, **kwargs):
return getattr(F.np, self._func)(a, self._scalar)

np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype)
np_test_x2 = _np.random.uniform(low, high)
mx_test_x1 = np.array(np_test_x1, dtype=ltype)
mx_test_x2 = np_test_x2
np_func = getattr(_np, func)
mx_func = TestBinaryScalar(func, np_test_x2)
rtol = 1e-2 if ltype is np.float16 else 1e-3
atol = 1e-3 if ltype is np.float16 else 1e-5
if ltype not in itypes:
if lgrad:
mx_test_x1.attach_grad()
np_out = np_func(np_test_x1, np_test_x2)
with mx.autograd.record():
y = mx_func(mx_test_x1, mx_test_x2)
assert y.shape == np_out.shape
assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol)
if lgrad:
y.backward()
assert_almost_equal(mx_test_x1.grad.asnumpy(),
collapse_sum_like(lgrad(y.asnumpy(), np_test_x1, np_test_x2), mx_test_x1.shape),
rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False)

# Test imperative
np_out = getattr(_np, func)(np_test_x1, np_test_x2)
mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol)

funcs = {
'add': (-1.0, 1.0, None),
'subtract': (-1.0, 1.0, None),
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape)),
'power': (1.0, 5.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2),
}

shapes = [(3, 2), (3, 0), (3, 1), (0, 2), (2, 3, 4)]
ltypes = [np.int32, np.int64, np.float16, np.float32, np.float64]
for func, func_data in funcs.items():
low, high, lgrad = func_data
for shape in shapes:
for ltype in ltypes:
check_binary_scalar_func(func, low, high, shape, lgrad, ltype)


@with_seed()
@use_np
def test_np_boolean_binary_funcs():
Expand Down

0 comments on commit 5b4ac20

Please sign in to comment.