Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Dec 10, 2017
1 parent 3fad020 commit 2521627
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,22 +202,25 @@ void DotBackward_(const nnvm::NodeAttrs& attrs,
inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
// csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp
// csr has many zero columns, so the result of dot(csr.T, matrix) should be
// rsp
const auto& lhs_stype = in_attrs->at(0);
const auto& rhs_stype = in_attrs->at(1);
auto& out_stype = out_attrs->at(0);
bool dispatched = false;
bool only_lhs_transpose = param.transpose_a && !param.transpose_b;
bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) {
bool rhs_rsp_or_dns =
rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
if (!dispatched && lhs_stype == kDefaultStorage &&
rhs_stype == kDefaultStorage) {
// dns, dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose &&
(rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage)) {
Expand All @@ -228,17 +231,16 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns &&
!param.transpose_a && !param.transpose_b) {
// csr, rsp/dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a && !param.transpose_b) {
// dns, csr -> csr
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const DispatchMode dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback
: DispatchMode::kFComputeEx;
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
dispatch_ex);
if (dev_mask == mshadow::cpu::kDevMask) {
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
Expand Down Expand Up @@ -897,14 +899,15 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
/*
* \brief CPU Impl of dot(dns, csr) = csr
*/
inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
template<typename xpu>
inline void DotDnsCsrCsrImpl(const OpContext& ctx,
const TBlob& lhs, const NDArray& rhs,
const OpReqType req, NDArray* ret) {
if (kNullOp == req) return;

CHECK_EQ(req, kWriteTo);

CHECK_EQ(rhs.storage_type(), kCSRStorage);

using namespace mshadow;
using namespace mshadow::expr;
using nnvm::dim_t;
Expand All @@ -918,10 +921,10 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
const dim_t out_data_size = lhs.shape_[0] * rhs.shape()[1];
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
FillZerosCsrImpl(s, *ret);
return;
}


MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type
Expand Down Expand Up @@ -1067,7 +1070,7 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
out_stype == kCSRStorage &&
!(param.transpose_a || param.transpose_b)) {
NDArray ret = outputs[0];
DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret);
DotDnsCsrCsrImpl<xpu>(ctx, inputs[0].data(), inputs[1], req[0], &ret);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
Expand Down

0 comments on commit 2521627

Please sign in to comment.