From bb2896720cd7648f115b8928b9ed06ba4efd70af Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Sat, 16 Nov 2019 13:41:26 +0800 Subject: [PATCH] fix according to reviews --- python/mxnet/ndarray/numpy/_op.py | 14 +++- python/mxnet/numpy/multiarray.py | 9 ++- src/operator/numpy/np_where_op-inl.h | 98 +++++++--------------------- src/operator/numpy/np_where_op.cc | 46 +++++++++++++ 4 files changed, 89 insertions(+), 78 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 5b516efd4218..67e5d21d1c84 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -5311,10 +5311,15 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): @set_module('mxnet.ndarray.numpy') -def where(condition, x, y): - """ +def where(condition, x=None, y=None): + """where(condition, [x, y]) Return elements chosen from `x` or `y` depending on `condition`. + .. note:: + When only `condition` is provided, this function is a shorthand for + ``np.asarray(condition).nonzero()``. The rest of this documentation + covers only the case where all three arguments are provided. + Parameters ---------- condition : ndarray @@ -5371,4 +5376,7 @@ def where(condition, x, y): [ 0., 2., -1.], [ 0., 3., -1.]]) """ - return _npi.where(condition, x, y, out=None) + if x is None and y is None: + return nonzero(condition) + else: + return _npi.where(condition, x, y, out=None) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 3e68c7a025e7..3969225b4ed8 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -7298,10 +7298,15 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs): @set_module('mxnet.numpy') -def where(condition, x, y): - """ +def where(condition, x=None, y=None): + """where(condition, [x, y]) Return elements chosen from `x` or `y` depending on `condition`. + .. note:: + When only `condition` is provided, this function is a shorthand for + ``np.asarray(condition).nonzero()``. The rest of this documentation + covers only the case where all three arguments are provided. + Parameters ---------- condition : ndarray diff --git a/src/operator/numpy/np_where_op-inl.h b/src/operator/numpy/np_where_op-inl.h index f9ac783724a7..84e6baa98f8d 100644 --- a/src/operator/numpy/np_where_op-inl.h +++ b/src/operator/numpy/np_where_op-inl.h @@ -40,8 +40,6 @@ namespace mxnet { namespace op { -#define NUMPY_WHERE_MAX_DIM 5 - using namespace mshadow; template @@ -75,52 +73,6 @@ struct numpy_where_backward_kernel { } }; -inline bool NumpyWhereOpShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_attrs, - mxnet::ShapeVector* out_attrs) { - CHECK_EQ(in_attrs->size(), 3U); - CHECK_EQ(out_attrs->size(), 1U); - mxnet::TShape& operand1 = (*in_attrs)[0]; - mxnet::TShape& operand2 = (*in_attrs)[1]; - mxnet::TShape& operand3 = (*in_attrs)[2]; - - if (operand1 == operand2 && operand2 == operand3) { - SHAPE_ASSIGN_CHECK(*out_attrs, 0, operand1); - return shape_is_known(out_attrs->at(0)); - } - mxnet::TShape out(std::max({operand1.ndim(), operand2.ndim(), operand3.ndim()}), -1); - const int b1 = out.ndim() - operand1.ndim(); - const int b2 = out.ndim() - operand2.ndim(); - const int b3 = out.ndim() - operand3.ndim(); - for (int i = 0; i < out.ndim(); ++i) { - int s1 = 1, s2 = 1, s3 = 1; - if (i >= b1) s1 = operand1[i-b1]; - if (i >= b2) s2 = operand2[i-b2]; - if (i >= b3) s3 = operand3[i-b3]; - if (!(s1 == s2 && s2 == s3)) { - CHECK((s1 == 1 && s2 == 1) || (s1 == 1 && s3 == 1) || (s2 == 1 && s3 == 1) || - (s1 == 1 && s2 == s3) || (s2 == 1 && s1 == s3) || (s3 == 1 && s1 == s2)) - << "Operands could not be broadcast together."; - out[i] = std::max({s1, s2, s3}); - } else { - out[i] = s1; - } - } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); - return shape_is_known(out); -} - -inline bool NumpyWhereOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 3U) - << "where operator takes 3 arguments (" << in_attrs->size() << " given)"; - CHECK_EQ(out_attrs->size(), 1U); - CHECK_EQ(in_attrs->at(1), in_attrs->at(2)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); - return (out_attrs->at(0) != -1); -} - template inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -130,29 +82,29 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 1U); if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor - CHECK_LE(outputs[0].shape_.ndim(), NUMPY_WHERE_MAX_DIM); + CHECK_LE(outputs[0].shape_.ndim(), broadcast::MAX_DIM); Stream *s = ctx.get_stream(); - std::vector> in_strides; + std::vector> in_strides; in_strides.resize(3); for (int i = 0; i < 3; ++i) { - TShape expanded_ishape(NUMPY_WHERE_MAX_DIM, 1); + TShape expanded_ishape(broadcast::MAX_DIM, 1); const TShape& ishape = inputs[i].shape_; const int ndim_delta = expanded_ishape.ndim() - ishape.ndim(); for (int j = 0; j < ishape.ndim(); ++j) { expanded_ishape[j + ndim_delta] = ishape[j]; } - in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get()); + in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get()); } - TShape expanded_oshape(NUMPY_WHERE_MAX_DIM, 1); + TShape expanded_oshape(broadcast::MAX_DIM, 1); const int ndim_delta = expanded_oshape.ndim() - outputs[0].shape_.ndim(); for (int j = 0; j < outputs[0].shape_.ndim(); ++j) { expanded_oshape[j + ndim_delta] = (outputs[0].shape_)[j]; } - Shape oshape = expanded_oshape.get(); + Shape oshape = expanded_oshape.get(); MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, CType, { - mxnet_op::Kernel, xpu>::Launch( + mxnet_op::Kernel, xpu>::Launch( s, outputs[0].Size(), req[0], in_strides[0], in_strides[1], in_strides[2], oshape, inputs[0].dptr(), inputs[1].dptr(), @@ -173,28 +125,28 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, if (inputs[0].shape_.Size() == 0U) return; // zero-size tensor Stream *s = ctx.get_stream(); // get expanded oshape - TShape expanded_oshape(NUMPY_WHERE_MAX_DIM, 1); + TShape expanded_oshape(broadcast::MAX_DIM, 1); int ndim_delta = expanded_oshape.ndim() - inputs[0].shape_.ndim(); for (int j = 0; j < inputs[0].shape_.ndim(); ++j) { expanded_oshape[j + ndim_delta] = (inputs[0].shape_)[j]; } - Shape oshape = expanded_oshape.get(); + Shape oshape = expanded_oshape.get(); // get cond stride - TShape expanded_cshape(NUMPY_WHERE_MAX_DIM, 1); + TShape expanded_cshape(broadcast::MAX_DIM, 1); ndim_delta = expanded_cshape.ndim() - inputs[1].shape_.ndim(); for (int j = 0; j < inputs[1].shape_.ndim(); ++j) { expanded_cshape[j + ndim_delta] = (inputs[1].shape_)[j]; } - Shape cstride = - mxnet_op::calc_stride(expanded_cshape.get()); + Shape cstride = + mxnet_op::calc_stride(expanded_cshape.get()); // get expanded lshape - TShape expanded_lshape(NUMPY_WHERE_MAX_DIM, 1); + TShape expanded_lshape(broadcast::MAX_DIM, 1); ndim_delta = expanded_lshape.ndim() - outputs[0].shape_.ndim(); for (int j = 0; j < outputs[0].shape_.ndim(); ++j) { expanded_lshape[j + ndim_delta] = (outputs[0].shape_)[j]; } // get expanded rshape - TShape expanded_rshape(NUMPY_WHERE_MAX_DIM, 1); + TShape expanded_rshape(broadcast::MAX_DIM, 1); ndim_delta = expanded_rshape.ndim() - outputs[1].shape_.ndim(); for (int j = 0; j < outputs[1].shape_.ndim(); ++j) { expanded_rshape[j + ndim_delta] = (outputs[1].shape_)[j]; @@ -203,27 +155,27 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, CType, { Tensor largespace; - Tensor workspace; + Tensor workspace; size_t ws_size = 0; if (!(inputs[0].shape_ != outputs[0].shape_) || !(inputs[0].shape_ != outputs[1].shape_)) { - size_t ws_size1 = broadcast::ReduceWorkspaceSize( + size_t ws_size1 = broadcast::ReduceWorkspaceSize( s, expanded_lshape, req[0], expanded_oshape); - size_t ws_size2 = broadcast::ReduceWorkspaceSize( + size_t ws_size2 = broadcast::ReduceWorkspaceSize( s, expanded_rshape, req[1], expanded_oshape); ws_size = std::max(ws_size1, ws_size2); } // process left output if (inputs[0].shape_ == outputs[0].shape_) { - mxnet_op::Kernel, xpu>::Launch( + mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), req[0], cstride, oshape, inputs[1].dptr(), inputs[0].dptr(), outputs[0].dptr()); } else { largespace = ctx.requested[0].get_space_typed( Shape1(inputs[0].shape_.Size() * sizeof(DType) + ws_size), s); - workspace = Tensor( + workspace = Tensor( reinterpret_cast(largespace.dptr_ + ws_size), - expanded_oshape.get(), s); - mxnet_op::Kernel, xpu>::Launch( + expanded_oshape.get(), s); + mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), req[0], cstride, oshape, inputs[1].dptr(), inputs[0].dptr(), workspace.dptr_); if (NeedSafeAcc(outputs[0].type_flag_, outputs[0].type_flag_)) { @@ -236,16 +188,16 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs, } // process right output if (inputs[0].shape_ == outputs[1].shape_) { - mxnet_op::Kernel, xpu>::Launch( + mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), req[1], cstride, oshape, inputs[1].dptr(), inputs[0].dptr(), outputs[1].dptr()); } else { largespace = ctx.requested[0].get_space_typed( Shape1(inputs[0].shape_.Size() * sizeof(DType) + ws_size), s); - workspace = Tensor( + workspace = Tensor( reinterpret_cast(largespace.dptr_ + ws_size), - expanded_oshape.get(), s); - mxnet_op::Kernel, xpu>::Launch( + expanded_oshape.get(), s); + mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), req[1], cstride, oshape, inputs[1].dptr(), inputs[0].dptr(), workspace.dptr_); if (NeedSafeAcc(outputs[1].type_flag_, outputs[1].type_flag_)) { diff --git a/src/operator/numpy/np_where_op.cc b/src/operator/numpy/np_where_op.cc index 1cd04dfc2ee6..6cca0c5fd985 100644 --- a/src/operator/numpy/np_where_op.cc +++ b/src/operator/numpy/np_where_op.cc @@ -28,6 +28,52 @@ namespace mxnet { namespace op { +inline bool NumpyWhereOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape& operand1 = (*in_attrs)[0]; + mxnet::TShape& operand2 = (*in_attrs)[1]; + mxnet::TShape& operand3 = (*in_attrs)[2]; + + if (operand1 == operand2 && operand2 == operand3) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, operand1); + return shape_is_known(out_attrs->at(0)); + } + mxnet::TShape out(std::max({operand1.ndim(), operand2.ndim(), operand3.ndim()}), -1); + const int b1 = out.ndim() - operand1.ndim(); + const int b2 = out.ndim() - operand2.ndim(); + const int b3 = out.ndim() - operand3.ndim(); + for (int i = 0; i < out.ndim(); ++i) { + int s1 = 1, s2 = 1, s3 = 1; + if (i >= b1) s1 = operand1[i-b1]; + if (i >= b2) s2 = operand2[i-b2]; + if (i >= b3) s3 = operand3[i-b3]; + if (!(s1 == s2 && s2 == s3)) { + CHECK((s1 == 1 && s2 == 1) || (s1 == 1 && s3 == 1) || (s2 == 1 && s3 == 1) || + (s1 == 1 && s2 == s3) || (s2 == 1 && s1 == s3) || (s3 == 1 && s1 == s2)) + << "Operands could not be broadcast together."; + out[i] = std::max({s1, s2, s3}); + } else { + out[i] = s1; + } + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out); + return shape_is_known(out); +} + +inline bool NumpyWhereOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 3U) + << "where operator takes 3 arguments (" << in_attrs->size() << " given)"; + CHECK_EQ(out_attrs->size(), 1U); + std::vector sub_in_attrs(in_attrs->begin() + 1, in_attrs->end()); + bool flag = ElemwiseType<2, 1>(attrs, &sub_in_attrs, out_attrs); + return flag && (in_attrs->at(0) != -1); +} + NNVM_REGISTER_OP(_npi_where) .set_num_inputs(3) .set_num_outputs(1)