Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add op npx.index_update #18545

Merged
merged 7 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ def _npx_index_add(a, ind, val):
"""
Add values to input according to given indexes.
If exists repeate positions to be updated, the update value will be accumulated.

Parameters
----------
a : ndarray
Expand All @@ -643,10 +644,12 @@ def _npx_index_add(a, ind, val):
- ind.dtype should be 'int32' or 'int64'
val : ndarray
Input data. The array to update the input 'a'.

Returns
-------
out : ndarray
The output array.

Examples
--------
>>> a = np.zeros((2, 3, 4))
Expand Down Expand Up @@ -699,6 +702,75 @@ def _npx_index_add(a, ind, val):
pass


def _npx_index_update(a, ind, val):
"""
Update values to input according to given indexes.
If multiple indices refer to the same location it is undefined which update is chosen; it may choose
the order of updates arbitrarily and nondeterministically (e.g., due to concurrent updates on some
hardware platforms). Recommend not to use repeate positions.
JiangZhaoh marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
a : ndarray
Input data. The array to be updated.
Support dtype: 'float32', 'float64', 'int32', 'int64'.
ind : ndarray
Indexes for indicating update positions.
For example, array([[0, 1], [2, 3], [4, 5]] indicates here are two positions to
be updated, which is (0, 2, 4) and (1, 3, 5).
Note: - 'ind' cannot be empty array '[]', for that case, please use operator 'add' instead.
- 0 <= ind.ndim <= 2.
- ind.dtype should be 'int32' or 'int64'
val : ndarray
Input data. The array to update the input 'a'.
Support dtype: 'float32', 'float64', 'int32', 'int64'.

Returns
-------
out : ndarray
The output array.

Examples
--------
>>> a = np.zeros((2, 3, 4))
>>> ind = np.array([[0, 0], [0, 0], [0, 1]], dtype='int32')
>>> val = np.arange(2).reshape(2) + 1
>>> b = npx.index_update(a, ind, val)
>>> b
array([[[1., 2., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],

[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])

>>> ind=np.array([[0, 0], [0, 1]], dtype='int32')
>>> val = np.arange(8).reshape(2, 4)
>>> b = npx.index_update(a, ind, val)
>>> b
array([[[0., 1., 2., 3.],
[4., 5., 6., 7.],
[0., 0., 0., 0.]],

[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])

>>> val = np.arange(4).reshape(4) # brocast 'val'
>>> b = npx.index_update(a, ind, val)
>>> b
array([[[0., 1., 2., 3.],
[0., 1., 2., 3.],
[0., 0., 0., 0.]],

[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
"""
pass


def _np_diag(array, k=0):
"""
Extracts a diagonal or constructs a diagonal array.
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/index_add-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ inline bool IndexModifyOpType(const nnvm::NodeAttrs& attrs,
CHECK_NE((*in_attrs)[1], -1);
CHECK_NE((*in_attrs)[2], -1);
CHECK_EQ((*in_attrs)[0], (*in_attrs)[2])
<< "index_add(a, ind, val) only support a.dtype == val.dtype";
<< "index_add/index_update(a, ind, val) only support a.dtype == val.dtype";
CHECK((*in_attrs)[1] == mshadow::kInt64 ||
(*in_attrs)[1] == mshadow::kInt32)
<< "'ind' only support int dtype.";
Expand Down
18 changes: 9 additions & 9 deletions src/operator/tensor/index_add_backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ void IndexAddBackwardValCPUCompute(DType* grad_val,

template<>
void IndexAddOpBackwardValImpl<cpu>(const OpContext& ctx,
const TBlob& grad_val,
const TBlob& ograd,
const TBlob& t_ind,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
const int tail_size, const int ind_num, const int ind_ndim,
const int ndim) {
const TBlob& grad_val,
const TBlob& ograd,
const TBlob& t_ind,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
const int tail_size, const int ind_num, const int ind_ndim,
const int ndim) {
using namespace mshadow;
using namespace mxnet_op;
int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
Expand Down
264 changes: 264 additions & 0 deletions src/operator/tensor/index_update-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file index_update-inl.h
* \brief Function definition of index_update operator
*/
#ifndef MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
#define MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_

#include <mxnet/operator_util.h>
#include <vector>
#include <algorithm>
#include "./index_add-inl.h"
#include "./sort_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {

template<typename xpu, typename DType>
void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
const int ind_num, DType* out,
const DType* val,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_shape,
const int a_tail_size,
const int ind_ndim, const int* ind,
const int a_ndim);

template<typename xpu>
void IndexUpdateOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob a = inputs[0];
TBlob ind = inputs[1];
TBlob val = inputs[2];
TBlob out = outputs[0];
CHECK_GT(a.shape_.ndim(), 0) << "The first input is saclar, please use '=' instead.";
int a_ndim = a.shape_.ndim();
CHECK_LE(a_ndim, MXNET_SPECIAL_MAX_NDIM)
<< "ndim should less than "<< MXNET_SPECIAL_MAX_NDIM
<< "but get " << a_ndim <<"\n";
int val_ndim = val.shape_.ndim();
if (val_ndim == 0) {
val.shape_ = Shape1(1);
val_ndim = 1;
}
// ind=np.array([]), ind.shape_.ndim() = 1
// ind=np.array(1), ind.shape_.ndim() = 0
// ind=np.array([[0,0],[0,1]]), ind.shape_.ndim() = 2
CHECK_NE(ind.shape_.Size(), 0) << "Param 'ind' is []. Please just use op 'add' instead.\n";
CHECK_LE(ind.shape_.ndim(), 2) << "'ind' array allow 2 dimension at most.";
if (ind.shape_.ndim() == 0) {
ind.shape_ = Shape2(1, 1);
} else if (ind.shape_.ndim() == 1) {
ind.shape_ = Shape2(1, ind.shape_[0]);
}
int ind_ndim = ind.shape_[0];
int ind_num = ind.shape_[1];
CHECK_LE(ind_ndim, a_ndim) << "IndexError: too many indices for array.";

// check 'val' broadcast legality
CHECK_LE(val_ndim, a_ndim - ind_ndim + 1)
<< "The ndim of param 'val' is " << val_ndim
<< ", but it should less than or equal to " << a_ndim - ind_ndim + 1;
for (int i = a_ndim - 1, j = val_ndim - 1; j >= 0 ; --i, --j) {
if ((j == 0) && (val_ndim == a_ndim - ind_ndim + 1)) {
// val_ndim == a_ndim - ind_ndim + 1, check the first dim of input 'val'
CHECK(val.shape_[j] == ind_num || val.shape_[j] == 1)
<< "can not broadcast from " << val.shape_[j] << " to " << ind_num;
} else {
CHECK(val.shape_[j] == a.shape_[i] || val.shape_[j] == 1)
<< "can not broadcast from " << val.shape_[j] << " to " << a.shape_[i]
<< " in axis " << i;
}
}
int a_tail_size = static_cast<int>(a.shape_.ProdShape(ind_ndim, a_ndim));
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_shape, val_shape;
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = a_ndim - 1; i >= 0; --i, --j) {
a_shape[i] = (j >= 0) ? a.shape_[j] : 1;
}
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_pre_shape(a_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_shape(a_shape);

int seg = MXNET_SPECIAL_MAX_NDIM - a_ndim;
for (int i = seg; i < ind_ndim + seg; ++i) {
a_tail_shape[i] = 1;
}
for (int i = ind_ndim + seg; i < a_ndim + seg; ++i) {
a_pre_shape[i] = 1;
}
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = val_ndim - 1; i >= 0; --i, --j) {
val_shape[i] = (j >= 0) ? val.shape_[j] : 1;
}
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_pre_stride = calc_stride(a_pre_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride = calc_stride(val_shape);
mxnet_op::copy(s, out, a);
TBlob t_ind = TBlob(ctx.requested[0].get_space_typed<xpu, 1, int>
(Shape1(ind.shape_.Size()), s));
mxnet_op::copy(s, t_ind, ind);
MSHADOW_TYPE_SWITCH(a.type_flag_, DType, {
IndexUpdateForwardCalc<xpu, DType>(s, ind_num,
out.dptr<DType>(), val.dptr<DType>(),
a_tail_shape, a_pre_stride,
val_stride, val_shape, a_shape,
a_tail_size, ind_ndim,
t_ind.dptr<int>(), a_ndim);
});
}

template<typename xpu>
void IndexUpdateOpBackwardValImpl(const OpContext& ctx,
const TBlob& grad_val,
const TBlob& ograd,
const TBlob& t_ind,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
const int tail_size, const int ind_num, const int ind_ndim,
const int ndim);

template<typename xpu>
inline void IndexUpdateOpBackwardVal(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
if (req[0] == kNullOp) {
return;
}
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const TBlob& ograd = inputs[0];
TBlob ind = inputs[1];
const TBlob& grad_val = outputs[0];
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// get the number of 'ind' index
if (ind.shape_.ndim() == 0) {
ind.shape_ = Shape2(1, 1);
} else if (ind.shape_.ndim() == 1) {
ind.shape_ = Shape2(1, ind.shape_[0]);
}
int ind_ndim = ind.shape_[0];
int ind_num = ind.shape_[1];
int out_ndim = ograd.shape_.ndim();
int tail_size = static_cast<int>(ograd.shape_.ProdShape(ind_ndim, out_ndim));
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_shape, val_shape;
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = out_ndim - 1; i >= 0; --i, --j) {
ograd_shape[i] = (j >= 0) ? ograd.shape_[j] : 1;
}
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_shape(ograd_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape(ograd_shape);
TBlob t_ind = TBlob(ctx.requested[0].get_space_typed<xpu, 1, int>
(Shape1(ind.shape_.Size()), s));
mxnet_op::copy(s, t_ind, ind);
int seg = MXNET_SPECIAL_MAX_NDIM - out_ndim;
for (int i = seg; i < seg + ind_ndim; ++i) {
ograd_tail_shape[i] = 1;
}
for (int i = seg + ind_ndim; i < seg + out_ndim; ++i) {
ograd_pre_shape[i] = 1;
}
for (int i = seg + out_ndim - 1, j = grad_val.shape_.ndim() - 1; i >= seg; --i, --j) {
val_shape[i] = (j >= 0) ? grad_val.shape_[j] : 1;
}
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride = mxnet_op::calc_stride(ograd_pre_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride = mxnet_op::calc_stride(val_shape);
IndexUpdateOpBackwardValImpl<xpu>(ctx, grad_val, ograd, t_ind, ograd_tail_shape, ograd_pre_stride,
val_stride, val_shape, tail_size, ind_num, ind_ndim, out_ndim);
}

template<typename DType>
struct ReqCopy {
MSHADOW_XINLINE static void Map(size_t i, DType* dest, const DType* src, const int req) {
KERNEL_ASSIGN(dest[i], req, src[i]);
}
};

template<typename xpu>
void IndexUpdateOpBackwardAImpl(const OpContext& ctx,
const TBlob& grad_a,
const TBlob& ograd,
const TBlob& t_ind,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
const int tail_size, const int ind_num, const int ind_ndim,
const int seg, const int req);

template<typename xpu>
inline void IndexUpdateOpBackwardA(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
if (req[0] == kNullOp) {
return;
}
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
TBlob ograd = inputs[0];
TBlob ind = inputs[1];
const TBlob& grad_a = outputs[0];
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// get the number of 'ind' index
if (ind.shape_.ndim() == 0) {
ind.shape_ = Shape2(1, 1);
} else if (ind.shape_.ndim() == 1) {
ind.shape_ = Shape2(1, ind.shape_[0]);
}
int ind_ndim = ind.shape_[0];
int ind_num = ind.shape_[1];
int out_ndim = ograd.shape_.ndim();
int tail_size = static_cast<int>(ograd.shape_.ProdShape(ind_ndim, out_ndim));
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_shape;
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = out_ndim - 1; i >= 0; --i, --j) {
grada_shape[i] = (j >= 0) ? grad_a.shape_[j] : 1;
}
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_shape(grada_shape);
int seg = MXNET_SPECIAL_MAX_NDIM - out_ndim;
for (int i = seg + ind_ndim; i < seg + out_ndim; ++i) {
grada_pre_shape[i] = 1;
}
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride = mxnet_op::calc_stride(grada_pre_shape);
IndexUpdateOpBackwardAImpl<xpu>(ctx, grad_a, ograd, ind, grada_pre_stride,
tail_size, ind_num, ind_ndim, seg, req[0]);
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
Loading