Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add operator for dot(dns, csr) = csr
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Dec 4, 2017
1 parent ddb4782 commit cffb161
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 16 deletions.
53 changes: 37 additions & 16 deletions benchmark/python/sparse/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def bench_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype,
# Create matrix instances
lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den, distribution=distribution)
# only uniform distribution supported for rhs
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
if rhs_stype == 'csr':
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution=distribution)
else:
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
lhs_dns = None
rhs_dns = None
dense_cost = None
Expand Down Expand Up @@ -337,27 +340,41 @@ def print_benchmark_info(lhs, rhs, lhs_trans, fw):

def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", rhs_density=1,
distribution="uniform"):
if lhs != "csr":
raise ValueError("Value other than csr for lhs not supported")

if rhs_density > 1 or rhs_density < 0:
raise ValueError("rhs_density has to be between 0 and 1")
raise ValueError("Value other than csr for lhs not supported")

print_benchmark_info(lhs, rhs, lhs_trans, fw)

if rhs == "csr":
lhs_stype = "default"
rhs_stype = "csr"
assert (lhs_stype == 'default'), "Only dot(default, csr) supported"
# Arrange dimensions according to use case. For below csr will have num_rows << num_cols
feature_dim_list = data_dict['batch_size']
batch_size_list = data_dict['m']
output_dim_list = data_dict['feature_dim']
density_list = data_dict['density']
default_output_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
default_feature_index = data_dict['default_index']['batch_size']
default_batch_size_index = data_dict['default_index']['output_dim']
num_repeat = data_dict['num_repeat']

lhs_stype = "csr"
rhs_stype = "row_sparse" if rhs == "rsp" else "default"
else:
lhs_stype = "csr"
rhs_stype = "row_sparse" if rhs == "rsp" else "default"

feature_dim_list = data_dict['feature_dim']
output_dim_list = data_dict['m']
batch_size_list = data_dict['batch_size']
density_list = data_dict['density']
feature_dim_list = data_dict['feature_dim']
output_dim_list = data_dict['m']
batch_size_list = data_dict['batch_size']
density_list = data_dict['density']

default_output_index = data_dict['default_index']['output_dim']
default_batch_size_index = data_dict['default_index']['batch_size']
default_feature_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
num_repeat = data_dict['num_repeat']
default_output_index = data_dict['default_index']['output_dim']
default_batch_size_index = data_dict['default_index']['batch_size']
default_feature_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
num_repeat = data_dict['num_repeat']

for output_dim in output_dim_list:
if lhs_trans:
Expand Down Expand Up @@ -403,7 +420,7 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r
feature_dim_list[default_feature_index]),
(output_row_dim,
output_dim_list[default_output_index]),
lhs_stype, rhs_stype, density, rhs_density, lhs_trans, ctx,
lhs_stype, rhs_stype, density, density, lhs_trans, ctx,
num_repeat=num_repeat, fw=fw, distribution=distribution)

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(ARGS.num_omp_threads)))
Expand All @@ -423,6 +440,10 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r
rhs="rsp", lhs_trans=False,
fw="mxnet", rhs_density=0.05,
distribution=distribution)
run_benchmark(context, lhs="default",
rhs="csr", lhs_trans=False,
fw="mxnet", rhs_density=0.001,
distribution=distribution)
if not ARGS.gpu:
run_benchmark(context, lhs="csr",
rhs="default", lhs_trans=False,
Expand Down
163 changes: 163 additions & 0 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
!param.transpose_a && !param.transpose_b) {
// dns, csr -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
Expand Down Expand Up @@ -527,6 +533,69 @@ struct DotCsrTransRspRspByRowBlocks {
}
};

/*!
* \brief CPU Kernel of PopulateCsrForNNC
* Parallelization by individual rows
*/
struct PopulateCsrForNNC {
/*!
* \brief
* \param i the i-th thread
* \param nnc_idx all non zero column indexes
* \param indptr_out indptr array for output
* \param col_idx_out column indices for output
* \param nnc number of non zero columns in the output
* \param num_rows_l number of rows in lhs
*/
template <typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i, const CType* nnc_idx,
IType* indptr_out, CType* col_idx_out,
const nnvm::dim_t nnc,
const nnvm::dim_t num_rows_l) {
const CType start_idx = i * nnc;
nnvm::dim_t cur = 0;
indptr_out[i] = start_idx;
if (i == static_cast<int>(num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc;
for (IType idx = start_idx; idx < (start_idx + nnc); idx++) {
col_idx_out[idx] = nnc_idx[cur++];
}
}
};

/*!
* \brief CPU Impl of dot(dns, csr) = csr
*/
struct DotDnsCsrCsrByRowBlocks {
template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(
int i, DType* out, const DType* data_l, const IType* indptr_r,
const CType* col_idx_r, const DType* data_r, const nnvm::dim_t seg_len,
const IType num_rows_r, const IType num_rows_l,
const nnvm::dim_t num_cols, const nnvm::dim_t nnc,
const CType* prefix_sum) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);

for (dim_t j = seg_start; j < seg_end; j++) {
for (dim_t k = 0; k < num_rows_r; k++) {
const dim_t working_idx = j * num_rows_r + k;
const DType val = data_l[working_idx];
if (indptr_r[k] == indptr_r[k + 1]) continue;
const dim_t row_start = j * nnc;
for (dim_t cur = indptr_r[k]; cur < indptr_r[k + 1]; cur++) {
dim_t cur_col_idx_r = col_idx_r[cur];
const dim_t out_idx = row_start + prefix_sum[cur_col_idx_r] - 1;
out[out_idx] += val * data_r[cur];
}
}
}
}
};



