Skip to content

Commit

Permalink
fix binary scalar logic dtype (apache#16964)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed May 18, 2020
1 parent 8bfde6a commit df8a3ff
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(logical_xor);

#else

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>)
#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \
NNVM_REGISTER_OP(_npi_##name##_scalar) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeLogic<cpu, mshadow_op::np_##name>) \
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
})

#endif // MXNET_USE_TVM_OP

Expand Down
19 changes: 14 additions & 5 deletions src/operator/tensor/elemwise_binary_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,20 @@ class BinaryScalarOp : public UnaryOp {
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
const double alpha = nnvm::get<double>(attrs.parsed);
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
s, inputs[0].Size(), outputs[0].dptr<bool>(), inputs[0].dptr<DType>(), DType(alpha));
});
TBlob temp_tblob;
if (common::is_int(inputs[0].type_flag_)) {
Tensor<xpu, 1, double> temp_tensor =
ctx.requested[0].get_space_typed<xpu, 1, double>(Shape1(inputs[0].Size()), s);
temp_tblob = TBlob(temp_tensor);
CastCompute<xpu>(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob});
} else {
temp_tblob = inputs[0];
}
MSHADOW_TYPE_SWITCH_WITH_BOOL(temp_tblob.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
s, inputs[0].Size(), outputs[0].dptr<bool>(), temp_tblob.dptr<DType>(), DType(alpha));
});
});
}

Expand Down

0 comments on commit df8a3ff

Please sign in to comment.