diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 1f7b1d3aed1b..04107128cf4b 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -4,6 +4,7 @@ import logging from .ndarray import NDArray, zeros, clip, sqrt, sign from .ndarray import sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update +from .sparse_ndarray import zeros as sparse_zeros from .random import normal @@ -332,7 +333,8 @@ def create_state(self, index, weight): if self.momentum == 0.0: return None else: - return zeros(weight.shape, weight.context, dtype=weight.dtype) + return sparse_zeros(weight.storage_type, weight.shape, + weight.context, dtype=weight.dtype) def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index 79351b1eb371..bc06fc1d1113 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -571,7 +571,7 @@ def to_dense(source): """ return ndarray.cast_storage(source, storage_type='default') -def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): +def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs): """Return a new array of given shape and type, filled with zeros. Parameters @@ -599,6 +599,8 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): >>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() array([[ 0., 0.]], dtype=float16) """ + if storage_type == 'default': + return ndarray.zeros(shape, ctx, dtype, **kwargs) if ctx is None: ctx = Context.default_ctx dtype = mx_real_t if dtype is None else dtype @@ -609,7 +611,7 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): raise Exception("unknown storage type") assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type])) out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types)) - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs) def _ndarray_cls(handle, writable=True): stype = _storage_type(handle) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index ecfb9c76acb3..a6d78c2558be 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -110,6 +110,19 @@ inline std::string type_string(const int& x) { return "unknown"; } +/*! \brief get string representation of storage_type */ +inline std::string stype_string(const int& x) { + switch (x) { + case kDefaultStorage: + return "default"; + case kCSRStorage: + return "csr"; + case kRowSparseStorage: + return "row_sparse"; + } + return "unknown"; +} + /*! * \brief Assign x to y. Checks for compatiblity when y is not empty. * Allow missing dim in both x and y (as 0). @@ -186,6 +199,24 @@ inline bool type_assign(int *y, const int& x) { } \ } +/*! + * \brief macro assign type to out if out is unknown (-1) otherwise check consistency + * Use macro so we can see the error file more clearly + * \param type_array the storage type array to store the result + * \param index the index of in the array + * \param type the inferred storage type + */ +#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \ + { \ + if (!type_assign(&(type_array)[index], type)) { \ + std::ostringstream os; \ + os << "Storage type inconsistent, Provided=" \ + << stype_string((type_array)[index]) << ',' \ + << " inferred storage type=" << stype_string(type); \ + throw ::mxnet::op::InferTypeError(os.str(), index); \ + } \ + } + // helper macro to implement bind dispatch #if MXNET_USE_CUDA #define DO_BIND_DISPATCH(Method, ...) \ diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 83a4a9cfccbb..d6d8ccc37c53 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -112,32 +112,31 @@ struct SGDDnsRspKernel { template inline void SGDUpdateDnsRspImpl(const SGDParam& param, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext &ctx, + const TBlob& weight, + const NDArray& grad, + const OpReqType& req, + TBlob *out) { using namespace mshadow; using namespace mshadow::expr; using namespace mshadow_op; + using namespace mxnet_op; Stream* s = ctx.get_stream(); - auto &weight = inputs[0]; - auto &grad = inputs[1]; - auto &out = outputs[0]; - CHECK_EQ(weight.storage_type(), kDefaultStorage); CHECK_EQ(grad.storage_type(), kRowSparseStorage); - if (!grad.storage_initialized()) return; + // if gradients are zeros, no weights are updated + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); - MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - auto weight_data = weight.data().FlatTo2D(s); - auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); - auto grad_val = grad.data().FlatTo2D(s); - auto out_data = out.data().FlatTo2D(s); + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr(); + auto grad_val = grad.data().dptr(); auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; - auto width = weight.shape().ProdShape(1, weight.shape().ndim()); - mxnet_op::Kernel, xpu>::Launch(s, num_rows, width, - out_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + auto width = weight.shape_.ProdShape(1, weight.ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out->dptr(), weight_data, grad_idx, grad_val, static_cast(param.clip_gradient), static_cast(param.lr), static_cast(param.wd), static_cast(param.rescale_grad)); @@ -146,6 +145,29 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, }); } +template +inline void SGDUpdateRspRspImpl(const SGDParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const OpReqType& req, + NDArray *out) { + if (weight.storage_shape()[0] == weight.shape()[0] && + out->storage_shape()[0] == out->shape()[0]) { + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + // reuse dns rsp implementation when storage_shape == shape + TBlob out_blob = out->data(); + SGDUpdateDnsRspImpl(param, ctx, weight.data(), grad, out_req, &out_blob); + } else { + LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when " + << "weights.values.shape == weights.shape"; + } +} + template inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -159,7 +181,11 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, auto weight_stype = inputs[0].storage_type(); auto grad_stype = inputs[1].storage_type(); if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) { - SGDUpdateDnsRspImpl(param, ctx, inputs, req, outputs); + TBlob out = outputs[0].data(); + SGDUpdateDnsRspImpl(param, ctx, inputs[0].data(), inputs[1], req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDUpdateRspRspImpl(param, ctx, inputs[0], inputs[1], req[0], &out); } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); } @@ -262,30 +288,31 @@ struct SGDMomDnsRspDnsKernel { template inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out) { using namespace mxnet_op; + using namespace rowsparse; Stream* s = ctx.get_stream(); - auto &weight = inputs[0]; - auto &grad = inputs[1]; - auto &mom = inputs[2]; - auto &out = outputs[0]; - if (!grad.storage_initialized()) return; + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(mom.shape_.Size(), 0); - MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { - MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - auto weight_data = weight.data().FlatTo2D(s); - auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); - auto grad_val = grad.data().FlatTo2D(s); - auto mom_data = mom.data().FlatTo2D(s); - auto out_data = out.data().FlatTo2D(s); - auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; - auto width = weight.shape().ProdShape(1, weight.shape().ndim()); + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(kIdx).dptr(); + auto grad_val = grad.data().dptr(); + auto mom_data = mom.dptr(); + auto out_data = out->dptr(); + auto num_rows = grad.aux_shape(kIdx)[0]; + auto width = weight.shape_.ProdShape(1, weight.ndim()); Kernel, xpu>::Launch(s, num_rows, width, - out_data.dptr_, mom_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + out_data, mom_data, weight_data, grad_idx, grad_val, static_cast(param.clip_gradient), static_cast(param.momentum), static_cast(param.lr), static_cast(param.wd), static_cast(param.rescale_grad)); @@ -294,6 +321,50 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, }); } +template +inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mom, + const OpReqType& req, + NDArray *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + using namespace rowsparse; + if (weight.storage_shape()[0] == weight.shape()[0] && + out->storage_shape()[0] == out->shape()[0]) { + Stream* s = ctx.get_stream(); + // fill mom with zero values in order to reuse the sgd mom dns impl + if (!mom.storage_initialized()) { + MSHADOW_REAL_TYPE_SWITCH(mom.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(mom.aux_type(kIdx), IType, { + auto num_rows = mom.shape()[0]; + mom.CheckAndAlloc({Shape1(num_rows)}); + auto mom_idx = mom.aux_data(kIdx).FlatTo1D(s); + auto mom_val = mom.data(); + // TODO(haibin) this is single-thread execution + Kernel::Launch(s, mom_val.Size(), mom_val.dptr()); + ASSIGN_DISPATCH(mom_idx, kWriteTo, range(0, num_rows, 1, 1)) + }); + }); + } + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + TBlob out_blob = out->data(); + // reuse dns rsp implementation when storage_shape == shape + SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + mom.data(), out_req, &out_blob); + } else { + LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when " + << "weights.values.shape == weights.shape"; + } +} + template inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -305,10 +376,16 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, auto weight_stype = inputs[0].storage_type(); auto grad_stype = inputs[1].storage_type(); auto mom_stype = inputs[2].storage_type(); - if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage && mom_stype == kDefaultStorage) { - SGDMomUpdateDnsRspDnsImpl(param, ctx, inputs, req, outputs); + TBlob out = outputs[0].data(); + SGDMomUpdateDnsRspDnsImpl(param, ctx, inputs[0].data(), inputs[1], + inputs[2].data(), req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && + mom_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDMomUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1], + inputs[2], req[0], &out); } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && mom_stype == kDefaultStorage) { FCompExFallback(attrs, ctx, inputs, req, outputs, diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 996a25d5a647..64b7c34359b9 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -324,10 +324,8 @@ inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDA struct CastStorageRspDnsKernel { template MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data, - DType* dns, const index_t invalid_rid) { + DType* dns) { auto rid = idx[i]; - // skip invalid rows - if (rid == invalid_rid) return; auto dns_offset = rid * width; auto rsp_offset = i * width; for (size_t col = 0; col < width; col++) { @@ -356,10 +354,9 @@ void CastStorageRspDnsImpl(mshadow::Stream* s, const NDArray& rsp, TBlob* d auto out_data = dns->FlatTo2D(s).dptr_; auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); auto rsp_shape = rsp.shape(); - auto invalid_rid = rsp_shape[0]; auto width = rsp_shape.ProdShape(1, rsp_shape.ndim()); - mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, in_data, - out_data, invalid_rid); + mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, + in_data, out_data); } }); }); diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 8cf00c0eb7b4..da20cf49f1a0 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -87,8 +87,17 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("FCompute", EmbeddingOpBackward); NNVM_REGISTER_OP(SparseEmbedding) -.describe(R"code(Maps integer indices to vector representations (embeddings) with sparse weight update -)code" ADD_FILELINE) +.describe(R"doc(Represents words or other sparse inputs by dense continuous vectors. +It assumes that the input is in one-hot form. E.g., for a vocabulary size of 10,000, + each input vector is expected to have dimension 10,000. +The index of the non-zero entry is the index of the word or item it represents. + +The corresponding embedding vectors are stored as rows of a matrix. +Hence, mapping an input word to its embedding is implemented as a matrix product. + +The gradient of an embedding matrix has the form of gradient vectors that are only + non-zero for words seen in a minibatch. +)doc" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) .set_attr_parser(ParamParser) @@ -96,19 +105,21 @@ NNVM_REGISTER_OP(SparseEmbedding) [](const NodeAttrs& attrs) { return std::vector{"data", "weight"}; }) -.set_attr("FInferShape", EmbeddingOpShape) +.set_attr("FInferShape", SparseEmbeddingShape) .set_attr("FInferType", EmbeddingOpType) +.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FCompute", EmbeddingOpForward) +.set_attr(FCOMP_EX_CPU, SparseEmbeddingForwardEx) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds, {n->inputs[0]}, n->attrs.dict); }) -.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.") +.add_argument("data", "NDArray-or-Symbol", + "The input array to the sparse embedding operator.") .add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") .add_arguments(EmbeddingParam::__FIELDS__()); @@ -116,10 +127,7 @@ NNVM_REGISTER_OP(_backward_SparseEmbedding) .set_num_inputs(2) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", SparseEmbeddingBackwardStorageType) -.set_attr("FComputeEx", SparseEmbeddingOpBackwardEx); -// TODO(haibin) handle dense case -// .set_attr("FCompute", EmbeddingOpBackward); +.set_attr("FComputeEx", SparseEmbeddingBackwardEx); NNVM_REGISTER_OP(take) .describe(R"code(Takes elements from an input array along the given axis. diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 81b219f7c2c9..7387b7dc79f1 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -23,6 +22,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "./sort_op.h" +#include "./matrix_op-inl.h" namespace mxnet { namespace op { @@ -204,6 +204,82 @@ void EmbeddingOpForward(const nnvm::NodeAttrs& attrs, }); } +template +void SparseEmbeddingForwardRspImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const NDArray& data, + const NDArray& weight, + const OpReqType req, + NDArray *out) { + if (weight.storage_shape()[0] == weight.shape()[0]) { + TBlob out_blob = out->data(); + // forward to dns implementation when storage_shape equals shape + bool transpose_a = false; + DotCsrRspDnsImpl(ctx, data, weight, req, transpose_a, &out_blob); + } else { + LOG(FATAL) << "SparseEmbedding for RowSparse weights is only implemented when " + << "weights.values.shape == weights.shape"; + } +} + +template +void SparseEmbeddingForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(req[embedding::kOut], kWriteTo); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + NDArray output = outputs[embedding::kOut]; + auto data_stype = inputs[embedding::kData].storage_type(); + auto weight_stype = inputs[embedding::kWeight].storage_type(); + auto out_stype = outputs[embedding::kOut].storage_type(); + if (data_stype == kCSRStorage && weight_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + NDArray ret = outputs[embedding::kOut]; + SparseEmbeddingForwardRspImpl(attrs, ctx, inputs[embedding::kData], + inputs[embedding::kWeight], + req[embedding::kOut], &ret); + } else { + LOG(FATAL) << "Not supported SparseEmbedding operation for data.storage_type = " + << data_stype << ", weight.storage_type = " << weight_stype + << ", out.storage_type = " << out_stype; + } +} + +inline bool SparseEmbeddingForwardStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, embedding::kData, kCSRStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, embedding::kOut, kDefaultStorage); + // override the default storage type generated in nnvm + in_attrs->at(embedding::kWeight) = kRowSparseStorage; + return true; +} + +inline bool SparseEmbeddingShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace mshadow; + const EmbeddingParam& param = nnvm::get(attrs.parsed); + const TShape &dshape = (*in_attrs)[embedding::kData]; + CHECK_EQ(dshape.ndim(), 2) + << "SparseEmbedding shape error: data is expected to be 2D."; + SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, + Shape2(param.input_dim, param.output_dim)); + out_attrs->clear(); + std::vector buf(2); + buf[0] = dshape[0]; + buf[1] = param.output_dim; + out_attrs->emplace_back(buf.begin(), buf.end()); + return true; +} + // Returns integer log2(a) rounded up inline int ilog2(unsigned int a) { int k = 1; @@ -316,130 +392,28 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, }); } -template -struct EmbeddingBackwardRsp { - template - // each thread i is responsible for target gradient row ids in [segment_start, segment_end) - MSHADOW_XINLINE static void Map(int i, const size_t width, IType* dst_idx, DType* dst_val, - const IType* idx, const size_t num_idx, const DType* src, - const size_t segment_len, const size_t num_rows) { - auto req_type = req; - size_t segment_start = i * segment_len; - size_t segment_end = (i + 1) * segment_len; - for (size_t y = 0; y < num_idx; y++) { - size_t j = idx[y]; - if (j >= num_rows) j = num_rows - 1; - if (j < segment_start || j >= segment_end) continue; - dst_idx[j] = j; - for (size_t k = 0; k < width; k++) { - if (req_type == kWriteTo) req_type = kAddTo; - KERNEL_ASSIGN(dst_val[j * width + k], req_type, src[y * width + k]); - } - } - } -}; - -/* - * for sparse embedding, the storage type for weight gradient is row_sparse. - * we don't care about the storage type for data gradient, since it is not - * differentiable. - */ -inline bool SparseEmbeddingBackwardStorageType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ((*in_attrs)[0], kDefaultStorage); - CHECK_EQ((*in_attrs)[1], kDefaultStorage); - (*out_attrs)[0] = kRowSparseStorage; - (*out_attrs)[1] = kRowSparseStorage; - return true; -} - template -void SparseEmbeddingOpBackwardDnsDnsRsp(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mxnet_op; - using namespace mshadow::expr; +void SparseEmbeddingBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 2U); - if (req[1] == kNullOp) return; - // check storage types - auto idx = inputs[1]; // idx shape (d1, d2 .. dk) - auto grad = inputs[0]; // grad shape (d1, d2, .. dk, out_dim) - auto output = outputs[1]; // weight shape (in_dim, out_dim) - CHECK_EQ(idx.storage_type(), kDefaultStorage); - CHECK_EQ(grad.storage_type(), kDefaultStorage); - CHECK_EQ(output.dtype(), grad.dtype()); - CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) << "Index type doesn't match"; + CHECK_EQ(req.size(), 2U); // CHECK_EQ(req[embedding::kData], kNullOp) - // << "Embedding layer doesn't support calculate data gradient" << req[embedding::kData]; + // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; + // CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; - const TShape& ishape = idx.shape(); - const TShape& oshape = grad.shape(); - - Stream *s = ctx.get_stream(); - CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) - << "embedding input index and gradient row sparse type doesn't match!"; - // Alloc dense output - unsigned int num_rows = output.shape()[0]; - output.CheckAndAlloc({mshadow::Shape1(num_rows)}); - MSHADOW_TYPE_SWITCH(output.dtype(), DType, { - MSHADOW_INT_TYPE_SWITCH(idx.dtype(), IType, { - MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { - // input embedding indice, each idx in [0, input_dim) - auto idx_data = idx.data().FlatTo1D(s); - auto grad_data = grad.data().get_with_shape( - Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); - auto output_idx = output.aux_data(rowsparse::kIdx).FlatTo1D(s); - auto output_val = output.data().FlatTo2D(s); - int num_threads = omp_get_num_threads(); - size_t width = output.shape()[1]; - size_t segment_len = (num_rows + num_threads - 1) / num_threads; - // fill indices with invalid row ids - Kernel::Launch(s, num_rows, output_idx.dptr_, - static_cast(num_rows)); - // fill zeros if needed - if (req_type == kWriteTo) { - Kernel::Launch(s, output_val.shape_.Size(), output_val.dptr_); - } - Kernel, xpu>::Launch(s, num_threads, width, - output_idx.dptr_, - output_val.dptr_, idx_data.dptr_, - ishape.Size(), grad_data.dptr_, - segment_len, num_rows); - }); - }); - }); -} - -// todo replace xpu with cpu -template -void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mxnet_op; - using namespace mshadow::expr; - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 2U); - // CHECK_EQ(req[embedding::kData], kNullOp) - // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; - // idx shape (d1, d2 .. dk) - auto idx_stype = inputs[1].storage_type(); - // grad shape (d1, d2, .. dk, out_dim) + auto data_stype = inputs[1].storage_type(); auto grad_stype = inputs[0].storage_type(); - // weight shape (in_dim, out_dim) auto output_stype = outputs[1].storage_type(); - if (idx_stype == kDefaultStorage && grad_stype == kDefaultStorage && - output_stype == kRowSparseStorage) { - SparseEmbeddingOpBackwardDnsDnsRsp(attrs, ctx, inputs, req, outputs); + if (data_stype == kCSRStorage && grad_stype == kDefaultStorage && + output_stype == kDefaultStorage) { + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], true, &ret); } else { - LOG(FATAL) << "Not implemented"; + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; } } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 05fba76d0ff3..f01d6428b0d4 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -643,22 +643,20 @@ struct DotCsrTransDnsDnsByRowBlocks { template void DotCsrDnsDnsImpl(const OpContext& ctx, const NDArray& lhs, - const NDArray& rhs, + const TBlob& rhs, const OpReqType req, const bool trans_lhs, - NDArray* ret) { + TBlob* ret) { if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); - CHECK_EQ(rhs.storage_type(), kDefaultStorage); - CHECK_EQ(ret->storage_type(), kDefaultStorage); if (!lhs.storage_initialized()) return; mshadow::Stream *s = ctx.get_stream(); const TBlob data_l = lhs.data(); const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob data_r = rhs.data(); - const TBlob data_out = ret->data(); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type @@ -693,7 +691,7 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); + col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); }); } } @@ -702,6 +700,21 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, }); } +template +void DotCsrRspDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (rhs.storage_shape()[0] == rhs.shape()[0]) { + // reuse csr dns implementation when storage_shape == shape for rhs + DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + } else { + LOG(FATAL) << "Dot for RowSparse rhs is only implemented for rhs.values.shape == rhs.shape"; + } +} + template void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -709,8 +722,25 @@ void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const DotParam& param = nnvm::get(attrs.parsed); - NDArray ret = outputs[1]; - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0], req[1], !param.transpose_a, &ret); + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); +} + +template +void DotBackwardCsrRspDns(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const auto& rhs = inputs[2]; + if (rhs.storage_shape()[0] == rhs.shape()[0]) { + // reuse csr dns implementation when storage_shape == shape for rhs + const DotParam& param = nnvm::get(attrs.parsed); + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else { + LOG(FATAL) << "Dot for RowSparse rhs is only implemented for rhs.values.shape == rhs.shape"; + } } inline bool DotShape(const nnvm::NodeAttrs& attrs, @@ -767,12 +797,16 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(req.size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; - - NDArray ret = outputs[0]; // get rid of the const qualifier - if (inputs[0].storage_type() == kCSRStorage - && inputs[1].storage_type() == kDefaultStorage - && outputs[0].storage_type() == kDefaultStorage) { - DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + auto lhs_stype = inputs[0].storage_type(); + auto rhs_stype = inputs[1].storage_type(); + auto out_stype = outputs[0].storage_type(); + if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); } else { // TODO(junwu): add fallback LOG(FATAL) << "Not supported dot operation for lhs.storage_type = " << inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type() @@ -796,12 +830,19 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, // TODO(junwu): check whether this CHECK is reasonable const DotParam& param = nnvm::get(attrs.parsed); CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; - if (inputs[0].storage_type() == kDefaultStorage // ograd dns format - // dns, csr, dns => *, dns - && inputs[1].storage_type() == kCSRStorage // csr input lhs of the op - && inputs[2].storage_type() == kDefaultStorage // dns input rhs of the op + auto ograd_stype = inputs[0].storage_type(); + auto lhs_stype = inputs[1].storage_type(); + auto rhs_stype = inputs[2].storage_type(); + if (ograd_stype == kDefaultStorage // ograd dns format + && lhs_stype == kCSRStorage // csr input lhs of the op + && rhs_stype == kDefaultStorage // dns input rhs of the op && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format + // dns, csr, dns => *, dns DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); + } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && + rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) { + // dns, csr, rsp => *, dns + DotBackwardCsrRspDns(attrs, ctx, inputs, req, outputs); } else { LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; } diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 6f69828ed9b1..80632c262a8e 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -35,8 +35,8 @@ def compare_optimizer(opt1, opt2, shape, w_stype='default', g_stype='default'): w2 = mx.random.uniform(shape=shape, ctx=default_context()) w1 = w2.copyto(default_context()) elif w_stype == 'row_sparse': - w2 = rand_ndarray(shape, w_stype) - w1 = rand_ndarray(shape, w_stype).to_dense() + w2 = rand_ndarray(shape, w_stype, density=1) + w1 = w2.copyto(default_context()).to_dense() else: raise Exception("type not supported yet") if g_stype == 'default': @@ -51,14 +51,20 @@ def compare_optimizer(opt1, opt2, shape, w_stype='default', g_stype='default'): state1 = opt1.create_state(0, w1) state2 = opt2.create_state(0, w2) if state1 is not None and state2 is not None: - for s1, s2, in zip(state1, state2): - assert(same(s1.asnumpy(), s2.asnumpy())) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + assert(same(s1.asnumpy(), s2.asnumpy())) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) opt1.update(0, w1, g1, state1) opt2.update(0, w2, g2, state2) if state1 is not None and state2 is not None: - for s1, s2, in zip(state1, state2): - assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=1e-4, atol=1e-5) # SGD @@ -230,7 +236,7 @@ def test_sparse_sgd(): {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] for kwarg in kwargs: - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='default', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='row_sparse', g_stype='row_sparse') # ADAM diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index fc27b80f4530..d46a5f7c81a2 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -62,6 +62,7 @@ def check_sparse_nd_zeros(stype, shape): shape = rand_shape_2d() check_sparse_nd_zeros('row_sparse', shape) check_sparse_nd_zeros('csr', shape) + check_sparse_nd_zeros('default', shape) def test_sparse_nd_copy(): diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index d625dfa7906b..ac7be4b41c80 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -65,8 +65,7 @@ def test_elemwise_add_ex_multiple_stages(): exec_test.backward(out_grads=exec_test.outputs) assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) - -# TODO(haibin) also add test for backward pass. Check if exception is thrown +# TODO(haibin) also add test for backward pass. def test_cast_storage_ex(): def test_rsp_to_dns(shape): rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') @@ -102,52 +101,56 @@ def test_dns_to_csr(dns_in): test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) def test_sparse_dot(): - def test_dot_csr_dns(csr_shape, dns_shape, trans_csr): - dns1 = rand_ndarray(csr_shape, 'default') - dns2 = rand_ndarray(dns_shape, 'default') - csr = mx.nd.cast_storage(dns1, storage_type='csr') - out = mx.nd.dot(csr, dns2, transpose_a=trans_csr) + def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): + lhs_dns = rand_ndarray(lhs_shape, 'default') + lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr') + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1) + rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.to_dense() + out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) assert out.storage_type == 'default' - out_expected = mx.nd.dot(dns1, dns2, transpose_a=trans_csr) + out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs) out_np = out_expected.asnumpy() - backward_trans = not trans_csr - rhs_backward_grad = mx.nd.dot(dns1, out_expected, transpose_a=backward_trans).asnumpy() + backward_trans = not trans_lhs + rhs_backward_grad = mx.nd.dot(lhs_dns, out_expected, transpose_a=backward_trans).asnumpy() assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) # test symbolic forward lhs = mx.symbol.Variable('lhs', storage_type='csr') - rhs = mx.symbol.Variable('rhs', storage_type='default') - test = mx.symbol.dot(lhs, rhs, transpose_a=trans_csr) - location = {'lhs': csr, 'rhs': dns2} + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + test = mx.symbol.dot(lhs, rhs, transpose_a=trans_lhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} expected = {'rhs': rhs_backward_grad} - # dot(lhs, rhs) - check_symbolic_forward(test, location, [out_expected.asnumpy()], rtol=1e-3, atol=1e-4) + check_symbolic_forward(test, location, [out_np], rtol=1e-3, atol=1e-4) + # test symbolic backward check_symbolic_backward(test, location, [out_np], expected, grad_req={'lhs': 'null', 'rhs': 'write'}, rtol=1e-3, atol=1e-4) lhs_shape = rand_shape_2d() - test_dot_csr_dns(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), False) - test_dot_csr_dns(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), True) - + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) def test_sparse_embedding(): in_dim = 10 out_dim = 4 batch = 24 - data = mx.sym.Variable("data", dtype=np.int32) + data = mx.sym.Variable("data", storage_type='csr') embed = mx.sym.SparseEmbedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") exe_test = embed.simple_bind(default_context(), grad_req={'data': 'null', 'embed_weight': 'write'}, - data=(batch,)) + data=(batch, in_dim)) + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) np_data = np.random.randint(low=0, high=in_dim, size=batch) np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape) np_onehot = np.zeros((batch, in_dim)) np_onehot[np.arange(batch), np_data] = 1.0 + nd_onehot = mx.nd.array(np_onehot).to_csr() # forward - arg_map["data"][:] = np_data + arg_map["data"][:] = nd_onehot arg_map["embed_weight"][:] = np_weight exe_test.forward(is_train=True) assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, np_weight)) @@ -197,7 +200,6 @@ def test_sparse_retain(): sym = mx.sym.sparse_retain(data=data, indices=idx) check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'}) - if __name__ == '__main__': import nose nose.runmodule()