Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Where operator scalar version #17249

Merged
merged 2 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6689,7 +6689,7 @@ def isinf(x, out=None, **kwargs):


@set_module('mxnet.ndarray.numpy')
def where(condition, x=None, y=None):
def where(condition, x=None, y=None): # pylint: disable=too-many-return-statements
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.

Expand Down Expand Up @@ -6719,6 +6719,14 @@ def where(condition, x=None, y=None):
[xv if c else yv
for c, xv, yv in zip(condition, x, y)]

This function differs from the original `numpy.where
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html>`_ in
the following way(s):

- If `condition` is a scalar, this operator returns x or y directly without broadcasting.
- If `condition` is ndarray, while both `x` and `y` are scalars,
the output dtype will be `float32`.

Examples
--------
>>> a = np.arange(10)
Expand Down Expand Up @@ -6749,15 +6757,30 @@ def where(condition, x=None, y=None):
>>> a = np.array([[0, 1, 2],
... [0, 2, 4],
... [0, 3, 6]])
>>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast
>>> np.where(a < 4, a, -1) # -1 is broadcast
array([[ 0., 1., 2.],
[ 0., 2., -1.],
[ 0., 3., -1.]])
"""
if x is None and y is None:
return nonzero(condition)
else:
return _npi.where(condition, x, y, out=None)
if isinstance(condition, numeric_types):
if condition != 0:
return x
else:
return y
else:
if isinstance(x, numeric_types) and isinstance(y, numeric_types):
return _npi.where_scalar2(condition, float(x), float(y), out=None)
elif isinstance(x, NDArray) and isinstance(y, NDArray):
return _npi.where(condition, x, y, out=None)
elif isinstance(y, NDArray):
return _npi.where_lscalar(condition, y, float(x), out=None)
elif isinstance(x, NDArray):
return _npi.where_rscalar(condition, x, float(y), out=None)
else:
raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y))))


@set_module('mxnet.ndarray.numpy')
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8828,7 +8828,7 @@ def where(condition, x=None, y=None):
>>> a = np.array([[0, 1, 2],
... [0, 2, 4],
... [0, 3, 6]])
>>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast
>>> np.where(a < 4, a, -1) # -1 is broadcast
array([[ 0., 1., 2.],
[ 0., 2., -1.],
[ 0., 3., -1.]])
Expand Down
17 changes: 16 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6012,7 +6012,22 @@ def where(condition, x, y):
from `y` elsewhere.

"""
return _npi.where(condition, x, y, out=None)
if isinstance(condition, numeric_types):
if condition != 0:
return x
else:
return y
else:
if isinstance(x, numeric_types) and isinstance(y, numeric_types):
return _npi.where_scalar2(condition, float(x), float(y), out=None)
elif isinstance(x, Symbol) and isinstance(y, Symbol):
return _npi.where(condition, x, y, out=None)
elif isinstance(y, Symbol):
return _npi.where_lscalar(condition, y, float(x), out=None)
elif isinstance(x, Symbol):
return _npi.where_rscalar(condition, x, float(y), out=None)
else:
raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y))))


@set_module('mxnet.symbol.numpy')
Expand Down
186 changes: 185 additions & 1 deletion src/operator/numpy/np_where_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,27 @@ namespace op {

using namespace mshadow;

struct NumpyWhereScalarParam : public dmlc::Parameter<NumpyWhereScalarParam> {
double scalar;
DMLC_DECLARE_PARAMETER(NumpyWhereScalarParam) {
DMLC_DECLARE_FIELD(scalar)
.set_default(0.0)
.describe("The scalar value of x/y.");
}
};

struct NumpyWhereScalar2Param : public dmlc::Parameter<NumpyWhereScalar2Param> {
double x, y;
DMLC_DECLARE_PARAMETER(NumpyWhereScalar2Param) {
DMLC_DECLARE_FIELD(x)
.set_default(0.0)
.describe("The scalar value of x.");
DMLC_DECLARE_FIELD(y)
.set_default(0.0)
.describe("The scalar value of y.");
}
};

template<int ndim>
struct numpy_where_kernel {
template<typename CType, typename DType>
Expand Down Expand Up @@ -73,6 +94,31 @@ struct numpy_where_backward_kernel {
}
};

template<int ndim, bool is_left>
struct numpy_where_scalar_kernel {
template<typename CType, typename DType>
MSHADOW_XINLINE static void Map(index_t base, OpReqType req, const Shape<ndim> &cstride,
const Shape<ndim> &ystride, const Shape<ndim> &oshape,
CType *datac, DType datax, DType *datay, DType *out) {
Shape<ndim> coord = mxnet_op::unravel(base, oshape);
auto cidx = static_cast<index_t>(mxnet_op::dot(coord, cstride));
auto yidx = static_cast<index_t>(mxnet_op::dot(coord, ystride));
if (is_left) {
KERNEL_ASSIGN(out[base], req, datac[cidx] != CType(0) ? datax : datay[yidx]);
} else {
KERNEL_ASSIGN(out[base], req, datac[cidx] != CType(0) ? datay[yidx] : datax);
}
}
};

struct numpy_where_scalar2_kernel {
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(index_t i, OpReqType req, DType* out, const CType* cond,
const DType x, const DType y) {
KERNEL_ASSIGN(out[i], req, (CType(0) != cond[i]? x : y));
}
};

template<typename xpu>
inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -166,7 +212,7 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, char> largespace;
Tensor<xpu, broadcast::MAX_DIM, DType> workspace;
size_t ws_size = 0;
if (!(ograd.shape_ != dx.shape_) || !(ograd.shape_ != dy.shape_)) {
if (ograd.shape_ != dx.shape_ || ograd.shape_ != dy.shape_) {
size_t ws_size1 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_lshape, req[0], expanded_oshape);
size_t ws_size2 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
Expand Down Expand Up @@ -221,6 +267,144 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
});
}

