diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 58025e26818b..120745cd9e52 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -6350,7 +6350,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): @set_module('mxnet.ndarray.numpy') -def where(condition, x=None, y=None): +def where(condition, x=None, y=None): # pylint: disable=too-many-return-statements """where(condition, [x, y]) Return elements chosen from `x` or `y` depending on `condition`. @@ -6380,6 +6380,14 @@ def where(condition, x=None, y=None): [xv if c else yv for c, xv, yv in zip(condition, x, y)] + This function differs from the original `numpy.where + `_ in + the following way(s): + + - If `condition` is a scalar, this operator returns x or y directly without broadcasting. + - If `condition` is ndarray, while both `x` and `y` are scalars, + the output dtype will be `float32`. + Examples -------- >>> a = np.arange(10) @@ -6410,7 +6418,7 @@ def where(condition, x=None, y=None): >>> a = np.array([[0, 1, 2], ... [0, 2, 4], ... [0, 3, 6]]) - >>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast + >>> np.where(a < 4, a, -1) # -1 is broadcast array([[ 0., 1., 2.], [ 0., 2., -1.], [ 0., 3., -1.]]) @@ -6418,7 +6426,22 @@ def where(condition, x=None, y=None): if x is None and y is None: return nonzero(condition) else: - return _npi.where(condition, x, y, out=None) + if isinstance(condition, numeric_types): + if condition != 0: + return x + else: + return y + else: + if isinstance(x, numeric_types) and isinstance(y, numeric_types): + return _npi.where_scalar2(condition, float(x), float(y), out=None) + elif isinstance(x, NDArray) and isinstance(y, NDArray): + return _npi.where(condition, x, y, out=None) + elif isinstance(y, NDArray): + return _npi.where_lscalar(condition, y, float(x), out=None) + elif isinstance(x, NDArray): + return _npi.where_rscalar(condition, x, float(y), out=None) + else: + raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y)))) @set_module('mxnet.ndarray.numpy') diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 761697135cd1..88ecddfeaf09 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -8495,7 +8495,7 @@ def where(condition, x=None, y=None): >>> a = np.array([[0, 1, 2], ... [0, 2, 4], ... [0, 3, 6]]) - >>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast + >>> np.where(a < 4, a, -1) # -1 is broadcast array([[ 0., 1., 2.], [ 0., 2., -1.], [ 0., 3., -1.]]) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 0e4571968c5d..e608062ac028 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -5772,7 +5772,22 @@ def where(condition, x, y): from `y` elsewhere. """ - return _npi.where(condition, x, y, out=None) + if isinstance(condition, numeric_types): + if condition != 0: + return x + else: + return y + else: + if isinstance(x, numeric_types) and isinstance(y, numeric_types): + return _npi.where_scalar2(condition, float(x), float(y), out=None) + elif isinstance(x, Symbol) and isinstance(y, Symbol): + return _npi.where(condition, x, y, out=None) + elif isinstance(y, Symbol): + return _npi.where_lscalar(condition, y, float(x), out=None) + elif isinstance(x, Symbol): + return _npi.where_rscalar(condition, x, float(y), out=None) + else: + raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y)))) @set_module('mxnet.symbol.numpy') diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h index c019a523fddd..872ff18bfd02 100644 --- a/src/operator/numpy/np_where_op-inl.h +++ b/src/operator/numpy/np_where_op-inl.h @@ -42,6 +42,27 @@ namespace op { using namespace mshadow; +struct NumpyWhereScalarParam : public dmlc::Parameter { + double scalar; + DMLC_DECLARE_PARAMETER(NumpyWhereScalarParam) { + DMLC_DECLARE_FIELD(scalar) + .set_default(0.0) + .describe("The scalar value of x/y."); + } +}; + +struct NumpyWhereScalar2Param : public dmlc::Parameter { + double x, y; + DMLC_DECLARE_PARAMETER(NumpyWhereScalar2Param) { + DMLC_DECLARE_FIELD(x) + .set_default(0.0) + .describe("The scalar value of x."); + DMLC_DECLARE_FIELD(y) + .set_default(0.0) + .describe("The scalar value of y."); + } +}; + template struct numpy_where_kernel { template @@ -73,6 +94,31 @@ struct numpy_where_backward_kernel { } }; +template +struct numpy_where_scalar_kernel { + template + MSHADOW_XINLINE static void Map(index_t base, OpReqType req, const Shape &cstride, + const Shape &ystride, const Shape &oshape, + CType *datac, DType datax, DType *datay, DType *out) { + Shape coord = mxnet_op::unravel(base, oshape); + auto cidx = static_cast(mxnet_op::dot(coord, cstride)); + auto yidx = static_cast(mxnet_op::dot(coord, ystride)); + if (is_left) { + KERNEL_ASSIGN(out[base], req, datac[cidx] != CType(0) ? datax : datay[yidx]); + } else { + KERNEL_ASSIGN(out[base], req, datac[cidx] != CType(0) ? datay[yidx] : datax); + } + } +}; + +struct numpy_where_scalar2_kernel { + template + MSHADOW_XINLINE static void Map(index_t i, OpReqType req, DType* out, const CType* cond, + const DType x, const DType y) { + KERNEL_ASSIGN(out[i], req, (CType(0) != cond[i]? x : y)); + } +}; + template inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -166,7 +212,7 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, Tensor largespace; Tensor workspace; size_t ws_size = 0; - if (!(ograd.shape_ != dx.shape_) || !(ograd.shape_ != dy.shape_)) { + if (ograd.shape_ != dx.shape_ || ograd.shape_ != dy.shape_) { size_t ws_size1 = broadcast::ReduceWorkspaceSize( s, expanded_lshape, req[0], expanded_oshape); size_t ws_size2 = broadcast::ReduceWorkspaceSize( @@ -221,6 +267,144 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, }); } +template +inline void NumpyWhereScalarOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor + CHECK_LE(outputs[0].shape_.ndim(), broadcast::MAX_DIM); + + const NumpyWhereScalarParam& param = nnvm::get(attrs.parsed); + const TBlob& cond = inputs[0]; + const TBlob& y = inputs[1]; + const TBlob& out = outputs[0]; + Stream *s = ctx.get_stream(); + std::vector> in_strides; + in_strides.resize(2); + for (int i = 0; i < 2; ++i) { + TShape expanded_ishape(broadcast::MAX_DIM, 1); + const TShape& ishape = inputs[i].shape_; + const int ndim_delta = expanded_ishape.ndim() - ishape.ndim(); + for (int j = 0; j < ishape.ndim(); ++j) { + expanded_ishape[j + ndim_delta] = ishape[j]; + } + in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get()); + } + TShape expanded_oshape(broadcast::MAX_DIM, 1); + const int ndim_delta = expanded_oshape.ndim() - out.shape_.ndim(); + for (int j = 0; j < out.shape_.ndim(); ++j) { + expanded_oshape[j + ndim_delta] = (out.shape_)[j]; + } + Shape oshape = expanded_oshape.get(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(out.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), req[0], + in_strides[0], in_strides[1], oshape, + cond.dptr(), DType(param.scalar), + y.dptr(), out.dptr()); + }); + }); +} + +template +inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK(common::is_float(inputs[0].type_flag_)) << "Backward only supports float types!"; + if (inputs[0].shape_.Size() == 0U) return; // zero-size tensor + + Stream *s = ctx.get_stream(); + const TBlob& ograd = inputs[0]; + const TBlob& cond = inputs[1]; + const TBlob& dx = outputs[0]; + // get expanded oshape + TShape expanded_oshape(broadcast::MAX_DIM, 1); + int ndim_delta = expanded_oshape.ndim() - ograd.shape_.ndim(); + for (int j = 0; j < ograd.shape_.ndim(); ++j) { + expanded_oshape[j + ndim_delta] = (ograd.shape_)[j]; + } + Shape oshape = expanded_oshape.get(); + // get cond stride + TShape expanded_cshape(broadcast::MAX_DIM, 1); + ndim_delta = expanded_cshape.ndim() - cond.shape_.ndim(); + for (int j = 0; j < cond.shape_.ndim(); ++j) { + expanded_cshape[j + ndim_delta] = (cond.shape_)[j]; + } + Shape cstride = + mxnet_op::calc_stride(expanded_cshape.get()); + // get expanded lshape + TShape expanded_lshape(broadcast::MAX_DIM, 1); + ndim_delta = expanded_lshape.ndim() - dx.shape_.ndim(); + for (int j = 0; j < dx.shape_.ndim(); ++j) { + expanded_lshape[j + ndim_delta] = (dx.shape_)[j]; + } + + MSHADOW_TYPE_SWITCH_WITH_BOOL(ograd.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, { + Tensor largespace; + Tensor workspace; + size_t ws_size = 0; + if (ograd.shape_ != dx.shape_) { + ws_size = broadcast::ReduceWorkspaceSize( + s, expanded_lshape, req[0], expanded_oshape); + } + // process left output + if (ograd.shape_ == dx.shape_) { + mxnet_op::Kernel, xpu>::Launch( + s, ograd.Size(), req[0], cstride, oshape, + cond.dptr(), ograd.dptr(), dx.dptr()); + } else { + largespace = ctx.requested[0].get_space_typed( + Shape1(ograd.shape_.Size() * sizeof(DType) + ws_size), s); + workspace = Tensor( + reinterpret_cast(largespace.dptr_ + ws_size), + expanded_oshape.get(), s); + mxnet_op::Kernel, xpu>::Launch( + s, ograd.Size(), req[0], cstride, oshape, + cond.dptr(), ograd.dptr(), workspace.dptr_); + if (NeedSafeAcc(dx.type_flag_, dx.type_flag_)) { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); + } else { + ReduceAxesComputeImpl(ctx, {TBlob(workspace)}, {req[0]}, + {dx.reshape(expanded_lshape)}, expanded_lshape); + } + } + }); + }); +} + +template +inline void NumpyWhereScalar2OpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor + using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); + const NumpyWhereScalar2Param& param = nnvm::get(attrs.parsed); + const TBlob& cond = inputs[0]; + const TBlob& out = outputs[0]; + MSHADOW_TYPE_SWITCH_WITH_BOOL(out.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, { + Kernel::Launch(s, out.Size(), req[0], + out.dptr(), cond.dptr(), DType(param.x), DType(param.y)); + }); + }); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_where_op.cc b/src/operator/numpy/np_where_op.cc index 6cca0c5fd985..f703ee6ddf1a 100644 --- a/src/operator/numpy/np_where_op.cc +++ b/src/operator/numpy/np_where_op.cc @@ -24,6 +24,7 @@ */ #include "np_where_op-inl.h" +#include "../tensor/elemwise_binary_broadcast_op.h" namespace mxnet { namespace op { @@ -74,6 +75,19 @@ inline bool NumpyWhereOpType(const nnvm::NodeAttrs& attrs, return flag && (in_attrs->at(0) != -1); } +inline bool NumpyWhereScalarOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + std::vector sub_in_attrs(in_attrs->begin() + 1, in_attrs->end()); + bool flag = ElemwiseType<1, 1>(attrs, &sub_in_attrs, out_attrs); + return flag && (in_attrs->at(0) != -1); +} + +DMLC_REGISTER_PARAMETER(NumpyWhereScalarParam); +DMLC_REGISTER_PARAMETER(NumpyWhereScalar2Param); + NNVM_REGISTER_OP(_npi_where) .set_num_inputs(3) .set_num_outputs(1) @@ -129,5 +143,137 @@ NNVM_REGISTER_OP(_backward_np_where) return std::vector{ResourceRequest::kTempSpace}; }); +NNVM_REGISTER_OP(_npi_where_lscalar) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"condition", "x"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", NumpyWhereScalarOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{1, 0}}; + }) +.set_attr("FCompute", NumpyWhereScalarOpForward) +.set_attr("FGradient", + // Use the following lambda function instead of ElemwiseGradUseIn + // for best efficiency. grad[condition] = 0; to calculate grad[x] or grad[y] + // we need only condition from input. + [](const nnvm::NodePtr& n, const std::vector& ograds) { + std::vector ret; + // make zero grad node for grad[condition] + auto p = MakeNode("zeros_like", n->attrs.name + "_cond_backward", + {n->inputs[0]}, nullptr, &n); + ret.emplace_back(p); + + // make grad nodes for grad[x] and grad[y] + std::vector heads(ograds.begin(), ograds.end()); + heads.push_back(n->inputs[0]); // only need condition to calculate gradients + p = nnvm::Node::Create(); + p->attrs.op = nnvm::Op::Get("_backward_np_where_lscalar"); + p->attrs.name = n->attrs.name + "_backward"; + p->attrs.dict = n->attrs.dict; + if (p->op()->attr_parser != nullptr) { + p->op()->attr_parser(&(p->attrs)); + } + p->control_deps.emplace_back(n); + p->inputs = std::move(heads); + ret.emplace_back(p, 0, 0); + return ret; + }) +.add_argument("condition", "NDArray-or-Symbol", "condition array") +.add_argument("x", "NDArray-or-Symbol", "input x") +.add_arguments(NumpyWhereScalarParam::__FIELDS__()); + +NNVM_REGISTER_OP(_npi_where_rscalar) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"condition", "y"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", NumpyWhereScalarOpType) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{1, 0}}; + }) +.set_attr("FCompute", NumpyWhereScalarOpForward) +.set_attr("FGradient", + // Use the following lambda function instead of ElemwiseGradUseIn + // for best efficiency. grad[condition] = 0; to calculate grad[x] or grad[y] + // we need only condition from input. + [](const nnvm::NodePtr& n, const std::vector& ograds) { + std::vector ret; + // make zero grad node for grad[condition] + auto p = MakeNode("zeros_like", n->attrs.name + "_cond_backward", + {n->inputs[0]}, nullptr, &n); + ret.emplace_back(p); + + // make grad nodes for grad[x] and grad[y] + std::vector heads(ograds.begin(), ograds.end()); + heads.push_back(n->inputs[0]); // only need condition to calculate gradients + p = nnvm::Node::Create(); + p->attrs.op = nnvm::Op::Get("_backward_np_where_rscalar"); + p->attrs.name = n->attrs.name + "_backward"; + p->attrs.dict = n->attrs.dict; + if (p->op()->attr_parser != nullptr) { + p->op()->attr_parser(&(p->attrs)); + } + p->control_deps.emplace_back(n); + p->inputs = std::move(heads); + ret.emplace_back(p, 0, 0); + return ret; + }) +.add_argument("condition", "NDArray-or-Symbol", "condition array") +.add_argument("y", "NDArray-or-Symbol", "input y") +.add_arguments(NumpyWhereScalarParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_np_where_lscalar) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyWhereScalarOpBackward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }); + +NNVM_REGISTER_OP(_backward_np_where_rscalar) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyWhereScalarOpBackward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }); + +NNVM_REGISTER_OP(_npi_where_scalar2) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"condition"}; + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs){ + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); + return in_attrs->at(0) != -1; + }) +.set_attr("FCompute", NumpyWhereScalar2OpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("condition", "NDArray-or-Symbol", "condition array") +.add_arguments(NumpyWhereScalar2Param::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_where_op.cu b/src/operator/numpy/np_where_op.cu index 6d3da4477112..f914292cdbcc 100644 --- a/src/operator/numpy/np_where_op.cu +++ b/src/operator/numpy/np_where_op.cu @@ -34,5 +34,20 @@ NNVM_REGISTER_OP(_npi_where) NNVM_REGISTER_OP(_backward_np_where) .set_attr("FCompute", NumpyWhereOpBackward); +NNVM_REGISTER_OP(_npi_where_lscalar) +.set_attr("FCompute", NumpyWhereScalarOpForward); + +NNVM_REGISTER_OP(_npi_where_rscalar) +.set_attr("FCompute", NumpyWhereScalarOpForward); + +NNVM_REGISTER_OP(_backward_np_where_lscalar) +.set_attr("FCompute", NumpyWhereScalarOpBackward); + +NNVM_REGISTER_OP(_backward_np_where_rscalar) +.set_attr("FCompute", NumpyWhereScalarOpBackward); + +NNVM_REGISTER_OP(_npi_where_scalar2) +.set_attr("FCompute", NumpyWhereScalar2OpForward); + } // namespace op } // namespace mxnet