From 3dc6410e1642e8056937ace83664a47be62a483d Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 29 Jun 2017 13:43:47 -0700 Subject: [PATCH 1/9] Initial checkin Add dot(csr.T, rsp)=rsp2 Add infer storage for dot(csr, rsp)=dns and dot(csr.T, rsp)=rsp2 --- src/operator/tensor/matrix_op-inl.h | 284 +++++++++++++++--- tests/python/unittest/test_sparse_operator.py | 6 +- 2 files changed, 244 insertions(+), 46 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index c684c7ad6057..8391ae9b8f57 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -484,8 +484,8 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); - if (param.transpose_a && kCSRStorage == (*in_attrs)[0] - && kDefaultStorage == (*in_attrs)[1]) { + // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp + if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) { STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); } else { STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); @@ -501,8 +501,7 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 2U); const DotParam& param = nnvm::get(attrs.parsed); STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); - if (!param.transpose_a && kDefaultStorage == (*in_attrs)[0] - && kCSRStorage == (*in_attrs)[1] && kDefaultStorage == (*in_attrs)[2]) { + if (!param.transpose_a && kCSRStorage == (*in_attrs)[1]) { STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); } else { STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); @@ -654,7 +653,7 @@ struct DotCsrTransDnsDnsByRowBlocks { } }; -/*! +/* * \brief Kernel of dot(csr.T(), dns) = rsp * Parallelization by row blocks. * This kernel fills up the row_idx array @@ -693,6 +692,91 @@ struct DotCsrTransDnsRspByRowBlocks { } }; +/*! + * \brief Kernel of dot(csr, rsp) = dns + * Parallelization by row blocks + */ +struct DotCsrRspDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \nnr_r storage_shape[0] of the rsp + * \num_rows dns.shape[0] + * \num_cols dns.shape[1] + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const RType* row_idx_r, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + const RType* row_idx_ptr = std::lower_bound(row_idx_r, row_idx_r+nnr_r, + col_idx_l[indptr_l[j]]); + if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> col_idx_l[indptr_l[j+1]-1]) continue; + for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != row_idx_r+nnr_r;) { + if (col_idx_l[k] == *row_idx_ptr) { + const size_t offset_r = (row_idx_ptr - row_idx_r) * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_l[k] * data_r[offset_r+l]; + } + ++k; + ++row_idx_ptr; + } else if (col_idx_l[k] < *row_idx_ptr) { + ++k; + } else { + ++row_idx_ptr; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), rsp) = dns with row_idx marked for non-zero rows + * Parallelization by row blocks + */ +struct DotCsrTransRspRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param num_rows_l number of rows of lhs matrix + * \param nnr_r number of non-zero rows of rhs matrix + * \param num_rows number of rows of out matrix + * \param num_cols number of cols of out matrix + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx_out, + const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const RType* row_idx_r, const size_t num_rows_l, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t rid = 0; rid < nnr_r; ++rid) { + const auto j = row_idx_r[rid]; + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = rid * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + row_idx_out[col_idx] = 1; // mark nonzero row as 1 + const size_t offset_out = col_idx * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * data_l[k]; + } + } + } + } +}; + template void DotCsrDnsDnsImpl(const OpContext& ctx, const NDArray& lhs, @@ -753,6 +837,9 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, }); } +/*! + * \brief Impl of dot(csr, rsp) + */ template void DotCsrDnsRspImpl(const OpContext& ctx, const NDArray& lhs, @@ -829,34 +916,141 @@ void DotCsrRspDnsImpl(const OpContext& ctx, const OpReqType req, const bool trans_lhs, TBlob* ret) { - CHECK_RSP_ALL_ROWS_NON_ZERO(rhs, "Dot", "rhs"); // reuse csr dns implementation when storage_shape == shape for rhs - DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); -} + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + return; + } -template -void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const DotParam& param = nnvm::get(attrs.parsed); - TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + mshadow::Stream *s = ctx.get_stream(); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) { + if (kWriteTo == req) { + MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { // data type + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + }); + } + return; + } + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + } + int num_threads = mxnet_op::get_num_threads(ret->shape_[0]); + size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet"; + } else { + mxnet_op::Kernel::Launch(s, num_threads, + ret->dptr(), data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), rhs.storage_shape()[0], + ret->shape_[0], ret->shape_[1], seg_len); + } + } else { + LOG(FATAL) << "DotCsrRspDnsImpl has not implemented GPU version yet"; + } + }); + }); + }); + }); } +/*! + * \brief Impl of dot(csr.T, rsp) = rsp2 + */ template -void DotBackwardCsrRspDns(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const auto& rhs = inputs[2]; - CHECK_RSP_ALL_ROWS_NON_ZERO(rhs, "Dot", "rhs"); +void DotCsrRspRspImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { // reuse csr dns implementation when storage_shape == shape for rhs - const DotParam& param = nnvm::get(attrs.parsed); - TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsRspImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + return; + } + + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + // pre-allocate spaces for ret using the dense dimension size + if (ret->storage_type() == kRowSparseStorage) { + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + } + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), lhs.shape()[0], rhs.storage_shape()[0], + ret->shape()[0], ret->shape()[1], seg_len); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr.T, rsp) = rsp2 yet"; + } + } else { + LOG(FATAL) << "DotCsrRspRspImpl has not implemented GPU version yet"; + } + }); + }); + }); + }); } inline bool DotShape(const nnvm::NodeAttrs& attrs, @@ -919,14 +1113,18 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); - } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && - out_stype == kDefaultStorage) { + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kRowSparseStorage) { NDArray out = outputs[0]; DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kRowSparseStorage) { + NDArray ret = outputs[0]; + DotCsrRspRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); } else { FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); } @@ -947,23 +1145,21 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, const DotParam& param = nnvm::get(attrs.parsed); CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; - auto ograd_stype = inputs[0].storage_type(); - auto lhs_stype = inputs[1].storage_type(); - auto rhs_stype = inputs[2].storage_type(); + const auto ograd_stype = inputs[0].storage_type(); + const auto lhs_stype = inputs[1].storage_type(); + const auto rhs_stype = inputs[2].storage_type(); + const auto grad_rhs_stype = outputs[1].storage_type(); + if (ograd_stype == kDefaultStorage // ograd dns format && lhs_stype == kCSRStorage // csr input lhs of the op - && rhs_stype == kDefaultStorage // dns input rhs of the op - && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format - // dns, csr, dns => *, dns - DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); - } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && - rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) { - // dns, csr, rsp => *, dns - DotBackwardCsrRspDns(attrs, ctx, inputs, req, outputs); - } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && - rhs_stype == kDefaultStorage && outputs[1].storage_type() == kRowSparseStorage) { - NDArray grad_rhs = outputs[1]; - DotCsrDnsRspImpl(ctx, inputs[1], inputs[2].data(), req[1], !param.transpose_a, &grad_rhs); + && grad_rhs_stype == kDefaultStorage) { // grad(rhs) dns format + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else if (ograd_stype == kDefaultStorage + && lhs_stype == kCSRStorage + && grad_rhs_stype == kRowSparseStorage) { + NDArray ret = outputs[1]; + DotCsrDnsRspImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); } else { FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); } diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 4d2debe5f9d2..8379e567fd59 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -102,10 +102,10 @@ def test_dns_to_csr(dns_in): def test_sparse_dot(): - def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): + def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): lhs_dns = rand_ndarray(lhs_shape, 'default') lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr') - rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1) + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density) rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) if trans_lhs: @@ -135,6 +135,8 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, 0.05) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, 0.05) def test_sparse_embedding(): From 528d1eb01a2077c7ff571b8aef09d5dad7aaeaae Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 29 Jun 2017 23:36:15 -0700 Subject: [PATCH 2/9] Fix comments --- src/operator/tensor/matrix_op-inl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 8391ae9b8f57..3dbe3f8a1f1a 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -653,7 +653,7 @@ struct DotCsrTransDnsDnsByRowBlocks { } }; -/* +/*! * \brief Kernel of dot(csr.T(), dns) = rsp * Parallelization by row blocks. * This kernel fills up the row_idx array @@ -700,9 +700,9 @@ struct DotCsrRspDnsByRowBlocks { /*! * \brief * \param i the i-th thread - * \nnr_r storage_shape[0] of the rsp - * \num_rows dns.shape[0] - * \num_cols dns.shape[1] + * \param nnr_r storage_shape[0] of the rsp + * \param num_rows dns.shape[0] + * \param num_cols dns.shape[1] */ template MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, From 3d9e2013777168d7c3b42acb62dfbc6aa17d7040 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 30 Jun 2017 11:47:34 -0700 Subject: [PATCH 3/9] Replace std::lower_bound with own impl for gpu use too --- src/operator/tensor/matrix_op-inl.h | 21 +++++++++++++++++-- tests/python/unittest/test_sparse_operator.py | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 3dbe3f8a1f1a..4ddb6bb55491 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -716,8 +716,25 @@ struct DotCsrRspDnsByRowBlocks { for (size_t j = seg_start; j < seg_end; ++j) { if (indptr_l[j] == indptr_l[j+1]) continue; const size_t offset_out = j * num_cols; - const RType* row_idx_ptr = std::lower_bound(row_idx_r, row_idx_r+nnr_r, - col_idx_l[indptr_l[j]]); + // Use binary search to find the lower_bound of val in row_idx array + const RType* first = row_idx_r; + const RType* last = row_idx_r + nnr_r; + const auto val = col_idx_l[indptr_l[j]]; + const RType* it; + int count = last - first, step; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (*it < val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + const RType* row_idx_ptr = first; + // end of binary search if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> col_idx_l[indptr_l[j+1]-1]) continue; for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != row_idx_r+nnr_r;) { if (col_idx_l[k] == *row_idx_ptr) { diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 8379e567fd59..1fc64a7149ea 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -130,7 +130,7 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): grad_req={'lhs': 'null', 'rhs': 'write'}, rtol=1e-3, atol=1e-4) - lhs_shape = rand_shape_2d() + lhs_shape = rand_shape_2d(50, 200) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) From 8f5bb982867731df0305148b1b150b05661f8529 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 30 Jun 2017 19:15:12 -0700 Subject: [PATCH 4/9] Add time profiling --- src/operator/tensor/matrix_op-inl.h | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 4ddb6bb55491..403d606acd5a 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -7,6 +7,7 @@ #define MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ #include +#include #include #include #include @@ -683,9 +684,8 @@ struct DotCsrTransDnsRspByRowBlocks { if (col_idx < seg_start || col_idx >= seg_end) continue; const size_t offset_out = col_idx * num_cols; row_idx[col_idx] = 1; - const auto val = data_l[k]; for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; + out[offset_out+l] += data_r[offset_r+l] * data_l[k]; } } } @@ -903,6 +903,7 @@ void DotCsrDnsRspImpl(const OpContext& ctx, nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + LOG(INFO) << "DotCsrDnsRspImpl: output storage shape = " << ret->storage_shape(); if (0 == nnr) return; mshadow::Tensor rsp_data = data_out.FlatTo2D(s); size_t idx = 0; @@ -942,6 +943,7 @@ void DotCsrRspDnsImpl(const OpContext& ctx, if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + LOG(INFO) << "DotCsrRspDnsImpl: rhs storage shape = " << rhs.storage_shape(); mshadow::Stream *s = ctx.get_stream(); if (!lhs.storage_initialized() || !rhs.storage_initialized()) { if (kWriteTo == req) { @@ -1007,6 +1009,7 @@ void DotCsrRspRspImpl(const OpContext& ctx, if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + LOG(INFO) << "DotCsrRspRspImpl: rhs storage shape = " << rhs.storage_shape(); CHECK_EQ(ret->storage_type(), kRowSparseStorage); if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; @@ -1048,6 +1051,7 @@ void DotCsrRspRspImpl(const OpContext& ctx, nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + LOG(INFO) << "DotCsrRspRspImpl: output storage shape = " << ret->storage_shape(); if (0 == nnr) return; mshadow::Tensor rsp_data = data_out.FlatTo2D(s); size_t idx = 0; @@ -1129,19 +1133,31 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, auto out_stype = outputs[0].storage_type(); if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); + double start = dmlc::GetTime(); DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrDnsDnsImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); + double start = dmlc::GetTime(); DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrRspDnsImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kRowSparseStorage) { NDArray out = outputs[0]; + double start = dmlc::GetTime(); DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrDnsRspImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kRowSparseStorage) { NDArray ret = outputs[0]; + double start = dmlc::GetTime(); DotCsrRspRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + double elapse = dmlc::GetTime() - start; + LOG(INFO) << "DotCsrRspRspImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else { FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); } From d75eab43ae480a8d8b22cc047cc3d711485cc0d9 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 30 Jun 2017 20:30:51 -0700 Subject: [PATCH 5/9] Revert "Add time profiling" This reverts commit 8f5bb982867731df0305148b1b150b05661f8529. --- src/operator/tensor/matrix_op-inl.h | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 403d606acd5a..4ddb6bb55491 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -7,7 +7,6 @@ #define MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ #include -#include #include #include #include @@ -684,8 +683,9 @@ struct DotCsrTransDnsRspByRowBlocks { if (col_idx < seg_start || col_idx >= seg_end) continue; const size_t offset_out = col_idx * num_cols; row_idx[col_idx] = 1; + const auto val = data_l[k]; for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * data_l[k]; + out[offset_out+l] += data_r[offset_r+l] * val; } } } @@ -903,7 +903,6 @@ void DotCsrDnsRspImpl(const OpContext& ctx, nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); - LOG(INFO) << "DotCsrDnsRspImpl: output storage shape = " << ret->storage_shape(); if (0 == nnr) return; mshadow::Tensor rsp_data = data_out.FlatTo2D(s); size_t idx = 0; @@ -943,7 +942,6 @@ void DotCsrRspDnsImpl(const OpContext& ctx, if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kRowSparseStorage); - LOG(INFO) << "DotCsrRspDnsImpl: rhs storage shape = " << rhs.storage_shape(); mshadow::Stream *s = ctx.get_stream(); if (!lhs.storage_initialized() || !rhs.storage_initialized()) { if (kWriteTo == req) { @@ -1009,7 +1007,6 @@ void DotCsrRspRspImpl(const OpContext& ctx, if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kRowSparseStorage); - LOG(INFO) << "DotCsrRspRspImpl: rhs storage shape = " << rhs.storage_shape(); CHECK_EQ(ret->storage_type(), kRowSparseStorage); if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; @@ -1051,7 +1048,6 @@ void DotCsrRspRspImpl(const OpContext& ctx, nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); - LOG(INFO) << "DotCsrRspRspImpl: output storage shape = " << ret->storage_shape(); if (0 == nnr) return; mshadow::Tensor rsp_data = data_out.FlatTo2D(s); size_t idx = 0; @@ -1133,31 +1129,19 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, auto out_stype = outputs[0].storage_type(); if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); - double start = dmlc::GetTime(); DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); - double elapse = dmlc::GetTime() - start; - LOG(INFO) << "DotCsrDnsDnsImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); - double start = dmlc::GetTime(); DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - double elapse = dmlc::GetTime() - start; - LOG(INFO) << "DotCsrRspDnsImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kRowSparseStorage) { NDArray out = outputs[0]; - double start = dmlc::GetTime(); DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); - double elapse = dmlc::GetTime() - start; - LOG(INFO) << "DotCsrDnsRspImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kRowSparseStorage) { NDArray ret = outputs[0]; - double start = dmlc::GetTime(); DotCsrRspRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - double elapse = dmlc::GetTime() - start; - LOG(INFO) << "DotCsrRspRspImpl: trans_lhs = " << param.transpose_a << ", time cost: " << elapse * 1000 << " ms"; } else { FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); } From 35f06426ade0945fb3c4dbf46e54cb79933696fa Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 30 Jun 2017 22:07:18 -0700 Subject: [PATCH 6/9] Move dot and batch_dot to a single file --- src/operator/nn/matrix_dot-inl.h | 1037 +++++++++++++++++++++++++++ src/operator/nn/matrix_dot.cc | 114 +++ src/operator/nn/matrix_dot.cu | 27 + src/operator/tensor/indexing_op.h | 2 +- src/operator/tensor/matrix_op-inl.h | 1012 -------------------------- src/operator/tensor/matrix_op.cc | 101 --- src/operator/tensor/matrix_op.cu | 15 - 7 files changed, 1179 insertions(+), 1129 deletions(-) create mode 100644 src/operator/nn/matrix_dot-inl.h create mode 100644 src/operator/nn/matrix_dot.cc create mode 100644 src/operator/nn/matrix_dot.cu diff --git a/src/operator/nn/matrix_dot-inl.h b/src/operator/nn/matrix_dot-inl.h new file mode 100644 index 000000000000..c32453faf127 --- /dev/null +++ b/src/operator/nn/matrix_dot-inl.h @@ -0,0 +1,1037 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file matrix_dot-inl.h + * \brief Function definition of matrix dot operator + */ + +#ifndef MXNET_OPERATOR_NN_MATRIX_DOT_INL_H_ +#define MXNET_OPERATOR_NN_MATRIX_DOT_INL_H_ + +#include +#include +#include +#include +#include +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "../mxnet_op.h" + +namespace mxnet { +namespace op { + +struct DotParam : public dmlc::Parameter { + bool transpose_a; + bool transpose_b; + DMLC_DECLARE_PARAMETER(DotParam) { + DMLC_DECLARE_FIELD(transpose_a) + .describe("If true then transpose the first input before dot.") + .set_default(false); + DMLC_DECLARE_FIELD(transpose_b) + .describe("If true then transpose the second input before dot.") + .set_default(false); + } +}; + +template +void DotForward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const DotParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, kFloat32) + << "dot only support 32 bit float so far"; + + if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { + CHECK_NE(req[0], kAddTo) << "AddTo not yet suported"; + Tensor out = outputs[0].get(s); + VectorDot(out, + inputs[0].get(s), + inputs[1].get(s)); + } else { + int ma, na, mb, nb, m, n; + if (param.transpose_a) { + ma = inputs[0].size(0); + na = inputs[0].Size()/ma; + m = na; + } else { + na = inputs[0].size(inputs[0].ndim()-1); + ma = inputs[0].Size()/na; + m = ma; + } + if (param.transpose_b) { + nb = inputs[1].size(inputs[1].ndim()-1); + mb = inputs[1].Size()/nb; + n = mb; + } else { + mb = inputs[1].size(0); + nb = inputs[1].Size()/mb; + n = nb; + } + + Tensor input0 = + inputs[0].get_with_shape(Shape2(ma, na), s); + Tensor input1 = + inputs[1].get_with_shape(Shape2(mb, nb), s); + Tensor out = + outputs[0].get_with_shape(Shape2(m, n), s); + if (param.transpose_a && param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T())); + } else if (!param.transpose_a && param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T())); + } else if (param.transpose_a && !param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1)); + } else { + ASSIGN_DISPATCH(out, req[0], dot(input0, input1)); + } + } +} + +template +void DotBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const DotParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + CHECK_NE(req[0], kWriteInplace); + CHECK_NE(req[1], kWriteInplace); + + if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) { + Tensor mout_grad = inputs[0].get(s); + Tensor mlhs_data = inputs[1].get(s); + Tensor mrhs_data = inputs[2].get(s); + Tensor mlhs_grad = outputs[0].get(s); + Tensor mrhs_grad = outputs[1].get(s); + ASSIGN_DISPATCH(mrhs_grad, req[1], + broadcast_scalar(mout_grad, mlhs_data.shape_) * mlhs_data); + ASSIGN_DISPATCH(mlhs_grad, req[0], + broadcast_scalar(mout_grad, mlhs_data.shape_) * mrhs_data); + } else { + int ma, na, mb, nb, m, n; + if (param.transpose_a) { + ma = outputs[0].size(0); + na = outputs[0].Size()/ma; + m = na; + } else { + na = outputs[0].size(outputs[0].ndim()-1); + ma = outputs[0].Size()/na; + m = ma; + } + if (param.transpose_b) { + nb = outputs[1].size(outputs[1].ndim()-1); + mb = outputs[1].Size()/nb; + n = mb; + } else { + mb = outputs[1].size(0); + nb = outputs[1].Size()/mb; + n = nb; + } + + Tensor mout_grad = + inputs[0].get_with_shape(Shape2(m, n), s); + Tensor mlhs_data = + inputs[1].get_with_shape(Shape2(ma, na), s); + Tensor mrhs_data = + inputs[2].get_with_shape(Shape2(mb, nb), s); + Tensor mlhs_grad = + outputs[0].get_with_shape(Shape2(ma, na), s); + Tensor mrhs_grad = + outputs[1].get_with_shape(Shape2(mb, nb), s); + if (param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x.T, y.T) + // dy = dot(x, dz).T = dot(dz.T, x.T) + // dx = dot(dz, y).T = dot(y.T, dz.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data.T())); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data.T(), mout_grad.T())); + } else if (!param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x, y.T) + // dy = dot(x.T, dz).T = dot(dz.T, x) + // dx = dot(dz, y) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data)); + } else if (param.transpose_a && !param.transpose_b) { + // Gradient of z = dot(x.T, y) + // dy = dot(x, dz) + // dx = dot(dz, y.T).T = dot(y, dz.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data, mout_grad)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data, mout_grad.T())); + } else { + // Gradient of z = dot(x, y) + // dy = dot(x.T, dz) + // dx = dot(dz, y.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data.T(), mout_grad)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data.T())); + } + } +} + +inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp + if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + } + return true; +} + +inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 2U); + const DotParam& param = nnvm::get(attrs.parsed); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + if (!param.transpose_a && kCSRStorage == (*in_attrs)[1]) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); + } + return true; +} + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrDnsDns { + /*! + * \brief This function represents performing an inner product between a row of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_cols number of columns of output + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols) { + const int irow = i / num_cols; // row id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { + const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs + sum += data_l[j] * data_r[cur_col*num_cols+icol]; + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrTransDnsDns { + /*! + * \brief This function represents performing an inner product between a column of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_rows_l number of rows of lhs + * \param num_cols number of columns of outputs + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const int num_rows_l, + const int num_cols) { + const int irow = i / num_cols; // col id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (int k = 0; k < num_rows_l; ++k) { + const IType low = indptr_l[k]; + const IType high = indptr_l[k+1]; + if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; + int j = -1, l = low, r = high - 1; + while (l <= r) { + int m = l + (r - l) / 2; + if (col_idx_l[m] == irow) { + j = m; break; + } + if (col_idx_l[m] < irow) { + l = m + 1; + } else { + r = m - 1; + } + } + if (j >= 0) { + sum += data_l[j] * data_r[k*num_cols+icol]; + } + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows, const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto val = data_l[k]; + const size_t offset_r = col_idx_l[k] * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrTransDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns) = rsp + * Parallelization by row blocks. + * This kernel fills up the row_idx array + * of the rsp with 1 for nonzero rows and 0 + * for zero rows. + * The matrix will be compacted after this kernel call. + */ +struct DotCsrTransDnsRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + row_idx[col_idx] = 1; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr, rsp) = dns + * Parallelization by row blocks + */ +struct DotCsrRspDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param nnr_r storage_shape[0] of the rsp + * \param num_rows dns.shape[0] + * \param num_cols dns.shape[1] + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const RType* row_idx_r, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + // Use binary search to find the lower_bound of val in row_idx array + const RType* first = row_idx_r; + const RType* last = row_idx_r + nnr_r; + const auto val = col_idx_l[indptr_l[j]]; + const RType* it; + int count = last - first, step; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (*it < val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + const RType* row_idx_ptr = first; + // end of binary search + if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> col_idx_l[indptr_l[j+1]-1]) continue; + for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != row_idx_r+nnr_r;) { + if (col_idx_l[k] == *row_idx_ptr) { + const size_t offset_r = (row_idx_ptr - row_idx_r) * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_l[k] * data_r[offset_r+l]; + } + ++k; + ++row_idx_ptr; + } else if (col_idx_l[k] < *row_idx_ptr) { + ++k; + } else { + ++row_idx_ptr; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), rsp) = dns with row_idx marked for non-zero rows + * Parallelization by row blocks + */ +struct DotCsrTransRspRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param num_rows_l number of rows of lhs matrix + * \param nnr_r number of non-zero rows of rhs matrix + * \param num_rows number of rows of out matrix + * \param num_cols number of cols of out matrix + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx_out, + const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const RType* row_idx_r, const size_t num_rows_l, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t rid = 0; rid < nnr_r; ++rid) { + const auto j = row_idx_r[rid]; + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = rid * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + row_idx_out[col_idx] = 1; // mark nonzero row as 1 + const size_t offset_out = col_idx * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * data_l[k]; + } + } + } + } +}; + +template +void DotCsrDnsDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + if (!lhs.storage_initialized()) return; + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + } else { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + data_out.shape_[0], data_out.shape_[1]); + } + } else { // gpu parallelization by output elements + if (trans_lhs) { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + data_out.shape_[1]); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); + }); + } + } + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr, rsp) + */ +template +void DotCsrDnsRspImpl(const OpContext& ctx, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized()) return; + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + + // pre-allocate spaces for ret using the dense dimension size + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." + " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; + } + } else { + LOG(FATAL) << "DotCsrDnsRspImpl has not implemented GPU version yet."; + } + }); + }); + }); + }); +} + +template +void DotCsrRspDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + // reuse csr dns implementation when storage_shape == shape for rhs + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + return; + } + + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + mshadow::Stream *s = ctx.get_stream(); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) { + if (kWriteTo == req) { + MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { // data type + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + }); + } + return; + } + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + } + int num_threads = mxnet_op::get_num_threads(ret->shape_[0]); + size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet"; + } else { + mxnet_op::Kernel::Launch(s, num_threads, + ret->dptr(), data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), rhs.storage_shape()[0], + ret->shape_[0], ret->shape_[1], seg_len); + } + } else { + LOG(FATAL) << "DotCsrRspDnsImpl has not implemented GPU version yet"; + } + }); + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr.T, rsp) = rsp2 + */ +template +void DotCsrRspRspImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + // reuse csr dns implementation when storage_shape == shape for rhs + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsRspImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + return; + } + + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + // pre-allocate spaces for ret using the dense dimension size + if (ret->storage_type() == kRowSparseStorage) { + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + } + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), lhs.shape()[0], rhs.storage_shape()[0], + ret->shape()[0], ret->shape()[1], seg_len); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr.T, rsp) = rsp2 yet"; + } + } else { + LOG(FATAL) << "DotCsrRspRspImpl has not implemented GPU version yet"; + } + }); + }); + }); + }); +} + +inline bool DotShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + TShape& lshape = (*in_attrs)[0]; + TShape& rshape = (*in_attrs)[1]; + if (lshape.ndim() == 1 && rshape.ndim() == 1) { + CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors"; + CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1)); + } else { + bool Ta = param.transpose_a, Tb = param.transpose_b; + TShape L[2], R[2]; + if (Ta) { + L[0] = mshadow::Shape1(lshape[0]); + L[1] = lshape.ndim() > 1 ? TShape(&lshape[1], &lshape[lshape.ndim()]) : TShape(1); + } else { + L[0] = lshape.ndim() > 1 ? TShape(&lshape[0], &lshape[lshape.ndim()-1]) : TShape(1); + L[1] = mshadow::Shape1(lshape[lshape.ndim()-1]); + } + if (Tb) { + R[0] = rshape.ndim() > 1 ? TShape(&rshape[0], &rshape[rshape.ndim()-1]) : TShape(1); + R[1] = mshadow::Shape1(rshape[rshape.ndim()-1]); + } else { + R[0] = mshadow::Shape1(rshape[0]); + R[1] = rshape.ndim() > 1 ? TShape(&rshape[1], &rshape[rshape.ndim()]) : TShape(1); + } + + if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) { + CHECK_EQ(L[!Ta].Size(), R[Tb].Size()) + << "dot shape error: " << lshape << " X " << rshape; + } + std::vector buf; + if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], &L[Ta][L[Ta].ndim()]); + if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], &R[!Tb][R[!Tb].ndim()]); + TShape oshape(buf.begin(), buf.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + } + return true; +} + +template +void DotForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; + auto lhs_stype = inputs[0].storage_type(); + auto rhs_stype = inputs[1].storage_type(); + auto out_stype = outputs[0].storage_type(); + if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage + && out_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kRowSparseStorage) { + NDArray ret = outputs[0]; + DotCsrRspRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); + } +} + +template +void DotBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_EQ(kNullOp, req[0]) + << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; + const auto ograd_stype = inputs[0].storage_type(); + const auto lhs_stype = inputs[1].storage_type(); + const auto rhs_stype = inputs[2].storage_type(); + const auto grad_rhs_stype = outputs[1].storage_type(); + + if (ograd_stype == kDefaultStorage // ograd dns format + && lhs_stype == kCSRStorage // csr input lhs of the op + && grad_rhs_stype == kDefaultStorage) { // grad(rhs) dns format + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else if (ograd_stype == kDefaultStorage + && lhs_stype == kCSRStorage + && grad_rhs_stype == kRowSparseStorage) { + NDArray ret = outputs[1]; + DotCsrDnsRspImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); + } +} + +template +void BatchDotForward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, mshadow::kFloat32) + << "dot only support 32 bit float so far"; + + mshadow::Tensor out = outputs[0].get(s); + mshadow::Tensor mlhs = inputs[0].get(s); + mshadow::Tensor mrhs = inputs[1].get(s); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); + if (kNullOp != req[0]) { + if (param.transpose_a && param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else if (!param.transpose_a && param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else if (param.transpose_a && !param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } + } +} + +template +void BatchDotBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_NE(req[1], kWriteInplace); + CHECK_NE(req[0], kWriteInplace); + + mshadow::Tensor mout_grad = inputs[0].get(s); + mshadow::Tensor mlhs_data = inputs[1].get(s); + mshadow::Tensor mrhs_data = inputs[2].get(s); + mshadow::Tensor mlhs_grad = outputs[0].get(s); + mshadow::Tensor mrhs_grad = outputs[1].get(s); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed( + mshadow::Shape2(2, 3 * mout_grad.size(0)), s); + mshadow::Tensor rhs_workspace = workspace[0]; + mshadow::Tensor lhs_workspace = workspace[1]; + if (param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x.T, y.T) + // dy = dot(x, dz).T = dot(dz.T, x.T) + // dx = dot(dz, y).T = dot(y.T, dz.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else if (!param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x, y.T) + // dy = dot(x.T, dz).T = dot(dz.T, x) + // dx = dot(dz, y) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else if (param.transpose_a && !param.transpose_b) { + // Gradient of z = dot(x.T, y) + // dy = dot(x, dz) + // dx = dot(dz, y.T).T = dot(y, dz.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else { + // Gradient of z = dot(x, y) + // dy = dot(x.T, dz) + // dx = dot(dz, y.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } +} + +inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + TShape& lshape = (*in_attrs)[0]; + TShape& rshape = (*in_attrs)[1]; + if (lshape.ndim() == 3 && rshape.ndim() == 3) { + CHECK(lshape[0] == rshape[0]) + << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape + << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; + index_t out_m = param.transpose_a ? lshape[2] : lshape[1]; + index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2]; + index_t out_n = param.transpose_b ? rshape[1] : rshape[2]; + index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1]; + CHECK(lshape_k == rshape_k) + << "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape + << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n)); + } else { + LOG(FATAL) << "batch_dot currently only support 3D*3D array" + << lshape << " v.s. " << rshape; + } + return true; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NN_MATRIX_DOT_INL_H_ diff --git a/src/operator/nn/matrix_dot.cc b/src/operator/nn/matrix_dot.cc new file mode 100644 index 000000000000..716efea6999c --- /dev/null +++ b/src/operator/nn/matrix_dot.cc @@ -0,0 +1,114 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file matrix_dot.cc + * \brief CPU Implementation of matrix dot + */ + +#include "./matrix_dot-inl.h" + +namespace mxnet { +namespace op { +DMLC_REGISTER_PARAMETER(DotParam); + +NNVM_REGISTER_OP(dot) +.describe(R"doc(Dot product of two arrays. + +``dot``'s behavior depends on the input array dimensions: + +- 1-D arrays: inner product of vectors +- 2-D arrays: matrix multiplication +- N-D arrays: a sum product over the last axis of the first input and the first + axis of the second input + + For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the + result array will have shape `(n,m,r,s)`. It is computed by:: + + dot(x,y)[i,j,a,b] = sum(x[i,j,:]*y[:,a,b]) + + Example:: + + x = reshape([0,1,2,3,4,5,6,7], shape=(2,2,2)) + y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2)) + dot(x,y)[0,0,1,1] = 0 + sum(x[0,0,:]*y[:,1,1]) = 0 +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) +.set_attr("FInferShape", DotShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) +.add_argument("lhs", "NDArray-or-Symbol", "The first input") +.add_argument("rhs", "NDArray-or-Symbol", "The second input") +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_dot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx) +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(batch_dot) +.describe(R"doc(Batchwise dot product. + +``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and +``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`. + +For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape +`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`, +which is computed by:: + + batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:]) + +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) +.set_attr("FInferShape", BatchDotShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BatchDotForward_) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) +.add_argument("lhs", "NDArray-or-Symbol", "The first input") +.add_argument("rhs", "NDArray-or-Symbol", "The second input") +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_batch_dot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", BatchDotBackward_); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/matrix_dot.cu b/src/operator/nn/matrix_dot.cu new file mode 100644 index 000000000000..21592e15449e --- /dev/null +++ b/src/operator/nn/matrix_dot.cu @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file matrix_dot.cu + * \brief GPU Implementation of matrix dot + */ + +#include "./matrix_dot-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(dot) +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx); + +NNVM_REGISTER_OP(_backward_dot) +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx); + +NNVM_REGISTER_OP(batch_dot) +.set_attr("FCompute", BatchDotForward_); + +NNVM_REGISTER_OP(_backward_batch_dot) +.set_attr("FCompute", BatchDotBackward_); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 2a1590f9a8a5..5f6f85e46ca6 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -22,7 +22,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "./sort_op.h" -#include "./matrix_op-inl.h" +#include "../nn/matrix_dot-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 4ddb6bb55491..e8b3936f3627 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -320,1018 +320,6 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, return true; } -struct DotParam : public dmlc::Parameter { - bool transpose_a; - bool transpose_b; - DMLC_DECLARE_PARAMETER(DotParam) { - DMLC_DECLARE_FIELD(transpose_a) - .describe("If true then transpose the first input before dot.") - .set_default(false); - DMLC_DECLARE_FIELD(transpose_b) - .describe("If true then transpose the second input before dot.") - .set_default(false); - } -}; - -template -void DotForward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const DotParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, kFloat32) - << "dot only support 32 bit float so far"; - - if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { - CHECK_NE(req[0], kAddTo) << "AddTo not yet suported"; - Tensor out = outputs[0].get(s); - VectorDot(out, - inputs[0].get(s), - inputs[1].get(s)); - } else { - int ma, na, mb, nb, m, n; - if (param.transpose_a) { - ma = inputs[0].size(0); - na = inputs[0].Size()/ma; - m = na; - } else { - na = inputs[0].size(inputs[0].ndim()-1); - ma = inputs[0].Size()/na; - m = ma; - } - if (param.transpose_b) { - nb = inputs[1].size(inputs[1].ndim()-1); - mb = inputs[1].Size()/nb; - n = mb; - } else { - mb = inputs[1].size(0); - nb = inputs[1].Size()/mb; - n = nb; - } - - Tensor input0 = - inputs[0].get_with_shape(Shape2(ma, na), s); - Tensor input1 = - inputs[1].get_with_shape(Shape2(mb, nb), s); - Tensor out = - outputs[0].get_with_shape(Shape2(m, n), s); - if (param.transpose_a && param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T())); - } else if (!param.transpose_a && param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T())); - } else if (param.transpose_a && !param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1)); - } else { - ASSIGN_DISPATCH(out, req[0], dot(input0, input1)); - } - } -} - -template -void DotBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const DotParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - CHECK_NE(req[0], kWriteInplace); - CHECK_NE(req[1], kWriteInplace); - - if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) { - Tensor mout_grad = inputs[0].get(s); - Tensor mlhs_data = inputs[1].get(s); - Tensor mrhs_data = inputs[2].get(s); - Tensor mlhs_grad = outputs[0].get(s); - Tensor mrhs_grad = outputs[1].get(s); - ASSIGN_DISPATCH(mrhs_grad, req[1], - broadcast_scalar(mout_grad, mlhs_data.shape_) * mlhs_data); - ASSIGN_DISPATCH(mlhs_grad, req[0], - broadcast_scalar(mout_grad, mlhs_data.shape_) * mrhs_data); - } else { - int ma, na, mb, nb, m, n; - if (param.transpose_a) { - ma = outputs[0].size(0); - na = outputs[0].Size()/ma; - m = na; - } else { - na = outputs[0].size(outputs[0].ndim()-1); - ma = outputs[0].Size()/na; - m = ma; - } - if (param.transpose_b) { - nb = outputs[1].size(outputs[1].ndim()-1); - mb = outputs[1].Size()/nb; - n = mb; - } else { - mb = outputs[1].size(0); - nb = outputs[1].Size()/mb; - n = nb; - } - - Tensor mout_grad = - inputs[0].get_with_shape(Shape2(m, n), s); - Tensor mlhs_data = - inputs[1].get_with_shape(Shape2(ma, na), s); - Tensor mrhs_data = - inputs[2].get_with_shape(Shape2(mb, nb), s); - Tensor mlhs_grad = - outputs[0].get_with_shape(Shape2(ma, na), s); - Tensor mrhs_grad = - outputs[1].get_with_shape(Shape2(mb, nb), s); - if (param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x.T, y.T) - // dy = dot(x, dz).T = dot(dz.T, x.T) - // dx = dot(dz, y).T = dot(y.T, dz.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data.T())); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data.T(), mout_grad.T())); - } else if (!param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x, y.T) - // dy = dot(x.T, dz).T = dot(dz.T, x) - // dx = dot(dz, y) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data)); - } else if (param.transpose_a && !param.transpose_b) { - // Gradient of z = dot(x.T, y) - // dy = dot(x, dz) - // dx = dot(dz, y.T).T = dot(y, dz.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data, mout_grad)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data, mout_grad.T())); - } else { - // Gradient of z = dot(x, y) - // dy = dot(x.T, dz) - // dx = dot(dz, y.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data.T(), mout_grad)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data.T())); - } - } -} - -inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, - const Context& ctx, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp - if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); - } else { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); - } - return true; -} - -inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, - const Context& ctx, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 3U); - CHECK_EQ(out_attrs->size(), 2U); - const DotParam& param = nnvm::get(attrs.parsed); - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); - if (!param.transpose_a && kCSRStorage == (*in_attrs)[1]) { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); - } else { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); - } - return true; -} - -/*! - * \brief Kernel of dot(csr, dns1) = dns2 - * Parallelization by output matrix elements - */ -template -struct DotCsrDnsDns { - /*! - * \brief This function represents performing an inner product between a row of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output matrix - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_cols number of columns of output - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, - const int num_cols) { - const int irow = i / num_cols; // row id of the lhs - const int icol = i % num_cols; // col id of the rhs - DType sum = 0; - for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { - const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs - sum += data_l[j] * data_r[cur_col*num_cols+icol]; - } - KERNEL_ASSIGN(out[i], req, sum); - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns1) = dns2 - * Parallelization by output matrix elements - */ -template -struct DotCsrTransDnsDns { - /*! - * \brief This function represents performing an inner product between a column of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output matrix - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_rows_l number of rows of lhs - * \param num_cols number of columns of outputs - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, const int num_rows_l, - const int num_cols) { - const int irow = i / num_cols; // col id of the lhs - const int icol = i % num_cols; // col id of the rhs - DType sum = 0; - for (int k = 0; k < num_rows_l; ++k) { - const IType low = indptr_l[k]; - const IType high = indptr_l[k+1]; - if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; - int j = -1, l = low, r = high - 1; - while (l <= r) { - int m = l + (r - l) / 2; - if (col_idx_l[m] == irow) { - j = m; break; - } - if (col_idx_l[m] < irow) { - l = m + 1; - } else { - r = m - 1; - } - } - if (j >= 0) { - sum += data_l[j] * data_r[k*num_cols+icol]; - } - } - KERNEL_ASSIGN(out[i], req, sum); - } -}; - -/*! - * \brief Kernel of dot(csr, dns1) = dns2 - * Parallelization by row blocks - */ -struct DotCsrDnsDnsByRowBlocks { - /*! - * \brief - * \param i the i-th thread - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, const size_t seg_len, - const size_t num_rows, const size_t num_cols) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); - for (size_t j = seg_start; j < seg_end; ++j) { - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_out = j * num_cols; - for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { - const auto val = data_l[k]; - const size_t offset_r = col_idx_l[k] * num_cols; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; - } - } - } - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns1) = dns2 - * Parallelization by row blocks - */ -struct DotCsrTransDnsDnsByRowBlocks { - /*! - * \brief - * \param i the i-th thread - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, const size_t seg_len, - const size_t num_rows_l, const size_t num_rows, - const size_t num_cols) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (i + 1) * seg_len; - for (size_t j = 0; j < num_rows_l; ++j) { - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_r = j * num_cols; - for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { - const auto col_idx = col_idx_l[k]; - if (col_idx < seg_start || col_idx >= seg_end) continue; - const size_t offset_out = col_idx * num_cols; - const auto val = data_l[k]; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; - } - } - } - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns) = rsp - * Parallelization by row blocks. - * This kernel fills up the row_idx array - * of the rsp with 1 for nonzero rows and 0 - * for zero rows. - * The matrix will be compacted after this kernel call. - */ -struct DotCsrTransDnsRspByRowBlocks { - /*! - * \brief - * \param i the i-th thread - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l, - const IType* indptr_l, const CType* col_idx_l, - const DType* data_r, const size_t seg_len, - const size_t num_rows_l, const size_t num_rows, - const size_t num_cols) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (i + 1) * seg_len; - for (size_t j = 0; j < num_rows_l; ++j) { - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_r = j * num_cols; - for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { - const auto col_idx = col_idx_l[k]; - if (col_idx < seg_start || col_idx >= seg_end) continue; - const size_t offset_out = col_idx * num_cols; - row_idx[col_idx] = 1; - const auto val = data_l[k]; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * val; - } - } - } - } -}; - -/*! - * \brief Kernel of dot(csr, rsp) = dns - * Parallelization by row blocks - */ -struct DotCsrRspDnsByRowBlocks { - /*! - * \brief - * \param i the i-th thread - * \param nnr_r storage_shape[0] of the rsp - * \param num_rows dns.shape[0] - * \param num_cols dns.shape[1] - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, - const IType* indptr_l, const CType* col_idx_l, - const DType* data_r, const RType* row_idx_r, - const size_t nnr_r, const size_t num_rows, - const size_t num_cols, const size_t seg_len) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); - for (size_t j = seg_start; j < seg_end; ++j) { - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_out = j * num_cols; - // Use binary search to find the lower_bound of val in row_idx array - const RType* first = row_idx_r; - const RType* last = row_idx_r + nnr_r; - const auto val = col_idx_l[indptr_l[j]]; - const RType* it; - int count = last - first, step; - while (count > 0) { - it = first; - step = count / 2; - it += step; - if (*it < val) { - first = ++it; - count -= step + 1; - } else { - count = step; - } - } - const RType* row_idx_ptr = first; - // end of binary search - if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> col_idx_l[indptr_l[j+1]-1]) continue; - for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != row_idx_r+nnr_r;) { - if (col_idx_l[k] == *row_idx_ptr) { - const size_t offset_r = (row_idx_ptr - row_idx_r) * num_cols; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_l[k] * data_r[offset_r+l]; - } - ++k; - ++row_idx_ptr; - } else if (col_idx_l[k] < *row_idx_ptr) { - ++k; - } else { - ++row_idx_ptr; - } - } - } - } -}; - -/*! - * \brief Kernel of dot(csr.T(), rsp) = dns with row_idx marked for non-zero rows - * Parallelization by row blocks - */ -struct DotCsrTransRspRspByRowBlocks { - /*! - * \brief - * \param i the i-th thread - * \param num_rows_l number of rows of lhs matrix - * \param nnr_r number of non-zero rows of rhs matrix - * \param num_rows number of rows of out matrix - * \param num_cols number of cols of out matrix - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx_out, - const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, - const RType* row_idx_r, const size_t num_rows_l, - const size_t nnr_r, const size_t num_rows, - const size_t num_cols, const size_t seg_len) { - const size_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; - const size_t seg_end = (i + 1) * seg_len; - for (size_t rid = 0; rid < nnr_r; ++rid) { - const auto j = row_idx_r[rid]; - if (indptr_l[j] == indptr_l[j+1]) continue; - const size_t offset_r = rid * num_cols; - for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { - const auto col_idx = col_idx_l[k]; - if (col_idx < seg_start || col_idx >= seg_end) continue; - row_idx_out[col_idx] = 1; // mark nonzero row as 1 - const size_t offset_out = col_idx * num_cols; - for (size_t l = 0; l < num_cols; ++l) { - out[offset_out+l] += data_r[offset_r+l] * data_l[k]; - } - } - } - } -}; - -template -void DotCsrDnsDnsImpl(const OpContext& ctx, - const NDArray& lhs, - const TBlob& rhs, - const OpReqType req, - const bool trans_lhs, - TBlob* ret) { - if (kNullOp == req) return; - CHECK_EQ(lhs.storage_type(), kCSRStorage); - if (!lhs.storage_initialized()) return; - - mshadow::Stream *s = ctx.get_stream(); - const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob& data_r = rhs; - const TBlob data_out = *ret; - - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), seg_len, - lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); - } else { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), seg_len, - data_out.shape_[0], data_out.shape_[1]); - } - } else { // gpu parallelization by output elements - if (trans_lhs) { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], - data_out.shape_[1]); - }); - } else { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); - }); - } - } - }); - }); - }); -} - -/*! - * \brief Impl of dot(csr, rsp) - */ -template -void DotCsrDnsRspImpl(const OpContext& ctx, - const NDArray& lhs, - const TBlob& rhs, - const OpReqType req, - const bool trans_lhs, - NDArray* ret) { - if (kNullOp == req) return; - CHECK_EQ(lhs.storage_type(), kCSRStorage); - CHECK_EQ(ret->storage_type(), kRowSparseStorage); - if (!lhs.storage_initialized()) return; - - mshadow::Stream *s = ctx.get_stream(); - const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob& data_r = rhs; - - // pre-allocate spaces for ret using the dense dimension size - ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); - const TBlob data_out = ret->data(); - const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); - - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - RType* row_idx = row_idx_out.dptr(); - mxnet_op::Kernel::Launch( - s, row_idx_out.Size(), row_idx); - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), row_idx, data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); - index_t nnr = 0; - nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); - ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); - ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); - if (0 == nnr) return; - mshadow::Tensor rsp_data = data_out.FlatTo2D(s); - size_t idx = 0; - for (index_t i = 0; i < ret->shape()[0]; ++i) { - if (row_idx[i] > 0) { - row_idx[idx] = i; - mshadow::Copy(rsp_data[idx], rsp_data[i], s); - ++idx; - } - } - } else { - LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." - " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; - } - } else { - LOG(FATAL) << "DotCsrDnsRspImpl has not implemented GPU version yet."; - } - }); - }); - }); - }); -} - -template -void DotCsrRspDnsImpl(const OpContext& ctx, - const NDArray& lhs, - const NDArray& rhs, - const OpReqType req, - const bool trans_lhs, - TBlob* ret) { - // reuse csr dns implementation when storage_shape == shape for rhs - if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense - DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); - return; - } - - if (kNullOp == req) return; - CHECK_EQ(lhs.storage_type(), kCSRStorage); - CHECK_EQ(rhs.storage_type(), kRowSparseStorage); - mshadow::Stream *s = ctx.get_stream(); - if (!lhs.storage_initialized() || !rhs.storage_initialized()) { - if (kWriteTo == req) { - MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { // data type - mxnet_op::Kernel::Launch( - s, ret->Size(), ret->dptr()); - }); - } - return; - } - - const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob data_r = rhs.data(); - const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); - - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, ret->Size(), ret->dptr()); - } - int num_threads = mxnet_op::get_num_threads(ret->shape_[0]); - size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet"; - } else { - mxnet_op::Kernel::Launch(s, num_threads, - ret->dptr(), data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - row_idx_r.dptr(), rhs.storage_shape()[0], - ret->shape_[0], ret->shape_[1], seg_len); - } - } else { - LOG(FATAL) << "DotCsrRspDnsImpl has not implemented GPU version yet"; - } - }); - }); - }); - }); -} - -/*! - * \brief Impl of dot(csr.T, rsp) = rsp2 - */ -template -void DotCsrRspRspImpl(const OpContext& ctx, - const NDArray& lhs, - const NDArray& rhs, - const OpReqType req, - const bool trans_lhs, - NDArray* ret) { - // reuse csr dns implementation when storage_shape == shape for rhs - if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense - DotCsrDnsRspImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); - return; - } - - if (kNullOp == req) return; - CHECK_EQ(lhs.storage_type(), kCSRStorage); - CHECK_EQ(rhs.storage_type(), kRowSparseStorage); - CHECK_EQ(ret->storage_type(), kRowSparseStorage); - if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; - - mshadow::Stream *s = ctx.get_stream(); - const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob data_r = rhs.data(); - const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); - - // pre-allocate spaces for ret using the dense dimension size - if (ret->storage_type() == kRowSparseStorage) { - ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); - } - const TBlob data_out = ret->data(); - const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); - - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - RType* row_idx = row_idx_out.dptr(); - mxnet_op::Kernel::Launch( - s, row_idx_out.Size(), row_idx); - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), row_idx, data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - row_idx_r.dptr(), lhs.shape()[0], rhs.storage_shape()[0], - ret->shape()[0], ret->shape()[1], seg_len); - index_t nnr = 0; - nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); - ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); - ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); - if (0 == nnr) return; - mshadow::Tensor rsp_data = data_out.FlatTo2D(s); - size_t idx = 0; - for (index_t i = 0; i < ret->shape()[0]; ++i) { - if (row_idx[i] > 0) { - row_idx[idx] = i; - mshadow::Copy(rsp_data[idx], rsp_data[i], s); - ++idx; - } - } - } else { - LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr.T, rsp) = rsp2 yet"; - } - } else { - LOG(FATAL) << "DotCsrRspRspImpl has not implemented GPU version yet"; - } - }); - }); - }); - }); -} - -inline bool DotShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - TShape& lshape = (*in_attrs)[0]; - TShape& rshape = (*in_attrs)[1]; - if (lshape.ndim() == 1 && rshape.ndim() == 1) { - CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors"; - CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1)); - } else { - bool Ta = param.transpose_a, Tb = param.transpose_b; - TShape L[2], R[2]; - if (Ta) { - L[0] = mshadow::Shape1(lshape[0]); - L[1] = lshape.ndim() > 1 ? TShape(&lshape[1], &lshape[lshape.ndim()]) : TShape(1); - } else { - L[0] = lshape.ndim() > 1 ? TShape(&lshape[0], &lshape[lshape.ndim()-1]) : TShape(1); - L[1] = mshadow::Shape1(lshape[lshape.ndim()-1]); - } - if (Tb) { - R[0] = rshape.ndim() > 1 ? TShape(&rshape[0], &rshape[rshape.ndim()-1]) : TShape(1); - R[1] = mshadow::Shape1(rshape[rshape.ndim()-1]); - } else { - R[0] = mshadow::Shape1(rshape[0]); - R[1] = rshape.ndim() > 1 ? TShape(&rshape[1], &rshape[rshape.ndim()]) : TShape(1); - } - - if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) { - CHECK_EQ(L[!Ta].Size(), R[Tb].Size()) - << "dot shape error: " << lshape << " X " << rshape; - } - std::vector buf; - if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], &L[Ta][L[Ta].ndim()]); - if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], &R[!Tb][R[!Tb].ndim()]); - TShape oshape(buf.begin(), buf.end()); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - } - return true; -} - -template -void DotForwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; - auto lhs_stype = inputs[0].storage_type(); - auto rhs_stype = inputs[1].storage_type(); - auto out_stype = outputs[0].storage_type(); - if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { - TBlob ret = outputs[0].data(); - DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); - } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage - && out_stype == kDefaultStorage) { - TBlob ret = outputs[0].data(); - DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage - && out_stype == kRowSparseStorage) { - NDArray out = outputs[0]; - DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); - } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage - && out_stype == kRowSparseStorage) { - NDArray ret = outputs[0]; - DotCsrRspRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - } else { - FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); - } -} - -template -void DotBackwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 3U); - CHECK_EQ(outputs.size(), 2U); - CHECK_EQ(req.size(), 2U); - CHECK_EQ(kNullOp, req[0]) - << "sparse dot does not support computing the gradient of the csr/lhs"; - CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; - - const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; - const auto ograd_stype = inputs[0].storage_type(); - const auto lhs_stype = inputs[1].storage_type(); - const auto rhs_stype = inputs[2].storage_type(); - const auto grad_rhs_stype = outputs[1].storage_type(); - - if (ograd_stype == kDefaultStorage // ograd dns format - && lhs_stype == kCSRStorage // csr input lhs of the op - && grad_rhs_stype == kDefaultStorage) { // grad(rhs) dns format - TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); - } else if (ograd_stype == kDefaultStorage - && lhs_stype == kCSRStorage - && grad_rhs_stype == kRowSparseStorage) { - NDArray ret = outputs[1]; - DotCsrDnsRspImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); - } else { - FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); - } -} - -template -void BatchDotForward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, mshadow::kFloat32) - << "dot only support 32 bit float so far"; - - mshadow::Tensor out = outputs[0].get(s); - mshadow::Tensor mlhs = inputs[0].get(s); - mshadow::Tensor mrhs = inputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); - if (kNullOp != req[0]) { - if (param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } else if (!param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } else if (param.transpose_a && !param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } else { - mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - workspace); - } - } -} - -template -void BatchDotBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_NE(req[1], kWriteInplace); - CHECK_NE(req[0], kWriteInplace); - - mshadow::Tensor mout_grad = inputs[0].get(s); - mshadow::Tensor mlhs_data = inputs[1].get(s); - mshadow::Tensor mrhs_data = inputs[2].get(s); - mshadow::Tensor mlhs_grad = outputs[0].get(s); - mshadow::Tensor mrhs_grad = outputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed( - mshadow::Shape2(2, 3 * mout_grad.size(0)), s); - mshadow::Tensor rhs_workspace = workspace[0]; - mshadow::Tensor lhs_workspace = workspace[1]; - if (param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x.T, y.T) - // dy = dot(x, dz).T = dot(dz.T, x.T) - // dx = dot(dz, y).T = dot(y.T, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } else if (!param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x, y.T) - // dy = dot(x.T, dz).T = dot(dz.T, x) - // dx = dot(dz, y) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } else if (param.transpose_a && !param.transpose_b) { - // Gradient of z = dot(x.T, y) - // dy = dot(x, dz) - // dx = dot(dz, y.T).T = dot(y, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } else { - // Gradient of z = dot(x, y) - // dy = dot(x.T, dz) - // dx = dot(dz, y.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, - (kAddTo == req[1]) ? 1.0f : 0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, - (kAddTo == req[0]) ? 1.0f : 0.0f, - lhs_workspace); - } - } -} - -inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - TShape& lshape = (*in_attrs)[0]; - TShape& rshape = (*in_attrs)[1]; - if (lshape.ndim() == 3 && rshape.ndim() == 3) { - CHECK(lshape[0] == rshape[0]) - << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape - << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - index_t out_m = param.transpose_a ? lshape[2] : lshape[1]; - index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2]; - index_t out_n = param.transpose_b ? rshape[1] : rshape[2]; - index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1]; - CHECK(lshape_k == rshape_k) - << "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape - << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n)); - } else { - LOG(FATAL) << "batch_dot currently only support 3D*3D array" - << lshape << " v.s. " << rshape; - } - return true; -} - struct SliceParam : public dmlc::Parameter { nnvm::Tuple > begin, end; DMLC_DECLARE_PARAMETER(SliceParam) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 72d8aadbe90a..e6ab9798bef6 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -16,7 +16,6 @@ DMLC_REGISTER_PARAMETER(ClipParam); DMLC_REGISTER_PARAMETER(SimpleCropAssignScalarParam); DMLC_REGISTER_PARAMETER(SliceParam); DMLC_REGISTER_PARAMETER(SliceAxisParam); -DMLC_REGISTER_PARAMETER(DotParam); DMLC_REGISTER_PARAMETER(RepeatParam); DMLC_REGISTER_PARAMETER(TileParam); DMLC_REGISTER_PARAMETER(ReverseParam); @@ -344,106 +343,6 @@ NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("TIsBackward", true) .set_attr("FCompute", SliceAxisGrad_); -NNVM_REGISTER_OP(dot) -.describe(R"doc(Dot product of two arrays. - -``dot``'s behavior depends on the input array dimensions: - -- 1-D arrays: inner product of vectors -- 2-D arrays: matrix multiplication -- N-D arrays: a sum product over the last axis of the first input and the first - axis of the second input - - For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the - result array will have shape `(n,m,r,s)`. It is computed by:: - - dot(x,y)[i,j,a,b] = sum(x[i,j,:]*y[:,a,b]) - - Example:: - - x = reshape([0,1,2,3,4,5,6,7], shape=(2,2,2)) - y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2)) - dot(x,y)[0,0,1,1] = 0 - sum(x[0,0,:]*y[:,1,1]) = 0 -)doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) -.set_attr("FInferShape", DotShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FInferStorageType", DotForwardInferStorageType) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", DotForward_) -.set_attr("FComputeEx", DotForwardEx) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) -.add_argument("lhs", "NDArray-or-Symbol", "The first input") -.add_argument("rhs", "NDArray-or-Symbol", "The second input") -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_dot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FInferStorageType", DotBackwardInferStorageType) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", DotBackward_) -.set_attr("FComputeEx", DotBackwardEx) -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(batch_dot) -.describe(R"doc(Batchwise dot product. - -``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and -``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`. - -For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape -`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`, -which is computed by:: - - batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:]) - -)doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) -.set_attr("FInferShape", BatchDotShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BatchDotForward_) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) -.add_argument("lhs", "NDArray-or-Symbol", "The first input") -.add_argument("rhs", "NDArray-or-Symbol", "The second input") -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_batch_dot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BatchDotBackward_); - NNVM_REGISTER_OP(clip) .describe(R"code(Clips (limits) the values in an array. diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 2e1effb9e560..91a6757b962c 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -39,21 +39,6 @@ NNVM_REGISTER_OP(slice_axis) NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("FCompute", SliceAxisGrad_); -NNVM_REGISTER_OP(dot) -.set_attr("FCompute", DotForward_) -.set_attr("FComputeEx", DotForwardEx); - -NNVM_REGISTER_OP(_backward_dot) -.set_attr("FCompute", DotBackward_) -.set_attr("FComputeEx", DotBackwardEx); - - -NNVM_REGISTER_OP(batch_dot) -.set_attr("FCompute", BatchDotForward_); - -NNVM_REGISTER_OP(_backward_batch_dot) -.set_attr("FCompute", BatchDotBackward_); - NNVM_REGISTER_OP(clip) .set_attr("FCompute", Clip); From 4d43935bf7e14acac25b97fdd5dff5954f697867 Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 30 Jun 2017 22:52:40 -0700 Subject: [PATCH 7/9] Move dot gpu impl to a .cuh file --- src/operator/nn/matrix_dot-inl.cuh | 161 +++++++++++++ src/operator/nn/matrix_dot-inl.h | 347 ++++++++++------------------- src/operator/tensor/indexing_op.h | 4 +- 3 files changed, 280 insertions(+), 232 deletions(-) create mode 100644 src/operator/nn/matrix_dot-inl.cuh diff --git a/src/operator/nn/matrix_dot-inl.cuh b/src/operator/nn/matrix_dot-inl.cuh new file mode 100644 index 000000000000..b1468efc3e98 --- /dev/null +++ b/src/operator/nn/matrix_dot-inl.cuh @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file matrix_dot-inl.cuh + * \brief implementation of matrix dot op on GPU + */ +#ifndef MXNET_OPERATOR_NN_MATRIX_DOT_INL_CUH_ +#define MXNET_OPERATOR_NN_MATRIX_DOT_INL_CUH_ + +#include +#include + +namespace mxnet { +namespace op { + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrDnsDns { + /*! + * \brief This function represents performing an inner product between a row of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_cols number of columns of output + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols) { + const int irow = i / num_cols; // row id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { + const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs + sum += data_l[j] * data_r[cur_col*num_cols+icol]; + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements + */ +template +struct DotCsrTransDnsDns { + /*! + * \brief This function represents performing an inner product between a column of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_rows_l number of rows of lhs + * \param num_cols number of columns of outputs + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const int num_rows_l, + const int num_cols) { + const int irow = i / num_cols; // col id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (int k = 0; k < num_rows_l; ++k) { + const IType low = indptr_l[k]; + const IType high = indptr_l[k+1]; + if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; + int j = -1, l = low, r = high - 1; + while (l <= r) { + int m = l + (r - l) / 2; + if (col_idx_l[m] == irow) { + j = m; break; + } + if (col_idx_l[m] < irow) { + l = m + 1; + } else { + r = m - 1; + } + } + if (j >= 0) { + sum += data_l[j] * data_r[k*num_cols+icol]; + } + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +inline void DotCsrDnsDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + if (!lhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (trans_lhs) { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + data_out.shape_[1]); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); + }); + } + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr.T, dns) = rsp + */ +inline void DotCsrDnsRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + LOG(FATAL) << "DotCsrDnsRspImpl gpu version is not implemented."; +} + +/*! + * \brief Impl of dot(csr.T, rsp) = rsp2 + */ +inline void DotCsrRspRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + LOG(FATAL) << "DotCsrRspRspImpl gpu version is not implemented."; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NN_MATRIX_DOT_INL_CUH_ diff --git a/src/operator/nn/matrix_dot-inl.h b/src/operator/nn/matrix_dot-inl.h index c32453faf127..2025049edbdd 100644 --- a/src/operator/nn/matrix_dot-inl.h +++ b/src/operator/nn/matrix_dot-inl.h @@ -15,6 +15,9 @@ #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../mxnet_op.h" +#ifdef __CUDACC__ +#include "./matrix_dot-inl.cuh" +#endif // __CUDACC__ namespace mxnet { namespace op { @@ -208,87 +211,6 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, return true; } -/*! - * \brief Kernel of dot(csr, dns1) = dns2 - * Parallelization by output matrix elements - */ -template -struct DotCsrDnsDns { - /*! - * \brief This function represents performing an inner product between a row of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output matrix - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_cols number of columns of output - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, - const int num_cols) { - const int irow = i / num_cols; // row id of the lhs - const int icol = i % num_cols; // col id of the rhs - DType sum = 0; - for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { - const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs - sum += data_l[j] * data_r[cur_col*num_cols+icol]; - } - KERNEL_ASSIGN(out[i], req, sum); - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns1) = dns2 - * Parallelization by output matrix elements - */ -template -struct DotCsrTransDnsDns { - /*! - * \brief This function represents performing an inner product between a column of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output matrix - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_rows_l number of rows of lhs - * \param num_cols number of columns of outputs - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, const int num_rows_l, - const int num_cols) { - const int irow = i / num_cols; // col id of the lhs - const int icol = i % num_cols; // col id of the rhs - DType sum = 0; - for (int k = 0; k < num_rows_l; ++k) { - const IType low = indptr_l[k]; - const IType high = indptr_l[k+1]; - if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; - int j = -1, l = low, r = high - 1; - while (l <= r) { - int m = l + (r - l) / 2; - if (col_idx_l[m] == irow) { - j = m; break; - } - if (col_idx_l[m] < irow) { - l = m + 1; - } else { - r = m - 1; - } - } - if (j >= 0) { - sum += data_l[j] * data_r[k*num_cols+icol]; - } - } - KERNEL_ASSIGN(out[i], req, sum); - } -}; - /*! * \brief Kernel of dot(csr, dns1) = dns2 * Parallelization by row blocks @@ -493,18 +415,16 @@ struct DotCsrTransRspRspByRowBlocks { } }; -template -void DotCsrDnsDnsImpl(const OpContext& ctx, - const NDArray& lhs, - const TBlob& rhs, - const OpReqType req, - const bool trans_lhs, - TBlob* ret) { +inline void DotCsrDnsDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); if (!lhs.storage_initialized()) return; - mshadow::Stream *s = ctx.get_stream(); const TBlob data_l = lhs.data(); const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); const TBlob col_idx_l = lhs.aux_data(csr::kIdx); @@ -514,39 +434,22 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), seg_len, - lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); - } else { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), seg_len, - data_out.shape_[0], data_out.shape_[1]); - } - } else { // gpu parallelization by output elements - if (trans_lhs) { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], - data_out.shape_[1]); - }); - } else { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); - }); - } + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + } else { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + data_out.shape_[0], data_out.shape_[1]); } }); }); @@ -556,19 +459,17 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, /*! * \brief Impl of dot(csr, rsp) */ -template -void DotCsrDnsRspImpl(const OpContext& ctx, - const NDArray& lhs, - const TBlob& rhs, - const OpReqType req, - const bool trans_lhs, - NDArray* ret) { +inline void DotCsrDnsRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(ret->storage_type(), kRowSparseStorage); if (!lhs.storage_initialized()) return; - mshadow::Stream *s = ctx.get_stream(); const TBlob data_l = lhs.data(); const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); const TBlob col_idx_l = lhs.aux_data(csr::kIdx); @@ -583,41 +484,37 @@ void DotCsrDnsRspImpl(const OpContext& ctx, MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - RType* row_idx = row_idx_out.dptr(); - mxnet_op::Kernel::Launch( - s, row_idx_out.Size(), row_idx); - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), row_idx, data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); - index_t nnr = 0; - nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); - ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); - ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); - if (0 == nnr) return; - mshadow::Tensor rsp_data = data_out.FlatTo2D(s); - size_t idx = 0; - for (index_t i = 0; i < ret->shape()[0]; ++i) { - if (row_idx[i] > 0) { - row_idx[idx] = i; - mshadow::Copy(rsp_data[idx], rsp_data[i], s); - ++idx; - } + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; } - } else { - LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." - " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; } } else { - LOG(FATAL) << "DotCsrDnsRspImpl has not implemented GPU version yet."; + LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." + " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; } }); }); @@ -626,7 +523,7 @@ void DotCsrDnsRspImpl(const OpContext& ctx, } template -void DotCsrRspDnsImpl(const OpContext& ctx, +void DotCsrRspDnsImpl(mshadow::Stream* s, const NDArray& lhs, const NDArray& rhs, const OpReqType req, @@ -634,14 +531,13 @@ void DotCsrRspDnsImpl(const OpContext& ctx, TBlob* ret) { // reuse csr dns implementation when storage_shape == shape for rhs if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense - DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + DotCsrDnsDnsImpl(s, lhs, rhs.data(), req, trans_lhs, ret); return; } if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kRowSparseStorage); - mshadow::Stream *s = ctx.get_stream(); if (!lhs.storage_initialized() || !rhs.storage_initialized()) { if (kWriteTo == req) { MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { // data type @@ -662,24 +558,20 @@ void DotCsrRspDnsImpl(const OpContext& ctx, MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, ret->Size(), ret->dptr()); - } - int num_threads = mxnet_op::get_num_threads(ret->shape_[0]); - size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet"; - } else { - mxnet_op::Kernel::Launch(s, num_threads, - ret->dptr(), data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - row_idx_r.dptr(), rhs.storage_shape()[0], - ret->shape_[0], ret->shape_[1], seg_len); - } + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + } + int num_threads = mxnet_op::get_num_threads(ret->shape_[0]); + size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet"; } else { - LOG(FATAL) << "DotCsrRspDnsImpl has not implemented GPU version yet"; + mxnet_op::Kernel::Launch(s, num_threads, + ret->dptr(), data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), rhs.storage_shape()[0], + ret->shape_[0], ret->shape_[1], seg_len); } }); }); @@ -690,16 +582,15 @@ void DotCsrRspDnsImpl(const OpContext& ctx, /*! * \brief Impl of dot(csr.T, rsp) = rsp2 */ -template -void DotCsrRspRspImpl(const OpContext& ctx, - const NDArray& lhs, - const NDArray& rhs, - const OpReqType req, - const bool trans_lhs, - NDArray* ret) { +inline void DotCsrRspRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { // reuse csr dns implementation when storage_shape == shape for rhs if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense - DotCsrDnsRspImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + DotCsrDnsRspImpl(s, lhs, rhs.data(), req, trans_lhs, ret); return; } @@ -709,7 +600,6 @@ void DotCsrRspRspImpl(const OpContext& ctx, CHECK_EQ(ret->storage_type(), kRowSparseStorage); if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; - mshadow::Stream *s = ctx.get_stream(); const TBlob data_l = lhs.data(); const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); const TBlob col_idx_l = lhs.aux_data(csr::kIdx); @@ -727,41 +617,37 @@ void DotCsrRspRspImpl(const OpContext& ctx, MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type - if (std::is_same::value) { // cpu parallelization by row blocks - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; - if (trans_lhs) { - RType* row_idx = row_idx_out.dptr(); - mxnet_op::Kernel::Launch( - s, row_idx_out.Size(), row_idx); - mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), row_idx, data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - row_idx_r.dptr(), lhs.shape()[0], rhs.storage_shape()[0], - ret->shape()[0], ret->shape()[1], seg_len); - index_t nnr = 0; - nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); - ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); - ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); - if (0 == nnr) return; - mshadow::Tensor rsp_data = data_out.FlatTo2D(s); - size_t idx = 0; - for (index_t i = 0; i < ret->shape()[0]; ++i) { - if (row_idx[i] > 0) { - row_idx[idx] = i; - mshadow::Copy(rsp_data[idx], rsp_data[i], s); - ++idx; - } + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), lhs.shape()[0], rhs.storage_shape()[0], + ret->shape()[0], ret->shape()[1], seg_len); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; } - } else { - LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr.T, rsp) = rsp2 yet"; } } else { - LOG(FATAL) << "DotCsrRspRspImpl has not implemented GPU version yet"; + LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr.T, rsp) = rsp2 yet"; } }); }); @@ -826,21 +712,22 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, auto lhs_stype = inputs[0].storage_type(); auto rhs_stype = inputs[1].storage_type(); auto out_stype = outputs[0].storage_type(); + mshadow::Stream* s = ctx.get_stream(); if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); - DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + DotCsrDnsDnsImpl(s, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); - DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + DotCsrRspDnsImpl(s, inputs[0], inputs[1], req[0], param.transpose_a, &ret); } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kRowSparseStorage) { NDArray out = outputs[0]; - DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + DotCsrDnsRspImpl(s, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && out_stype == kRowSparseStorage) { NDArray ret = outputs[0]; - DotCsrRspRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + DotCsrRspRspImpl(s, inputs[0], inputs[1], req[0], param.transpose_a, &ret); } else { FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); } @@ -865,17 +752,17 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, const auto lhs_stype = inputs[1].storage_type(); const auto rhs_stype = inputs[2].storage_type(); const auto grad_rhs_stype = outputs[1].storage_type(); - + mshadow::Stream* s = ctx.get_stream(); if (ograd_stype == kDefaultStorage // ograd dns format && lhs_stype == kCSRStorage // csr input lhs of the op && grad_rhs_stype == kDefaultStorage) { // grad(rhs) dns format TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + DotCsrDnsDnsImpl(s, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && grad_rhs_stype == kRowSparseStorage) { NDArray ret = outputs[1]; - DotCsrDnsRspImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + DotCsrDnsRspImpl(s, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); } else { FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); } diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5f6f85e46ca6..5373badfd734 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -215,7 +215,7 @@ void SparseEmbeddingForwardRspImpl(const nnvm::NodeAttrs& attrs, TBlob out_blob = out->data(); // forward to dns implementation when storage_shape equals shape bool transpose_a = false; - DotCsrRspDnsImpl(ctx, data, weight, req, transpose_a, &out_blob); + DotCsrRspDnsImpl(ctx.get_stream(), data, weight, req, transpose_a, &out_blob); } template @@ -408,7 +408,7 @@ void SparseEmbeddingBackwardEx(const nnvm::NodeAttrs& attrs, if (data_stype == kCSRStorage && grad_stype == kDefaultStorage && output_stype == kDefaultStorage) { TBlob ret = outputs[1].data(); - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], true, &ret); + DotCsrDnsDnsImpl(ctx.get_stream(), inputs[1], inputs[0].data(), req[1], true, &ret); } else { LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; } From fd1ae7114e9d32796a4db54df797b7f02107a5b2 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 3 Jul 2017 11:47:00 -0700 Subject: [PATCH 8/9] More refactor --- src/common/utils.cc | 2 +- src/common/utils.cu | 2 +- src/operator/{nn => tensor}/cast_storage-inl.cuh | 6 +++--- src/operator/{nn => tensor}/cast_storage-inl.h | 6 +++--- src/operator/{nn => tensor}/cast_storage.cc | 0 src/operator/{nn => tensor}/cast_storage.cu | 0 .../{nn/matrix_dot-inl.cuh => tensor/dot-inl.cuh} | 8 ++++---- src/operator/{nn/matrix_dot-inl.h => tensor/dot-inl.h} | 8 ++++---- src/operator/{nn/matrix_dot.cc => tensor/dot.cc} | 4 ++-- src/operator/{nn/matrix_dot.cu => tensor/dot.cu} | 4 ++-- src/operator/tensor/indexing_op.h | 2 +- 11 files changed, 21 insertions(+), 21 deletions(-) rename src/operator/{nn => tensor}/cast_storage-inl.cuh (78%) rename src/operator/{nn => tensor}/cast_storage-inl.h (98%) rename src/operator/{nn => tensor}/cast_storage.cc (100%) rename src/operator/{nn => tensor}/cast_storage.cu (100%) rename src/operator/{nn/matrix_dot-inl.cuh => tensor/dot-inl.cuh} (97%) rename src/operator/{nn/matrix_dot-inl.h => tensor/dot-inl.h} (99%) rename src/operator/{nn/matrix_dot.cc => tensor/dot.cc} (98%) rename src/operator/{nn/matrix_dot.cu => tensor/dot.cu} (92%) diff --git a/src/common/utils.cc b/src/common/utils.cc index 5bfb959fdf34..4bcae02e990c 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -5,7 +5,7 @@ */ #include "./utils.h" -#include "../operator/nn/cast_storage-inl.h" +#include "../operator/tensor/cast_storage-inl.h" namespace mxnet { namespace common { diff --git a/src/common/utils.cu b/src/common/utils.cu index a249be5bb9f5..7221a2b6ec6c 100644 --- a/src/common/utils.cu +++ b/src/common/utils.cu @@ -5,7 +5,7 @@ */ #include "./utils.h" -#include "../operator/nn/cast_storage-inl.h" +#include "../operator/tensor/cast_storage-inl.h" namespace mxnet { namespace common { diff --git a/src/operator/nn/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh similarity index 78% rename from src/operator/nn/cast_storage-inl.cuh rename to src/operator/tensor/cast_storage-inl.cuh index b99d875eb612..0d4e601d0d2e 100644 --- a/src/operator/nn/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -3,8 +3,8 @@ * \file cast_storage-inl.cuh * \brief implementation of cast_storage op on GPU */ -#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ -#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ +#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ #include #include @@ -23,4 +23,4 @@ inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDA } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_CUH_ +#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ diff --git a/src/operator/nn/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h similarity index 98% rename from src/operator/nn/cast_storage-inl.h rename to src/operator/tensor/cast_storage-inl.h index f0268c797c74..9273b996d48e 100644 --- a/src/operator/nn/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -3,8 +3,8 @@ * \file cast_storage-inl.h * \brief cast_storage implementation for dense and sparse tensors */ -#ifndef MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ -#define MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ +#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ +#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ #include #include @@ -333,4 +333,4 @@ void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NN_CAST_STORAGE_INL_H_ +#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ diff --git a/src/operator/nn/cast_storage.cc b/src/operator/tensor/cast_storage.cc similarity index 100% rename from src/operator/nn/cast_storage.cc rename to src/operator/tensor/cast_storage.cc diff --git a/src/operator/nn/cast_storage.cu b/src/operator/tensor/cast_storage.cu similarity index 100% rename from src/operator/nn/cast_storage.cu rename to src/operator/tensor/cast_storage.cu diff --git a/src/operator/nn/matrix_dot-inl.cuh b/src/operator/tensor/dot-inl.cuh similarity index 97% rename from src/operator/nn/matrix_dot-inl.cuh rename to src/operator/tensor/dot-inl.cuh index b1468efc3e98..513fde306bab 100644 --- a/src/operator/nn/matrix_dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -1,10 +1,10 @@ /*! * Copyright (c) 2017 by Contributors - * \file matrix_dot-inl.cuh + * \file dot-inl.cuh * \brief implementation of matrix dot op on GPU */ -#ifndef MXNET_OPERATOR_NN_MATRIX_DOT_INL_CUH_ -#define MXNET_OPERATOR_NN_MATRIX_DOT_INL_CUH_ +#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ #include #include @@ -158,4 +158,4 @@ inline void DotCsrRspRspImpl(mshadow::Stream* s, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NN_MATRIX_DOT_INL_CUH_ +#endif // MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ diff --git a/src/operator/nn/matrix_dot-inl.h b/src/operator/tensor/dot-inl.h similarity index 99% rename from src/operator/nn/matrix_dot-inl.h rename to src/operator/tensor/dot-inl.h index 2025049edbdd..73252d0eda58 100644 --- a/src/operator/nn/matrix_dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -1,11 +1,11 @@ /*! * Copyright (c) 2017 by Contributors - * \file matrix_dot-inl.h + * \file dot-inl.h * \brief Function definition of matrix dot operator */ -#ifndef MXNET_OPERATOR_NN_MATRIX_DOT_INL_H_ -#define MXNET_OPERATOR_NN_MATRIX_DOT_INL_H_ +#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_H_ +#define MXNET_OPERATOR_TENSOR_DOT_INL_H_ #include #include @@ -921,4 +921,4 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#endif // MXNET_OPERATOR_NN_MATRIX_DOT_INL_H_ +#endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_ diff --git a/src/operator/nn/matrix_dot.cc b/src/operator/tensor/dot.cc similarity index 98% rename from src/operator/nn/matrix_dot.cc rename to src/operator/tensor/dot.cc index 716efea6999c..fc476a75eec8 100644 --- a/src/operator/nn/matrix_dot.cc +++ b/src/operator/tensor/dot.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2017 by Contributors - * \file matrix_dot.cc + * \file dot.cc * \brief CPU Implementation of matrix dot */ -#include "./matrix_dot-inl.h" +#include "./dot-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/matrix_dot.cu b/src/operator/tensor/dot.cu similarity index 92% rename from src/operator/nn/matrix_dot.cu rename to src/operator/tensor/dot.cu index 21592e15449e..ae00566d5d45 100644 --- a/src/operator/nn/matrix_dot.cu +++ b/src/operator/tensor/dot.cu @@ -1,10 +1,10 @@ /*! * Copyright (c) 2017 by Contributors - * \file matrix_dot.cu + * \file dot.cu * \brief GPU Implementation of matrix dot */ -#include "./matrix_dot-inl.h" +#include "./dot-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5373badfd734..46aa6fcd73a4 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -22,7 +22,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "./sort_op.h" -#include "../nn/matrix_dot-inl.h" +#include "./dot-inl.h" namespace mxnet { namespace op { From 6b8912cdcc5d29175573a1e89f63a7b7dcbf83b9 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 3 Jul 2017 11:49:32 -0700 Subject: [PATCH 9/9] Fix include error --- src/operator/tensor/dot-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 73252d0eda58..33cc095c0cee 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -16,7 +16,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #ifdef __CUDACC__ -#include "./matrix_dot-inl.cuh" +#include "./dot-inl.cuh" #endif // __CUDACC__ namespace mxnet {