From df8a3ff8540d48c5bb7b8c987461869fefa1d739 Mon Sep 17 00:00:00 2001 From: Minghao Liu Date: Mon, 11 May 2020 08:09:23 +0000 Subject: [PATCH] fix binary scalar logic dtype (#16964) --- .../numpy/np_elemwise_broadcast_logic_op.cc | 9 ++++++--- .../tensor/elemwise_binary_scalar_op.h | 19 ++++++++++++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index 970c5cd44e4d..0f5d9923a3e2 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -317,9 +317,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(logical_xor); #else -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ - NNVM_REGISTER_OP(_npi_##name##_scalar) \ - .set_attr("FCompute", BinaryScalarOp::ComputeLogic) +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", BinaryScalarOp::ComputeLogic) \ + .set_attr("FResourceRequest", [](const NodeAttrs& n) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) #endif // MXNET_USE_TVM_OP diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 3e8702813a7c..ad15593031b9 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -276,11 +276,20 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); const double alpha = nnvm::get(attrs.parsed); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); - }); + TBlob temp_tblob; + if (common::is_int(inputs[0].type_flag_)) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + temp_tblob = TBlob(temp_tensor); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + } else { + temp_tblob = inputs[0]; + } + MSHADOW_TYPE_SWITCH_WITH_BOOL(temp_tblob.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), temp_tblob.dptr(), DType(alpha)); + }); }); }