diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 6ec880e0ba8b..95a8ffcb0946 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -28,20 +28,24 @@ namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = std::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ .add_argument("scalar", "float", "scalar input") bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index a0e204318839..323cb05375f8 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -40,9 +40,13 @@ inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return in_attrs->at(0) != -1; + if (common::is_int(in_attrs->at(0))) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } + return out_attrs->at(0) != -1; } inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cuh b/src/operator/tensor/elemwise_binary_scalar_op.cuh index 062c18767ac6..d4343af2efea 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.cuh +++ b/src/operator/tensor/elemwise_binary_scalar_op.cuh @@ -147,6 +147,7 @@ class VectorizedBinaryScalarBwd { template void BinaryScalarOp::Compute_(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, mshadow::Stream* s, const std::vector &inputs, const std::vector &req, diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index f974332252d8..4d9a671829cd 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -226,7 +226,8 @@ class BinaryScalarOp : public UnaryOp { public: template static void Compute_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, + const OpContext &ctx, + mshadow::Stream* s, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { @@ -235,7 +236,7 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; const double alpha = nnvm::get(attrs.parsed); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, cpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); @@ -246,6 +247,7 @@ class BinaryScalarOp : public UnaryOp { #if MXNET_USE_CUDA template static void Compute_(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, mshadow::Stream* s, const std::vector &inputs, const std::vector &req, @@ -259,7 +261,21 @@ class BinaryScalarOp : public UnaryOp { const std::vector &req, const std::vector &outputs) { mshadow::Stream *s = ctx.get_stream(); - Compute_(attrs, s, inputs, req, outputs); + using namespace mshadow; + using namespace mshadow::expr; + TBlob temp_tblob; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if (common::is_int(inputs[0].type_flag_)) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + temp_tblob = TBlob(temp_tensor); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + } else { + temp_tblob = inputs[0]; + } + }); + std::vector input{temp_tblob}; + Compute_(attrs, ctx, s, input, req, outputs); } template diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index bb07a57c85e6..0a9082fe5906 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2661,6 +2661,64 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) +@with_seed() +@use_np +def test_np_binary_scalar_funcs(): + itypes = [np.int8, np.int32, np.int64] + def check_binary_scalar_func(func, low, high, lshape, lgrad, ltype): + class TestBinaryScalar(HybridBlock): + def __init__(self, func, scalar): + super(TestBinaryScalar, self).__init__() + self._func = func + self._scalar = scalar + + def hybrid_forward(self, F, a, *args, **kwargs): + return getattr(F.np, self._func)(a, self._scalar) + + np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype) + np_test_x2 = _np.random.uniform(low, high) + mx_test_x1 = np.array(np_test_x1, dtype=ltype) + mx_test_x2 = np_test_x2 + np_func = getattr(_np, func) + mx_func = TestBinaryScalar(func, np_test_x2) + rtol = 1e-2 if ltype is np.float16 else 1e-3 + atol = 1e-3 if ltype is np.float16 else 1e-5 + if ltype not in itypes: + if lgrad: + mx_test_x1.attach_grad() + np_out = np_func(np_test_x1, np_test_x2) + with mx.autograd.record(): + y = mx_func(mx_test_x1, mx_test_x2) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol) + if lgrad: + y.backward() + assert_almost_equal(mx_test_x1.grad.asnumpy(), + collapse_sum_like(lgrad(y.asnumpy(), np_test_x1, np_test_x2), mx_test_x1.shape), + rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) + + # Test imperative + np_out = getattr(_np, func)(np_test_x1, np_test_x2) + mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol) + + funcs = { + 'add': (-1.0, 1.0, None), + 'subtract': (-1.0, 1.0, None), + 'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape)), + 'power': (1.0, 5.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2), + } + + shapes = [(3, 2), (3, 0), (3, 1), (0, 2), (2, 3, 4)] + ltypes = [np.int32, np.int64, np.float16, np.float32, np.float64] + for func, func_data in funcs.items(): + low, high, lgrad = func_data + for shape in shapes: + for ltype in ltypes: + check_binary_scalar_func(func, low, high, shape, lgrad, ltype) + + @with_seed() @use_np def test_np_boolean_binary_funcs():