template<typename xpu, bool is_left>
inline void NumpyWhereScalarOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
CHECK_LE(outputs[0].shape_.ndim(), broadcast::MAX_DIM);

const NumpyWhereScalarParam& param = nnvm::get<NumpyWhereScalarParam>(attrs.parsed);
const TBlob& cond = inputs[0];
const TBlob& y = inputs[1];
const TBlob& out = outputs[0];
Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Shape<broadcast::MAX_DIM>> in_strides;
in_strides.resize(2);
for (int i = 0; i < 2; ++i) {
TShape expanded_ishape(broadcast::MAX_DIM, 1);
const TShape& ishape = inputs[i].shape_;
const int ndim_delta = expanded_ishape.ndim() - ishape.ndim();
for (int j = 0; j < ishape.ndim(); ++j) {
expanded_ishape[j + ndim_delta] = ishape[j];
}
in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get<broadcast::MAX_DIM>());
}
TShape expanded_oshape(broadcast::MAX_DIM, 1);
const int ndim_delta = expanded_oshape.ndim() - out.shape_.ndim();
for (int j = 0; j < out.shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (out.shape_)[j];
}
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(out.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, {
mxnet_op::Kernel<numpy_where_scalar_kernel<broadcast::MAX_DIM, is_left>, xpu>::Launch(
s, out.Size(), req[0],
in_strides[0], in_strides[1], oshape,
cond.dptr<CType>(), DType(param.scalar),
y.dptr<DType>(), out.dptr<DType>());
});
});
}

template<typename xpu, bool is_left>
inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK(common::is_float(inputs[0].type_flag_)) << "Backward only supports float types!";
if (inputs[0].shape_.Size() == 0U) return; // zero-size tensor

Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& ograd = inputs[0];
const TBlob& cond = inputs[1];
const TBlob& dx = outputs[0];
// get expanded oshape
TShape expanded_oshape(broadcast::MAX_DIM, 1);
int ndim_delta = expanded_oshape.ndim() - ograd.shape_.ndim();
for (int j = 0; j < ograd.shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (ograd.shape_)[j];
}
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
// get cond stride
TShape expanded_cshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_cshape.ndim() - cond.shape_.ndim();
for (int j = 0; j < cond.shape_.ndim(); ++j) {
expanded_cshape[j + ndim_delta] = (cond.shape_)[j];
}
Shape<broadcast::MAX_DIM> cstride =
mxnet_op::calc_stride(expanded_cshape.get<broadcast::MAX_DIM>());
// get expanded lshape
TShape expanded_lshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_lshape.ndim() - dx.shape_.ndim();
for (int j = 0; j < dx.shape_.ndim(); ++j) {
expanded_lshape[j + ndim_delta] = (dx.shape_)[j];
}

MSHADOW_TYPE_SWITCH_WITH_BOOL(ograd.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, {
Tensor<xpu, 1, char> largespace;
Tensor<xpu, broadcast::MAX_DIM, DType> workspace;
size_t ws_size = 0;
if (ograd.shape_ != dx.shape_) {
ws_size = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_lshape, req[0], expanded_oshape);
}
// process left output
if (ograd.shape_ == dx.shape_) {
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, is_left>, xpu>::Launch(
s, ograd.Size(), req[0], cstride, oshape,
cond.dptr<CType>(), ograd.dptr<DType>(), dx.dptr<DType>());
} else {
largespace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(ograd.shape_.Size() * sizeof(DType) + ws_size), s);
workspace = Tensor<xpu, broadcast::MAX_DIM, DType>(
reinterpret_cast<DType*>(largespace.dptr_ + ws_size),
expanded_oshape.get<broadcast::MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, true>, xpu>::Launch(
s, ograd.Size(), req[0], cstride, oshape,
cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(dx.type_flag_, dx.type_flag_)) {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
} else {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
}
}
});
});
}

template<typename xpu>
inline void NumpyWhereScalar2OpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
using namespace mxnet_op;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const NumpyWhereScalar2Param& param = nnvm::get<NumpyWhereScalar2Param>(attrs.parsed);
const TBlob& cond = inputs[0];
const TBlob& out = outputs[0];
MSHADOW_TYPE_SWITCH_WITH_BOOL(out.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, {
Kernel<numpy_where_scalar2_kernel, xpu>::Launch(s, out.Size(), req[0],
out.dptr<DType>(), cond.dptr<CType>(), DType(param.x), DType(param.y));
});
});
}

} // namespace op
} // namespace mxnet

Expand Down
Loading