Skip to content

Commit

Permalink
[Numpy] Where operator scalar version (apache#17249)
Browse files Browse the repository at this point in the history
* init

* fix
  • Loading branch information
hgt312 authored and Ubuntu committed Feb 19, 2020
1 parent 31fade3 commit 1cd53d8
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 6 deletions.
29 changes: 26 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6728,7 +6728,7 @@ def isinf(x, out=None, **kwargs):


@set_module('mxnet.ndarray.numpy')
def where(condition, x=None, y=None):
def where(condition, x=None, y=None): # pylint: disable=too-many-return-statements
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.
Expand Down Expand Up @@ -6758,6 +6758,14 @@ def where(condition, x=None, y=None):
[xv if c else yv
for c, xv, yv in zip(condition, x, y)]
This function differs from the original `numpy.where
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html>`_ in
the following way(s):
- If `condition` is a scalar, this operator returns x or y directly without broadcasting.
- If `condition` is ndarray, while both `x` and `y` are scalars,
the output dtype will be `float32`.
Examples
--------
>>> a = np.arange(10)
Expand Down Expand Up @@ -6788,15 +6796,30 @@ def where(condition, x=None, y=None):
>>> a = np.array([[0, 1, 2],
... [0, 2, 4],
... [0, 3, 6]])
>>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast
>>> np.where(a < 4, a, -1) # -1 is broadcast
array([[ 0., 1., 2.],
[ 0., 2., -1.],
[ 0., 3., -1.]])
"""
if x is None and y is None:
return nonzero(condition)
else:
return _npi.where(condition, x, y, out=None)
if isinstance(condition, numeric_types):
if condition != 0:
return x
else:
return y
else:
if isinstance(x, numeric_types) and isinstance(y, numeric_types):
return _npi.where_scalar2(condition, float(x), float(y), out=None)
elif isinstance(x, NDArray) and isinstance(y, NDArray):
return _npi.where(condition, x, y, out=None)
elif isinstance(y, NDArray):
return _npi.where_lscalar(condition, y, float(x), out=None)
elif isinstance(x, NDArray):
return _npi.where_rscalar(condition, x, float(y), out=None)
else:
raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y))))


@set_module('mxnet.ndarray.numpy')
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8867,7 +8867,7 @@ def where(condition, x=None, y=None):
>>> a = np.array([[0, 1, 2],
... [0, 2, 4],
... [0, 3, 6]])
>>> np.where(a < 4, a, np.array(-1)) # -1 is broadcast
>>> np.where(a < 4, a, -1) # -1 is broadcast
array([[ 0., 1., 2.],
[ 0., 2., -1.],
[ 0., 3., -1.]])
Expand Down
17 changes: 16 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6051,7 +6051,22 @@ def where(condition, x, y):
from `y` elsewhere.
"""
return _npi.where(condition, x, y, out=None)
if isinstance(condition, numeric_types):
if condition != 0:
return x
else:
return y
else:
if isinstance(x, numeric_types) and isinstance(y, numeric_types):
return _npi.where_scalar2(condition, float(x), float(y), out=None)
elif isinstance(x, Symbol) and isinstance(y, Symbol):
return _npi.where(condition, x, y, out=None)
elif isinstance(y, Symbol):
return _npi.where_lscalar(condition, y, float(x), out=None)
elif isinstance(x, Symbol):
return _npi.where_rscalar(condition, x, float(y), out=None)
else:
raise TypeError('type {0} and {1} not supported'.format(str(type(x)), str(type(y))))


@set_module('mxnet.symbol.numpy')
Expand Down
186 changes: 185 additions & 1 deletion src/operator/numpy/np_where_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,27 @@ namespace op {

using namespace mshadow;

struct NumpyWhereScalarParam : public dmlc::Parameter<NumpyWhereScalarParam> {
double scalar;
DMLC_DECLARE_PARAMETER(NumpyWhereScalarParam) {
DMLC_DECLARE_FIELD(scalar)
.set_default(0.0)
.describe("The scalar value of x/y.");
}
};

struct NumpyWhereScalar2Param : public dmlc::Parameter<NumpyWhereScalar2Param> {
double x, y;
DMLC_DECLARE_PARAMETER(NumpyWhereScalar2Param) {
DMLC_DECLARE_FIELD(x)
.set_default(0.0)
.describe("The scalar value of x.");
DMLC_DECLARE_FIELD(y)
.set_default(0.0)
.describe("The scalar value of y.");
}
};

template<int ndim>
struct numpy_where_kernel {
template<typename CType, typename DType>
Expand Down Expand Up @@ -73,6 +94,31 @@ struct numpy_where_backward_kernel {
}
};

template<int ndim, bool is_left>
struct numpy_where_scalar_kernel {
template<typename CType, typename DType>
MSHADOW_XINLINE static void Map(index_t base, OpReqType req, const Shape<ndim> &cstride,
const Shape<ndim> &ystride, const Shape<ndim> &oshape,
CType *datac, DType datax, DType *datay, DType *out) {
Shape<ndim> coord = mxnet_op::unravel(base, oshape);
auto cidx = static_cast<index_t>(mxnet_op::dot(coord, cstride));
auto yidx = static_cast<index_t>(mxnet_op::dot(coord, ystride));
if (is_left) {
KERNEL_ASSIGN(out[base], req, datac[cidx] != CType(0) ? datax : datay[yidx]);
} else {
KERNEL_ASSIGN(out[base], req, datac[cidx] != CType(0) ? datay[yidx] : datax);
}
}
};

struct numpy_where_scalar2_kernel {
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(index_t i, OpReqType req, DType* out, const CType* cond,
const DType x, const DType y) {
KERNEL_ASSIGN(out[i], req, (CType(0) != cond[i]? x : y));
}
};

template<typename xpu>
inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -166,7 +212,7 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, char> largespace;
Tensor<xpu, broadcast::MAX_DIM, DType> workspace;
size_t ws_size = 0;
if (!(ograd.shape_ != dx.shape_) || !(ograd.shape_ != dy.shape_)) {
if (ograd.shape_ != dx.shape_ || ograd.shape_ != dy.shape_) {
size_t ws_size1 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_lshape, req[0], expanded_oshape);
size_t ws_size2 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
Expand Down Expand Up @@ -221,6 +267,144 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
});
}

