Skip to content

Commit

Permalink
Square sum backward support one more case (apache#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce authored and eric-haibin-lin committed Aug 10, 2017
1 parent 17bfa4e commit a44afed
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions src/operator/tensor/square_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,30 +196,26 @@ struct SquareSumRspGradKernel<req, 1> {
};

/*!
* This kernel assumes that the ograd and in_data
* are all rsp and have equal row_idx array.
* TODO(junwu): make the kernel general to support
* the cases when ograd and in_data have different
* row_idx arrays.
* Note: This kernel assumes that the ograd and in_data
* are all rsp and have equal row_idx array, or
* in_data is a full rsp.
*/
template<int req>
struct SquareSumRspGradKernel<req, 1, kRowSparseStorage> {
/*!
* \param i index of out_grad_row_idx
* \param i index of igrad.data()
* \param in_grad_row_idx row_idx of the gradient of the op's input
* \param in_grad gradient of the op's input
* \param out_grad_row_idx row_idx of the gradient of the op's output
* \param out_grad gradient of the op's output
* \param in_row_idx row idx of the op's input
* \param in_data op's input
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* in_grad_row_idx, DType* in_grad,
const IType* out_grad_row_idx, const DType* out_grad,
const IType* in_row_idx, const DType* in_data,
const int64_t num_cols) {
const DType* in_data, const int64_t num_cols) {
const int64_t row = i / num_cols;
in_grad_row_idx[row] = in_row_idx[row];
in_grad_row_idx[row] = out_grad_row_idx[row];
KERNEL_ASSIGN(in_grad[i], req, 2*in_data[i]*out_grad[row]);
}
};
Expand Down Expand Up @@ -341,7 +337,7 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
const TBlob& igrad_data = igrad->data();
const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
const TBlob& ograd_data = ograd.data();
const TBlob in_data = input.data();
const TBlob& in_data = input.data();
const TBlob in_row_idx = input.aux_data(rowsparse::kIdx);
if (ograd.storage_type() == kDefaultStorage) {
if (0 == param.axis[0]) { // forward is sum per column
Expand Down Expand Up @@ -372,16 +368,20 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
" when ograd_stype = kRowSparseStorage";
CHECK_EQ(ograd.shape().ndim(), 2U);
const TBlob ograd_row_idx = ograd.aux_data(rowsparse::kIdx);
CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size());
CHECK(ograd_row_idx.Size() == in_row_idx.Size() || in_row_idx.Size() == in_data.shape_[0]);
MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
if (std::is_same<xpu, cpu>::value) {
const IType* first1 = ograd_row_idx.dptr<IType>();
const IType* last1 = first1 + ograd_row_idx.Size();
const IType* first2 = in_row_idx.dptr<IType>();
CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
" equal ograd_row_idx and input_row_idx"
" when ograd and input are both"
" row-sparse";
// when ograd_row_idx and in_row_idx have the same size and input is not a full rsp
// ograd_row_idx and in_row_idx are expected to have the same elements
if (ograd_row_idx.Size() == in_row_idx.Size() && in_row_idx.Size() != in_data.shape_[0]) {
CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
" equal ograd_row_idx and input_row_idx"
" when ograd and input are both"
" row-sparse";
}
} else {
LOG(FATAL) << "SquareSumRspGradImpl has not implemented GPU version when"
" ograd and input are both row-sparse";
Expand All @@ -391,8 +391,7 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage>, xpu>::Launch(
s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
ograd_data.dptr<DType>(), in_row_idx.dptr<IType>(),
in_data.dptr<DType>(), num_cols);
ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
})
})
})
Expand Down

0 comments on commit a44afed

Please sign in to comment.