/*!
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
*/
Expand Down Expand Up @@ -811,6 +880,95 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
});
}

/*
* \brief CPU Impl of dot(dns, csr) = csr
*/
inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
const TBlob& lhs, const NDArray& rhs,
const OpReqType req, NDArray* ret) {
if (kNullOp == req) return;
CHECK_EQ(rhs.storage_type(), kCSRStorage);
if (!rhs.storage_initialized()) return;

using namespace mshadow;
using namespace mshadow::expr;
using nnvm::dim_t;

/*Initialize data structures*/
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
const NDArray& out = *ret;
const TBlob data_l = lhs;
const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);

MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type
/* Allocate workspace */
CType num_cols_out = out.shape()[1];
CType rhs_data_size = static_cast<CType>(col_idx_r.shape_.Size());
size_t workspace_size = 2 * num_cols_out * sizeof(CType);
Tensor<cpu, 1, char> workspace =
ctx.requested[0].get_space_typed<cpu, 1, char>(
Shape1(workspace_size), s);
CType* col_flg = reinterpret_cast<dim_t*>(workspace.dptr_);

CType* prefix_sum = col_flg;
CType* nnc_idx = prefix_sum + num_cols_out;

/* Set the column flags for nnz columns */
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_cols_out,
col_flg);
mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch(
s, rhs_data_size, col_flg, col_idx_r.dptr<CType>());

/* 1. Calculate prefix sum from col flgs
* 2. Storage all non zero column indexes in nnc_idx
*/
CType cur = 0;
prefix_sum[0] = col_flg[0];
if (prefix_sum[0]) nnc_idx[cur++] = 0;
for (CType i = 1; i < num_cols_out; i++) {
prefix_sum[i] = prefix_sum[i - 1] + col_flg[i];
if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i;
}

/* Allocate aux data for out */
IType num_rows_l = lhs.shape_[0];
dim_t nnc = prefix_sum[num_cols_out - 1];
dim_t nnz = nnc * num_rows_l;
out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1));
out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));

/* Set csr indptr and index according to nnc_idx*/
IType* indptr_out = out.aux_data(csr::kIndPtr).dptr<IType>();
CType* col_idx_out = out.aux_data(csr::kIdx).dptr<CType>();
DType* data_out = out.data().dptr<DType>();
mxnet_op::Kernel<PopulateCsrForNNC, cpu>::Launch(
s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);

if (nnc == 0) {
return;
}

dim_t num_threads = mxnet_op::get_num_threads<cpu>(num_rows_l);
dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads;

IType num_rows_r = rhs.shape()[0];
mxnet_op::Kernel<DotDnsCsrCsrByRowBlocks, cpu>::Launch(
s, num_threads, data_out, data_l.dptr<DType>(),
indptr_r.dptr<IType>(), col_idx_r.dptr<CType>(),
data_r.dptr<DType>(), seg_len, num_rows_r, num_rows_l, num_cols_out,
nnc, prefix_sum);

});
});
});
}

inline bool DotShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand Down Expand Up @@ -886,6 +1044,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
&& out_stype == kRowSparseStorage && !param.transpose_b) {
NDArray ret = outputs[0];
DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
out_stype == kCSRStorage &&
!(param.transpose_a || param.transpose_b)) {
NDArray ret = outputs[0];
DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,31 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_de
grad_req={'lhs': 'null', 'rhs': 'write'},
rtol=1e-3, atol=1e-4)

def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=False, trans_rhs=False):
lhs_nd = rand_ndarray(lhs_shape, stype='default', density=lhs_density)
rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density)
rhs_dns = rhs_nd.tostype('default')

out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs)
out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs)
out_np = out_dns.asnumpy()
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5)

# test symbolic forward
lhs = mx.symbol.Variable('lhs', stype='default')
rhs = mx.symbol.Variable('rhs', stype='csr')
out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs)
location = {'lhs': lhs_nd, 'rhs': rhs_nd}
check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)

# test symbolic backward
backward_trans = not trans_lhs
rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy()
expected = {'rhs': rhs_backward_grad}
check_symbolic_backward(out, location, [out_np], expected,
grad_req={'lhs': 'null', 'rhs': 'write'},
rtol=1e-3, atol=1e-4)

def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
"""Test for nnr_out = 0. Before the fix, the test would fail."""
lhs = mx.nd.zeros(lhs_shape)
Expand All @@ -1248,10 +1273,12 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True, lhs_d, rhs_d) # (vector kernel)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel)
test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(500, 1000)), lhs_d, lhs_d)
for rhs_d in density:
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d)


test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)

Expand Down

0 comments on commit cffb161

Please sign in to comment.