template<typename xpu, bool is_left>
inline void NumpyWhereScalarOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
CHECK_LE(outputs[0].shape_.ndim(), broadcast::MAX_DIM);

const NumpyWhereScalarParam& param = nnvm::get<NumpyWhereScalarParam>(attrs.parsed);
const TBlob& cond = inputs[0];
const TBlob& y = inputs[1];
const TBlob& out = outputs[0];
Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Shape<broadcast::MAX_DIM>> in_strides;
in_strides.resize(2);
for (int i = 0; i < 2; ++i) {
TShape expanded_ishape(broadcast::MAX_DIM, 1);
const TShape& ishape = inputs[i].shape_;
const int ndim_delta = expanded_ishape.ndim() - ishape.ndim();
for (int j = 0; j < ishape.ndim(); ++j) {
expanded_ishape[j + ndim_delta] = ishape[j];
}
in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get<broadcast::MAX_DIM>());
}
TShape expanded_oshape(broadcast::MAX_DIM, 1);
const int ndim_delta = expanded_oshape.ndim() - out.shape_.ndim();
for (int j = 0; j < out.shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (out.shape_)[j];
}
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(out.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, {
mxnet_op::Kernel<numpy_where_scalar_kernel<broadcast::MAX_DIM, is_left>, xpu>::Launch(
s, out.Size(), req[0],
in_strides[0], in_strides[1], oshape,
cond.dptr<CType>(), DType(param.scalar),
y.dptr<DType>(), out.dptr<DType>());
});
});
}

template<typename xpu, bool is_left>
inline void NumpyWhereScalarOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK(common::is_float(inputs[0].type_flag_)) << "Backward only supports float types!";
if (inputs[0].shape_.Size() == 0U) return; // zero-size tensor

Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& ograd = inputs[0];
const TBlob& cond = inputs[1];
const TBlob& dx = outputs[0];
// get expanded oshape
TShape expanded_oshape(broadcast::MAX_DIM, 1);
int ndim_delta = expanded_oshape.ndim() - ograd.shape_.ndim();
for (int j = 0; j < ograd.shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (ograd.shape_)[j];
}
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
// get cond stride
TShape expanded_cshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_cshape.ndim() - cond.shape_.ndim();
for (int j = 0; j < cond.shape_.ndim(); ++j) {
expanded_cshape[j + ndim_delta] = (cond.shape_)[j];
}
Shape<broadcast::MAX_DIM> cstride =
mxnet_op::calc_stride(expanded_cshape.get<broadcast::MAX_DIM>());
// get expanded lshape
TShape expanded_lshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_lshape.ndim() - dx.shape_.ndim();
for (int j = 0; j < dx.shape_.ndim(); ++j) {
expanded_lshape[j + ndim_delta] = (dx.shape_)[j];
}

MSHADOW_TYPE_SWITCH_WITH_BOOL(ograd.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, {
Tensor<xpu, 1, char> largespace;
Tensor<xpu, broadcast::MAX_DIM, DType> workspace;
size_t ws_size = 0;
if (ograd.shape_ != dx.shape_) {
ws_size = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_lshape, req[0], expanded_oshape);
}
// process left output
if (ograd.shape_ == dx.shape_) {
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, is_left>, xpu>::Launch(
s, ograd.Size(), req[0], cstride, oshape,
cond.dptr<CType>(), ograd.dptr<DType>(), dx.dptr<DType>());
} else {
largespace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(ograd.shape_.Size() * sizeof(DType) + ws_size), s);
workspace = Tensor<xpu, broadcast::MAX_DIM, DType>(
reinterpret_cast<DType*>(largespace.dptr_ + ws_size),
expanded_oshape.get<broadcast::MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, true>, xpu>::Launch(
s, ograd.Size(), req[0], cstride, oshape,
cond.dptr<CType>(), ograd.dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(dx.type_flag_, dx.type_flag_)) {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true>(ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
} else {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, false>(ctx, {TBlob(workspace)}, {req[0]},
{dx.reshape(expanded_lshape)}, expanded_lshape);
}
}
});
});
}

template<typename xpu>
inline void NumpyWhereScalar2OpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
using namespace mxnet_op;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const NumpyWhereScalar2Param& param = nnvm::get<NumpyWhereScalar2Param>(attrs.parsed);
const TBlob& cond = inputs[0];
const TBlob& out = outputs[0];
MSHADOW_TYPE_SWITCH_WITH_BOOL(out.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(cond.type_flag_, CType, {
Kernel<numpy_where_scalar2_kernel, xpu>::Launch(s, out.Size(), req[0],
out.dptr<DType>(), cond.dptr<CType>(), DType(param.x), DType(param.y));
});
});
}

} // namespace op
} // namespace mxnet

Expand Down
Loading

0 comments on commit 1cd53d8

Please sign in to comment.