Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix according to reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Nov 16, 2019
1 parent 3df4094 commit bb28967
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 78 deletions.
14 changes: 11 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5311,10 +5311,15 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):


@set_module('mxnet.ndarray.numpy')
def where(condition, x, y):
"""
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.
.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. The rest of this documentation
covers only the case where all three arguments are provided.
Parameters
----------
condition : ndarray
Expand Down Expand Up @@ -5371,4 +5376,7 @@ def where(condition, x, y):
[ 0., 2., -1.],
[ 0., 3., -1.]])
"""
return _npi.where(condition, x, y, out=None)
if x is None and y is None:
return nonzero(condition)
else:
return _npi.where(condition, x, y, out=None)
9 changes: 7 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7298,10 +7298,15 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):


@set_module('mxnet.numpy')
def where(condition, x, y):
"""
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.
.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. The rest of this documentation
covers only the case where all three arguments are provided.
Parameters
----------
condition : ndarray
Expand Down
98 changes: 25 additions & 73 deletions src/operator/numpy/np_where_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
namespace mxnet {
namespace op {

#define NUMPY_WHERE_MAX_DIM 5

using namespace mshadow;

template<int ndim>
Expand Down Expand Up @@ -75,52 +73,6 @@ struct numpy_where_backward_kernel {
}
};

inline bool NumpyWhereOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& operand1 = (*in_attrs)[0];
mxnet::TShape& operand2 = (*in_attrs)[1];
mxnet::TShape& operand3 = (*in_attrs)[2];

if (operand1 == operand2 && operand2 == operand3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, operand1);
return shape_is_known(out_attrs->at(0));
}
mxnet::TShape out(std::max({operand1.ndim(), operand2.ndim(), operand3.ndim()}), -1);
const int b1 = out.ndim() - operand1.ndim();
const int b2 = out.ndim() - operand2.ndim();
const int b3 = out.ndim() - operand3.ndim();
for (int i = 0; i < out.ndim(); ++i) {
int s1 = 1, s2 = 1, s3 = 1;
if (i >= b1) s1 = operand1[i-b1];
if (i >= b2) s2 = operand2[i-b2];
if (i >= b3) s3 = operand3[i-b3];
if (!(s1 == s2 && s2 == s3)) {
CHECK((s1 == 1 && s2 == 1) || (s1 == 1 && s3 == 1) || (s2 == 1 && s3 == 1) ||
(s1 == 1 && s2 == s3) || (s2 == 1 && s1 == s3) || (s3 == 1 && s1 == s2))
<< "Operands could not be broadcast together.";
out[i] = std::max({s1, s2, s3});
} else {
out[i] = s1;
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out);
return shape_is_known(out);
}

inline bool NumpyWhereOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U)
<< "where operator takes 3 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ(in_attrs->at(1), in_attrs->at(2));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
return (out_attrs->at(0) != -1);
}

template<typename xpu>
inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -130,29 +82,29 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
CHECK_LE(outputs[0].shape_.ndim(), NUMPY_WHERE_MAX_DIM);
CHECK_LE(outputs[0].shape_.ndim(), broadcast::MAX_DIM);

Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Shape<NUMPY_WHERE_MAX_DIM>> in_strides;
std::vector<Shape<broadcast::MAX_DIM>> in_strides;
in_strides.resize(3);
for (int i = 0; i < 3; ++i) {
TShape expanded_ishape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_ishape(broadcast::MAX_DIM, 1);
const TShape& ishape = inputs[i].shape_;
const int ndim_delta = expanded_ishape.ndim() - ishape.ndim();
for (int j = 0; j < ishape.ndim(); ++j) {
expanded_ishape[j + ndim_delta] = ishape[j];
}
in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get<NUMPY_WHERE_MAX_DIM>());
in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get<broadcast::MAX_DIM>());
}
TShape expanded_oshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_oshape(broadcast::MAX_DIM, 1);
const int ndim_delta = expanded_oshape.ndim() - outputs[0].shape_.ndim();
for (int j = 0; j < outputs[0].shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (outputs[0].shape_)[j];
}
Shape<NUMPY_WHERE_MAX_DIM> oshape = expanded_oshape.get<NUMPY_WHERE_MAX_DIM>();
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, CType, {
mxnet_op::Kernel<numpy_where_kernel<NUMPY_WHERE_MAX_DIM>, xpu>::Launch(
mxnet_op::Kernel<numpy_where_kernel<broadcast::MAX_DIM>, xpu>::Launch(
s, outputs[0].Size(), req[0],
in_strides[0], in_strides[1], in_strides[2], oshape,
inputs[0].dptr<CType>(), inputs[1].dptr<DType>(),
Expand All @@ -173,28 +125,28 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
if (inputs[0].shape_.Size() == 0U) return; // zero-size tensor
Stream<xpu> *s = ctx.get_stream<xpu>();
// get expanded oshape
TShape expanded_oshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_oshape(broadcast::MAX_DIM, 1);
int ndim_delta = expanded_oshape.ndim() - inputs[0].shape_.ndim();
for (int j = 0; j < inputs[0].shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (inputs[0].shape_)[j];
}
Shape<NUMPY_WHERE_MAX_DIM> oshape = expanded_oshape.get<NUMPY_WHERE_MAX_DIM>();
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
// get cond stride
TShape expanded_cshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_cshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_cshape.ndim() - inputs[1].shape_.ndim();
for (int j = 0; j < inputs[1].shape_.ndim(); ++j) {
expanded_cshape[j + ndim_delta] = (inputs[1].shape_)[j];
}
Shape<NUMPY_WHERE_MAX_DIM> cstride =
mxnet_op::calc_stride(expanded_cshape.get<NUMPY_WHERE_MAX_DIM>());
Shape<broadcast::MAX_DIM> cstride =
mxnet_op::calc_stride(expanded_cshape.get<broadcast::MAX_DIM>());
// get expanded lshape
TShape expanded_lshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_lshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_lshape.ndim() - outputs[0].shape_.ndim();
for (int j = 0; j < outputs[0].shape_.ndim(); ++j) {
expanded_lshape[j + ndim_delta] = (outputs[0].shape_)[j];
}
// get expanded rshape
TShape expanded_rshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_rshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_rshape.ndim() - outputs[1].shape_.ndim();
for (int j = 0; j < outputs[1].shape_.ndim(); ++j) {
expanded_rshape[j + ndim_delta] = (outputs[1].shape_)[j];
Expand All @@ -203,27 +155,27 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, CType, {
Tensor<xpu, 1, char> largespace;
Tensor<xpu, NUMPY_WHERE_MAX_DIM, DType> workspace;
Tensor<xpu, broadcast::MAX_DIM, DType> workspace;
size_t ws_size = 0;
if (!(inputs[0].shape_ != outputs[0].shape_) || !(inputs[0].shape_ != outputs[1].shape_)) {
size_t ws_size1 = broadcast::ReduceWorkspaceSize<NUMPY_WHERE_MAX_DIM, DType>(
size_t ws_size1 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_lshape, req[0], expanded_oshape);
size_t ws_size2 = broadcast::ReduceWorkspaceSize<NUMPY_WHERE_MAX_DIM, DType>(
size_t ws_size2 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_rshape, req[1], expanded_oshape);
ws_size = std::max(ws_size1, ws_size2);
}
// process left output
if (inputs[0].shape_ == outputs[0].shape_) {
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, true>, xpu>::Launch(
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, true>, xpu>::Launch(
s, inputs[0].Size(), req[0], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), outputs[0].dptr<DType>());
} else {
largespace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(inputs[0].shape_.Size() * sizeof(DType) + ws_size), s);
workspace = Tensor<xpu, NUMPY_WHERE_MAX_DIM, DType>(
workspace = Tensor<xpu, broadcast::MAX_DIM, DType>(
reinterpret_cast<DType*>(largespace.dptr_ + ws_size),
expanded_oshape.get<NUMPY_WHERE_MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, true>, xpu>::Launch(
expanded_oshape.get<broadcast::MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, true>, xpu>::Launch(
s, inputs[0].Size(), req[0], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(outputs[0].type_flag_, outputs[0].type_flag_)) {
Expand All @@ -236,16 +188,16 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
}
// process right output
if (inputs[0].shape_ == outputs[1].shape_) {
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, false>, xpu>::Launch(
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, false>, xpu>::Launch(
s, inputs[0].Size(), req[1], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), outputs[1].dptr<DType>());
} else {
largespace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(inputs[0].shape_.Size() * sizeof(DType) + ws_size), s);
workspace = Tensor<xpu, NUMPY_WHERE_MAX_DIM, DType>(
workspace = Tensor<xpu, broadcast::MAX_DIM, DType>(
reinterpret_cast<DType*>(largespace.dptr_ + ws_size),
expanded_oshape.get<NUMPY_WHERE_MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, false>, xpu>::Launch(
expanded_oshape.get<broadcast::MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, false>, xpu>::Launch(
s, inputs[0].Size(), req[1], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(outputs[1].type_flag_, outputs[1].type_flag_)) {
Expand Down
46 changes: 46 additions & 0 deletions src/operator/numpy/np_where_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,52 @@
namespace mxnet {
namespace op {

inline bool NumpyWhereOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& operand1 = (*in_attrs)[0];
mxnet::TShape& operand2 = (*in_attrs)[1];
mxnet::TShape& operand3 = (*in_attrs)[2];

if (operand1 == operand2 && operand2 == operand3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, operand1);
return shape_is_known(out_attrs->at(0));
}
mxnet::TShape out(std::max({operand1.ndim(), operand2.ndim(), operand3.ndim()}), -1);
const int b1 = out.ndim() - operand1.ndim();
const int b2 = out.ndim() - operand2.ndim();
const int b3 = out.ndim() - operand3.ndim();
for (int i = 0; i < out.ndim(); ++i) {
int s1 = 1, s2 = 1, s3 = 1;
if (i >= b1) s1 = operand1[i-b1];
if (i >= b2) s2 = operand2[i-b2];
if (i >= b3) s3 = operand3[i-b3];
if (!(s1 == s2 && s2 == s3)) {
CHECK((s1 == 1 && s2 == 1) || (s1 == 1 && s3 == 1) || (s2 == 1 && s3 == 1) ||
(s1 == 1 && s2 == s3) || (s2 == 1 && s1 == s3) || (s3 == 1 && s1 == s2))
<< "Operands could not be broadcast together.";
out[i] = std::max({s1, s2, s3});
} else {
out[i] = s1;
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out);
return shape_is_known(out);
}

inline bool NumpyWhereOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U)
<< "where operator takes 3 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1U);
std::vector<int> sub_in_attrs(in_attrs->begin() + 1, in_attrs->end());
bool flag = ElemwiseType<2, 1>(attrs, &sub_in_attrs, out_attrs);
return flag && (in_attrs->at(0) != -1);
}

NNVM_REGISTER_OP(_npi_where)
.set_num_inputs(3)
.set_num_outputs(1)
Expand Down

0 comments on commit bb28967

Please sign in to comment.