diff --git a/example/sparse/get_data.py b/example/sparse/get_data.py new file mode 100644 index 000000000000..6b6723f07be4 --- /dev/null +++ b/example/sparse/get_data.py @@ -0,0 +1,15 @@ +# pylint: skip-file +import os, gzip +import pickle as pickle +import sys + +def get_libsvm_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + import urllib + zippath = os.path.join(data_dir, data_origin_name) + urllib.urlretrieve(url, zippath) + os.system("bzip2 -d %r" % data_origin_name) + os.chdir("..") diff --git a/example/sparse/linear_regression.py b/example/sparse/linear_regression.py new file mode 100644 index 000000000000..6aa1cbadbcb2 --- /dev/null +++ b/example/sparse/linear_regression.py @@ -0,0 +1,178 @@ +import mxnet as mx +from mxnet.test_utils import * +from get_data import get_libsvm_data +import time +import argparse +import os + +parser = argparse.ArgumentParser(description="Run sparse linear regression " \ + "with distributed kvstore", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--profiler', type=int, default=0, + help='whether to use profiler') +parser.add_argument('--num-epoch', type=int, default=1, + help='number of epochs to train') +parser.add_argument('--batch-size', type=int, default=512, + help='number of examples per batch') +parser.add_argument('--num-batch', type=int, default=99999999, + help='number of batches per epoch') +parser.add_argument('--dummy-iter', type=int, default=0, + help='whether to use dummy iterator to exclude io cost') +parser.add_argument('--kvstore', type=str, default='dist_sync', + help='what kvstore to use [local, dist_sync, etc]') +parser.add_argument('--log-level', type=str, default='debug', + help='logging level [debug, info, error]') +parser.add_argument('--dataset', type=str, default='avazu', + help='what test dataset to use') + +class DummyIter(mx.io.DataIter): + "A dummy iterator that always return the same batch, used for speed testing" + def __init__(self, real_iter): + super(DummyIter, self).__init__() + self.real_iter = real_iter + self.provide_data = real_iter.provide_data + self.provide_label = real_iter.provide_label + self.batch_size = real_iter.batch_size + + for batch in real_iter: + self.the_batch = batch + break + + def __iter__(self): + return self + + def next(self): + return self.the_batch + +# testing dataset sources +avazu = { + 'data_name': 'avazu-app.t', + 'data_origin_name': 'avazu-app.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2", + 'feature_dim': 1000000, +} + +kdda = { + 'data_name': 'kdda.t', + 'data_origin_name': 'kdda.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2", + 'feature_dim': 20216830, +} + +datasets = { 'kdda' : kdda, 'avazu' : avazu } + +def regression_model(feature_dim): + initializer = mx.initializer.Normal() + x = mx.symbol.Variable("data", stype='csr') + norm_init = mx.initializer.Normal(sigma=0.01) + v = mx.symbol.Variable("v", shape=(feature_dim, 1), init=norm_init, stype='row_sparse') + embed = mx.symbol.dot(x, v) + y = mx.symbol.Variable("softmax_label") + model = mx.symbol.LinearRegressionOutput(data=embed, label=y, name="out") + return model + +if __name__ == '__main__': + + # arg parser + args = parser.parse_args() + num_epoch = args.num_epoch + num_batch = args.num_batch + kvstore = args.kvstore + profiler = args.profiler > 0 + batch_size = args.batch_size + dummy_iter = args.dummy_iter + dataset = args.dataset + log_level = args.log_level + + # create kvstore + kv = mx.kvstore.create(kvstore) + rank = kv.rank + num_worker = kv.num_workers + + # only print log for rank 0 worker + import logging + if rank != 0: + log_level = logging.ERROR + elif log_level == 'DEBUG': + log_level = logging.DEBUG + else: + log_level = logging.INFO + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=log_level, format=head) + + # dataset + assert(dataset in datasets), "unknown dataset " + dataset + metadata = datasets[dataset] + feature_dim = metadata['feature_dim'] + if logging: + logging.debug('preparing data ... ') + data_dir = os.path.join(os.getcwd(), 'data') + path = os.path.join(data_dir, metadata['data_name']) + if not os.path.exists(path): + get_libsvm_data(data_dir, metadata['data_name'], metadata['url'], + metadata['data_origin_name']) + assert os.path.exists(path) + + # data iterator + train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,), + batch_size=batch_size, num_parts=num_worker, + part_index=rank) + if dummy_iter: + train_data = DummyIter(train_data) + + # model + model = regression_model(feature_dim) + + # module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label']) + mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) + mod.init_params(initializer=mx.init.Uniform(scale=.1)) + sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0, + learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker) + mod.init_optimizer(optimizer=sgd, kvstore=kv) + # use accuracy as the metric + metric = mx.metric.create('MSE') + + # start profiler + if profiler: + import random + name = 'profile_output_' + str(num_worker) + '.json' + mx.profiler.profiler_set_config(mode='all', filename=name) + mx.profiler.profiler_set_state('run') + + logging.debug('start training ...') + start = time.time() + data_iter = iter(train_data) + for epoch in range(num_epoch): + nbatch = 0 + end_of_batch = False + data_iter.reset() + metric.reset() + next_batch = next(data_iter) + while not end_of_batch: + nbatch += 1 + batch = next_batch + # TODO(haibin) remove extra copy after Jun's change + row_ids = batch.data[0].indices.copyto(mx.cpu()) + # pull sparse weight + index = mod._exec_group.param_names.index('v') + kv.row_sparse_pull('v', mod._exec_group.param_arrays[index], + priority=-index, row_ids=[row_ids]) + mod.forward_backward(batch) + # update parameters + mod.update() + try: + # pre fetch next batch + next_batch = next(data_iter) + if nbatch == num_batch: + raise StopIteration + except StopIteration: + end_of_batch = True + # accumulate prediction accuracy + mod.update_metric(metric, batch.label) + logging.info('epoch %d, %s' % (epoch, metric.get())) + if profiler: + mx.profiler.profiler_set_state('stop') + end = time.time() + time_cost = end - start + logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost)) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 608eec5591db..2db96aa3e859 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1505,6 +1505,26 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, const char** keys, NDArrayHandle* vals, int priority); + +/*! + * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string. + * The NDArray pulled back will be in row_sparse storage with only the specified + * row_ids present based row_ids (others rows are zeros). + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param row_ids the list of row_id NDArrays + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + const NDArrayHandle* row_ids, + int priority); + /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index a77f653d492c..7dc2217b6389 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -7,6 +7,7 @@ #define MXNET_KVSTORE_H_ #include #include +#include #include #include #include @@ -155,6 +156,29 @@ class KVStore { const std::vector& values, int priority = 0) = 0; + /*! + * \brief pull a list of key-value pairs from the store. + * The NDArray pulled back will be in row_sparse storage with only the + * specified row_ids present (others rows are zeros). + * \param keys the list of keys + * \param values the list of buffers - row_id pairs + * \param priority the priority of the action. + */ + virtual void PullRowSparse(const std::vector& str_keys, + const std::vector>& val_rowids, + const int priority = 0) = 0; + + /*! + * \brief pull a list of key-value pairs from the store, where each key is a string. + * The NDArray pulled back will be in row_sparse storage with only the + * specified row_ids present (others rows are zeros). + * \param keys the list of keys in string format + * \param values the list of buffers - row_id pairs + * \param priority the priority of the action. + */ + virtual void PullRowSparse(const std::vector& str_keys, + const std::vector>& val_rowids, + const int priority = 0) = 0; /** * \brief the prototype of user-defined updater diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index cbe815eb2c24..3d46e69a4565 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -182,12 +182,13 @@ class NDArray { return shape_; } /*! - * \return the shape of underlying chunk which stores the NDArray values. - * For default storage, it is the same as shape(). For row-sparse storage, it is the shape of + * \return the shape of underlying chunk which stores the NDArray data/value. + * It is only intended for non-default storage. For row-sparse storage, it is the shape of * the tensor which stores the non-zero values. */ inline const TShape &storage_shape() const { CHECK(ptr_ != nullptr); + CHECK_NE(storage_type(), kDefaultStorage); return ptr_->storage_shape; } @@ -271,7 +272,11 @@ class NDArray { if (is_none()) return false; auto stype = storage_type(); CHECK_NE(stype, kDefaultStorage); - if (stype == kRowSparseStorage || stype == kCSRStorage) { + if (stype == kRowSparseStorage) { + CHECK_EQ(aux_shape(rowsparse::kIdx)[0], storage_shape()[0]); + return aux_shape(0).Size() != 0; + } else if (stype == kCSRStorage) { + CHECK_EQ(aux_shape(csr::kIdx)[0], storage_shape()[0]); return aux_shape(0).Size() != 0; } else { LOG(FATAL) << "Unknown storage type"; diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 655f602856da..8d96c751ccb3 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -5,58 +5,47 @@ import ctypes import pickle from .ndarray import NDArray +from .ndarray import _ndarray_cls from .base import _LIB from .base import check_call, c_array, c_str, string_types, mx_uint, py_str from .base import NDArrayHandle, KVStoreHandle from . import optimizer as opt -def _ctype_str_key_value(keys, vals): - names = [] - if isinstance(keys, str): - if isinstance(vals, NDArray): - names.append(c_str(keys)) - return (c_array(ctypes.c_char_p, names), - c_array(NDArrayHandle, [vals.handle])) - else: - for value in vals: - assert(isinstance(value, NDArray)) - return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), - c_array(NDArrayHandle, [value.handle for value in vals])) - else: +def _ctype_key_value(keys, vals): + if isinstance(keys, (tuple, list)): assert(len(keys) == len(vals)) - for k in keys: - assert(isinstance(k, str)) c_keys = [] c_vals = [] for key, val in zip(keys, vals): - c_key_i, c_val_i = _ctype_str_key_value(key, val) + c_key_i, c_val_i = _ctype_key_value(key, val) c_keys += c_key_i c_vals += c_val_i return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) - -def _cast_to_str_keys(keys): - if isinstance(keys, str): - return keys - if isinstance(keys, int): - return str(keys) - str_keys = [] - for key in keys: - str_keys.append(str(key) if isinstance(key, int) else key) - return str_keys + names = [] + keys = str(keys) + if isinstance(vals, NDArray): + names.append(c_str(keys)) + return (c_array(ctypes.c_char_p, names), + c_array(NDArrayHandle, [vals.handle])) + else: + for value in vals: + assert(isinstance(value, NDArray)) + return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), + c_array(NDArrayHandle, [value.handle for value in vals])) def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" def updater_handle(key, lhs_handle, rhs_handle, _): """ ctypes function """ - lhs = NDArray(NDArrayHandle(lhs_handle)) - rhs = NDArray(NDArrayHandle(rhs_handle)) + lhs = _ndarray_cls(NDArrayHandle(lhs_handle)) + rhs = _ndarray_cls(NDArrayHandle(rhs_handle)) updater(key, lhs, rhs) return updater_handle class KVStore(object): """A key-value store for synchronization of values, over multiple devices.""" - def __init__(self, handle, name2idx=None): + def __init__(self, handle): """Initializes a new KVStore. Parameters @@ -66,7 +55,6 @@ def __init__(self, handle, name2idx=None): """ assert isinstance(handle, KVStoreHandle) self.handle = handle - self.name2idx = name2idx if name2idx is not None else {} self._updater = None self._updater_func = None @@ -104,8 +92,7 @@ def init(self, key, value): >>> keys = ['5', '7', '9'] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, value) + ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals)) def push(self, key, value, priority=0): @@ -165,8 +152,7 @@ def push(self, key, value, priority=0): [[ 4. 4. 4.] [ 4. 4. 4.]] """ - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, value) + ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStorePushEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) @@ -184,6 +170,8 @@ def pull(self, key, out=None, priority=0): The returned values are gauranteed to be the latest values in the store. + For row_sparse values, please use `row_sparse_pull` instead. + Parameters ---------- key : int or list of int @@ -229,12 +217,88 @@ def pull(self, key, out=None, priority=0): [ 2. 2. 2.]] """ assert(out is not None) - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, out) + if not isinstance(out, (list, tuple)): + out = [out] + for val in out: + if not isinstance(val, (list, tuple)): + assert(val.stype == 'default') + else: + for v in val: + assert(v.stype == 'default') + ckeys, cvals = _ctype_key_value(key, out) check_call(_LIB.MXKVStorePullEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): + """ Pulls a single row_sparse value or a sequence of row_sparse values from the store + with specified row_ids. + + `row_sparse_pull` is executed asynchronously after all previous + `push`/`pull`/`row_sparse_pull` calls for the same input key(s) are finished. + + The returned values are guaranteed to be the latest values in the store. + + Parameters + ---------- + key : str or list of str + Keys. + + out: NDArray or list of NDArray or list of list of NDArray + Values corresponding to the keys. The stype is expected to be row_sparse + + priority : int, optional + The priority of the pull operation. + Higher priority pull operations are likely to be executed before + other pull actions. + + row_ids : NDArray or list of NDArray + The row_ids for which to pull for each value. The row_ids doesn't have to be unique + or sorted. + Examples + -------- + >>> shape = (3, 3) + >>> kv.init('3', mx.nd.ones(shape)._to_rsp()) + >>> a = mx.nd.zeros(shape) + >>> row_ids = mx.nd.array([0, 2], dtype='int64') + >>> kv.row_sparse_pull('3', out=a, row_ids=row_ids) + >>> print a.asnumpy() + [[ 1. 1. 1.] + [ 0. 0. 0.] + [ 1. 1. 1.]] + >>> duplicate_row_ids = mx.nd.array([2, 2], dtype='int64') + >>> kv.row_sparse_pull('3', out=a, row_ids=duplicate_row_ids) + >>> print a.asnumpy() + [[ 0. 0. 0.] + [ 0. 0. 0.] + [ 1. 1. 1.]] + >>> unsorted_row_ids = mx.nd.array([1, 0], dtype='int64') + >>> kv.row_sparse_pull('3', out=a, row_ids=unsorted_row_ids) + >>> print a.asnumpy() + [[ 1. 1. 1.] + [ 1. 1. 1.] + [ 0. 0. 0.]] + """ + assert(out is not None) + assert(row_ids is not None) + if isinstance(row_ids, NDArray): + row_ids = [row_ids] + if not isinstance(out, (list, tuple)): + out = [out] + for val in out: + if not isinstance(val, (list, tuple)): + assert(val.stype == 'row_sparse') + else: + for v in val: + assert(v.stype == 'row_sparse') + ckeys, cvals = _ctype_key_value(key, out) + _, crow_ids = _ctype_key_value(key, row_ids) + assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values" + + check_call(_LIB.MXKVStorePullRowSparse( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority))) + + def set_optimizer(self, optimizer): """ Registers an optimizer with the kvstore. @@ -407,7 +471,7 @@ def _send_command_to_servers(self, head, body): check_call(_LIB.MXKVStoreSendCommmandToServers( self.handle, mx_uint(head), c_str(body))) -def create(name='local', name2idx=None): +def create(name='local'): """Creates a new KVStore. For single machine training, there are two commonly used types: @@ -447,4 +511,4 @@ def create(name='local', name2idx=None): handle = KVStoreHandle() check_call(_LIB.MXKVStoreCreate(c_str(name), ctypes.byref(handle))) - return KVStore(handle, name2idx=name2idx) + return KVStore(handle) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index c91ef5474601..e30f9f332c8c 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -37,7 +37,7 @@ 'eval_metric', 'locals']) -def _create_kvstore(kvstore, num_device, arg_params, name2idx=None): +def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore This function select and create a proper kvstore if given the kvstore type. @@ -61,7 +61,7 @@ def _create_kvstore(kvstore, num_device, arg_params, name2idx=None): # no need to use kv for single device and single machine kv = None else: - kv = kvs.create(kvstore, name2idx=name2idx) + kv = kvs.create(kvstore) if kvstore == 'local': # automatically select a proper local max_size = max(np.prod(param.shape) for param in @@ -76,15 +76,29 @@ def _create_kvstore(kvstore, num_device, arg_params, name2idx=None): return (kv, update_on_kvstore) -def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, - update_on_kvstore): +def _contains_non_default_storage(params): + if isinstance(params, (list, tuple)): + for param in params: + if param.stype != 'default': + return True + elif isinstance(params, NDArray): + return param.stype != 'default' + else: + return False + +def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore): """Initialize kvstore""" for idx, param_on_devs in enumerate(param_arrays): name = param_names[idx] kvstore.init(name, arg_params[name]) if update_on_kvstore: - kvstore.pull(name, param_on_devs, priority=-idx) + if _contains_non_default_storage(param_on_devs): + # skip pulling row_sparse weights + warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ + 'sure to pull it with row_ids explicitly', RuntimeWarning) + else: + kvstore.pull(name, param_on_devs, priority=-idx) def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" @@ -96,7 +110,12 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) # pull back the weights - kvstore.pull(name, arg_list, priority=-index) + if _contains_non_default_storage(arg_list): + # skip pulling row_sparse weights + warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ + 'sure to pull it with row_ids', RuntimeWarning) + else: + kvstore.pull(name, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, kvstore=None, param_names=None): @@ -111,7 +130,12 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) # pull back the sum gradients, to the same locations. - kvstore.pull(name, grad_list, priority=-index) + if _contains_non_default_storage(grad_list): + # skip pulling row_sparse weights + warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ + 'sure to pull it with row_ids', RuntimeWarning) + else: + kvstore.pull(name, grad_list, priority=-index) for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index 820841087a9c..05076cec46b7 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -933,7 +933,8 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, def init_optimizer(self, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), force_init=False): - """Installs and initializes optimizers. + """Installs and initializes optimizers, as well as initialize kvstore for + distributed training Parameters ---------- diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 5288a32c3384..1594665bf5ef 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -453,12 +453,9 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', if self._params_dirty: self._sync_params_from_devices() - name2idx = {} - for idx, name in enumerate(self._exec_group.param_names): - name2idx[name] = idx (kvstore, update_on_kvstore) = \ - _create_kvstore(kvstore, len(self._context), self._arg_params, name2idx=name2idx) + _create_kvstore(kvstore, len(self._context), self._arg_params) batch_size = self._exec_group.batch_size if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: diff --git a/python/mxnet/ndarray/sparse_ndarray.py b/python/mxnet/ndarray/sparse_ndarray.py index bd98e58e9547..83f0683431ce 100644 --- a/python/mxnet/ndarray/sparse_ndarray.py +++ b/python/mxnet/ndarray/sparse_ndarray.py @@ -88,19 +88,24 @@ class SparseNDArray(NDArray): for more details. """ def __iadd__(self, other): - raise NotImplementedError("SparseND doesn't support __iadd__") + (self + other).copyto(self) + return self def __isub__(self, other): - raise NotImplementedError("SparseND doesn't support __isub__") + (self - other).copyto(self) + return self def __imul__(self, other): - raise NotImplementedError("SparseND doesn't support __imul__") + (self * other).copyto(self) + return self def __idiv__(self, other): - raise NotImplementedError("SparseND doesn't support __idiv__") + (self / other).copyto(self) + return self def __itruediv__(self, other): - raise NotImplementedError("SparseND doesn't support __itruediv__") + (self / other).copyto(self) + return self def __setitem__(self, key, value): """x.__setitem__(i, y) <=> x[i]=y @@ -179,14 +184,13 @@ def __getitem__(self, key): array([[ 3., 4., 5.]], dtype=float32) """ stype = self.stype - if stype != 'csr': - raise Exception("__getitem__ for " + str(stype) + " is not implemented yet") if isinstance(key, int): raise Exception("__getitem__ with int key is not implemented yet") if isinstance(key, py_slice): if key.step is not None: raise ValueError('NDArray only supports continuous slicing on axis 0') if key.start is not None or key.stop is not None: + assert(stype == 'csr'), "__getitem__ with slice is only implemented for CSRNDArray" begin = key.start if key.start else 0 end = key.stop if key.stop else self.shape[0] return nd_slice(self, begin=begin, end=end) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1f7335a2928b..5f176dc647b0 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -802,6 +802,24 @@ int MXKVStorePullEx(KVStoreHandle handle, API_END(); } +int MXKVStorePullRowSparse(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + const NDArrayHandle* row_ids, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector> v_val_rowids(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_val_rowids[i] = std::make_pair(static_cast(vals[i]), + *static_cast(row_ids[i])); + } + static_cast(handle)->PullRowSparse(v_keys, v_val_rowids, priority); + API_END(); +} + int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 3d522c83efac..46c83a7441e9 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -13,6 +13,7 @@ #include #include "mxnet/ndarray.h" #include "../ndarray/ndarray_function.h" +#include "../operator/tensor/indexing_op.h" namespace mxnet { namespace kvstore { /** @@ -48,6 +49,18 @@ class Comm { int key, const NDArray& src, const std::vector dst, int priority) = 0; + /** + * \brief broadcast src to dst[i] with target row_ids for every i + * \param dst a list of destination row_sparse NDArray and its target row_ids to broadcast, + where the row_ids are expected to be unique and sorted + * \param use_copy if set to true, directly copy src to dst[i] without looking up the + provided row_ids + */ + virtual void BroadcastRowSparse(int key, const NDArray& src, + const std::vector>& dst, + const bool use_copy, + const int priority) = 0; + /** * \brief return a pinned contex */ @@ -164,6 +177,38 @@ class CommCPU : public Comm { } } + // TODO(haibin) support broadcast row_sparse on GPU + void BroadcastRowSparse(int key, const NDArray& src, + const std::vector>& dst, + const bool use_copy, + const int priority) override { + using namespace mshadow; + auto size = dst.size(); + for (size_t i = 0; i < size; i++) { + auto out = dst[i].first; + auto row_id = dst[i].second; + if (use_copy) { + CopyFromTo(src, out, priority); + } else { + CHECK_EQ(out->storage_type(), kRowSparseStorage) + << "BroadcastRowSparse expects row_sparse dst NDArray"; + CHECK_EQ(out->ctx().dev_mask(), Context::kCPU) + << "BroadcastRowSparse with dst on gpu context not supported"; + CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU) + << "BroadcastRowSparse with src on gpu context not supported"; + // retain according to unique indices + Engine::Get()->PushSync([src, out, row_id](RunContext rctx) { + NDArray *output = out; + const auto indices = row_id.data(); + op::SparseRetainOpForwardRspImpl(rctx.get_stream(), + src, indices, kWriteTo, + output); + }, Context::CPU(), {src.var(), row_id.var()}, {out->var()}, + FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain")); + } + } + } + private: // reduce sum into val[0] inline void ReduceSumCPU(const std::vector &in_data) { @@ -180,7 +225,6 @@ class CommCPU : public Comm { } // serial implementation of reduce sum for row sparse NDArray. - // TODO(haibin) use openmp kernel to parallelize the summation inline void ReduceSumCPUExSerial(const std::vector &in, NDArray *out) { using namespace rowsparse; using namespace mshadow; @@ -410,6 +454,13 @@ class CommDevice : public Comm { } } + void BroadcastRowSparse(int key, const NDArray& src, + const std::vector>& dst, + const bool use_copy, + const int priority) override { + LOG(FATAL) << "Not implemented yet"; + } + private: void EnableP2P(const std::vector& devs) { #if MXNET_USE_CUDA diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 59d9158012ef..9a3bc31f4ac4 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -7,11 +7,12 @@ #define MXNET_KVSTORE_KVSTORE_DIST_H_ #include #include +#include +#include #include "./kvstore_local.h" #include "mxnet/engine.h" #include "ps/ps.h" #include "./kvstore_dist_server.h" -#include "../operator/tensor/init_op.h" #if MKL_EXPERIMENTAL == 1 #include #include "../operator/mkl/mkl_memory-inl.h" @@ -43,7 +44,7 @@ class KVStoreDist : public KVStoreLocal { } } bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); - row_sparse_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); + log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } virtual ~KVStoreDist() { @@ -100,52 +101,82 @@ class KVStoreDist : public KVStoreLocal { // after the previous push on this key auto& recv_buf = comm_buf_[key]; const auto storage_type = grouped_vals[i][0]->storage_type(); + CHECK_EQ(storage_type, kDefaultStorage) + << "Expected stype of value to be kDefaultStorage"; if (recv_buf.is_none()) { // it may happen for the first time a no-rank-0 worker pull the weight. - if (storage_type == kDefaultStorage) { - recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_, - false, grouped_vals[i][0]->dtype()); - } else { - recv_buf = NDArray(storage_type, grouped_vals[i][0]->shape(), - pinned_ctx_, true, grouped_vals[i][0]->dtype()); - // initialize the buffer with sufficient memory - op::FillDnsZerosRspImpl(nullptr, &recv_buf); - } + recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_, + true, grouped_vals[i][0]->dtype()); } - if (storage_type == kDefaultStorage) { #if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(recv_buf.data()); + mkl_set_tblob_eager_mode(recv_buf.data()); #endif - real_t* data = static_cast(recv_buf.data().dptr_); - size_t size = recv_buf.shape().Size(); - auto pull_from_servers = [this, key, data, size]( - RunContext rctx, Engine::CallbackOnComplete cb) { - // convert to ps keys - PSKV& pskv = EncodeKey(key, size); + real_t* data = static_cast(recv_buf.data().dptr_); + size_t size = recv_buf.shape().Size(); + auto pull_from_servers = [this, key, data, size]( + RunContext rctx, Engine::CallbackOnComplete cb) { + // convert to ps keys + PSKV& pskv = EncodeKey(key, size); + + // issue pull, false means no delete + auto vals = new ps::SArray(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPull( + pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); }); + }; + + CHECK_NOTNULL(Engine::Get())->PushAsync( + pull_from_servers, + pinned_ctx_, + {}, + {recv_buf.var()}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistDefaultPull")); - // issue pull, false means no delete - auto vals = new ps::SArray(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPull( - pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); }); - }; + comm_->Broadcast(key, recv_buf, grouped_vals[i], priority); + } + } - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, - pinned_ctx_, - {}, - {recv_buf.var()}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistDefaultPull")); - } else if (storage_type == kRowSparseStorage) { - recv_buf.WaitToRead(); - grouped_vals[i][0]->WaitToRead(); - PullRowSparse(key, &recv_buf, grouped_vals[i][0]->aux_ndarray(rowsparse::kIdx), priority); + void PullRowSparse(const std::vector& keys, + const std::vector>& val_rowids, + const int priority = 0) { + std::vector uniq_keys; + std::vector>> grouped_val_rowids; + GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + + for (size_t i = 0; i < uniq_keys.size(); ++i) { + int key = uniq_keys[i]; + // use the same array for merging to guarantee that pull always happens + // after the previous push on this key + auto& recv_buf = comm_buf_[key]; + auto& grouped_val_rowid = grouped_val_rowids[i]; + const auto storage_type = grouped_val_rowid[0].first->storage_type(); + CHECK_EQ(storage_type, kRowSparseStorage) + << "expected kRowSparseStorage, but got " << storage_type; + if (recv_buf.is_none()) { + // it may happen for the first time a no-rank-0 worker pull the weight. + recv_buf = NDArray(storage_type, grouped_val_rowid[0].first->shape(), + pinned_ctx_, true, grouped_val_rowid[0].first->dtype()); + } + auto &target_val_rowids = grouped_val_rowids[i]; + const size_t num_vals = target_val_rowids.size(); + size_t num_rows = 0; + // TODO(haibin) refactor this for loop + for (size_t i = 0; i < num_vals; i++) { + auto &row_id = target_val_rowids[i].second; + NDArray indices = row_id.Copy(pinned_ctx_); + Unique(&indices, priority); + target_val_rowids[i].second = indices; + num_rows += indices.shape().Size(); + } + if (num_vals > 1) { + // TODO(haibin) aggregate over all unique indices + LOG(FATAL) << "RowSparsePull with multiple values is not implemented yet"; } else { - LOG(FATAL) << "unknown storage type " << storage_type; + auto& indices = target_val_rowids[0].second; + PullRowSparse_(key, &recv_buf, indices, priority); + comm_->BroadcastRowSparse(key, recv_buf, grouped_val_rowid, num_vals == 1, priority); } - - comm_->Broadcast(key, recv_buf, grouped_vals[i], priority); } } @@ -223,6 +254,8 @@ class KVStoreDist : public KVStoreLocal { auto& send_buf = comm_buf_[key]; const auto storage_type = merged.storage_type(); if (merged.ctx().dev_mask() == cpu::kDevMask) { + // make sure the previous push/pull is completed + send_buf.WaitToWrite(); send_buf = merged; // avoid memory copy } else { if (send_buf.is_none()) { @@ -230,8 +263,6 @@ class KVStoreDist : public KVStoreLocal { send_buf = NDArray(merged.shape(), pinned_ctx_, false, merged.dtype()); } else { send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype()); - // initialize the buffer with sufficient memory - op::FillDnsZerosRspImpl(nullptr, &send_buf); } } CopyFromTo(merged, &send_buf); @@ -271,29 +302,34 @@ class KVStoreDist : public KVStoreLocal { } // pull row sparse weight into `recv_buf` based on indices given by `indices` - void PullRowSparse(int key, NDArray *recv_buf, const NDArray indices, int priority) { + void PullRowSparse_(int key, NDArray *recv_buf, const NDArray& indices, int priority) { using namespace rowsparse; - auto pull_from_servers = [this, key, recv_buf, &indices] + auto pull_from_servers = [this, key, recv_buf, indices] (RunContext rctx, Engine::CallbackOnComplete cb) { - // reading aux_shape & aux_data should be inside the engine + // allocate memory for the buffer size_t num_rows = indices.shape().Size(); recv_buf->CheckAndAlloc({mshadow::Shape1(num_rows)}); #if MKL_EXPERIMENTAL == 1 mkl_set_tblob_eager_mode(recv_buf->data()); #endif real_t* data = static_cast(recv_buf->data().dptr_); - const auto offsets = indices.data().dptr(); + auto indices_data = indices.data(); + const auto offsets = indices_data.dptr(); const auto unit_len = recv_buf->shape().ProdShape(1, recv_buf->shape().ndim()); - size_t size = num_rows * unit_len; + const int64_t size = num_rows * unit_len; // convert to ps keys in row sparse format - PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, unit_len); - if (this->row_sparse_verbose_) { - LOG(INFO) << "pull lens: " << pskv.lens << " keys: " << pskv.keys - << " size: " << size; + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, + unit_len, recv_buf->shape()[0]); + if (this->log_verbose_) { + LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: " + << pskv.keys << " size: " << size; } auto vals = new ps::SArray(data, size, false); CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull, [vals, cb]() { delete vals; cb(); }); + // copy indices to recv_buf + mshadow::Copy(recv_buf->aux_data(kIdx).FlatTo1D(), + indices_data.FlatTo1D()); }; CHECK_NOTNULL(Engine::Get())->PushAsync( pull_from_servers, @@ -303,10 +339,6 @@ class KVStoreDist : public KVStoreLocal { FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreDistRowSparsePull")); - recv_buf->WaitToRead(); - // copy indices pulled - auto recv_buf_idx = recv_buf->aux_ndarray(kIdx); - CopyFromTo(indices, &recv_buf_idx); } // push row sparse gradient @@ -318,17 +350,18 @@ class KVStoreDist : public KVStoreLocal { mkl_set_tblob_eager_mode(send_buf.data()); #endif real_t* data = static_cast(send_buf.data().dptr_); - if (!send_buf.storage_initialized()) return; - size_t num_rows = send_buf.aux_shape(kIdx).Size(); - const auto offsets = send_buf.aux_data(kIdx).dptr(); + bool init = send_buf.storage_initialized(); + const int64_t num_rows = init ? send_buf.aux_shape(kIdx)[0] : 0; + const auto offsets = init ? send_buf.aux_data(kIdx).dptr() : nullptr; const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim()); - const auto size = num_rows * unit_len; + const int64_t size = num_rows * unit_len; // convert to ps keys in row sparse format - PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, unit_len); - if (this->row_sparse_verbose_) { - LOG(INFO) << "push lens: " << pskv.lens << " keys: " << pskv.keys - << " size: " << size; + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, + unit_len, send_buf.shape()[0]); + if (this->log_verbose_) { + LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens << " keys: " + << pskv.keys << " size: " << size; } ps::SArray vals(data, size, false); CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() { @@ -417,8 +450,11 @@ class KVStoreDist : public KVStoreLocal { return pskv; } - inline PSKV& EncodeRowSparseKey(int key, size_t size, int64_t num_rows, - const int64_t *offsets, size_t unit_len) { + // TODO(haibin) this encoding method for row sparse keys doesn't allow cross-layer batching + inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows, + const int64_t *offsets, const size_t unit_len, + const int64_t total_num_rows) { + using namespace common; mu_.lock(); PSKV& pskv = ps_kv_[key]; mu_.unlock(); @@ -429,21 +465,45 @@ class KVStoreDist : public KVStoreLocal { int num_servers = krs.size(); CHECK_GT(num_servers, 0); - if (size >= bigarray_bound_ && row_sparse_verbose_) { - LOG(INFO) << "WARNING: big row_sparse weight array sharding is not implemented"; - } - // send it to a single random picked server - int server = (key * 9973) % num_servers; - ps::Key master_key = krs[server].begin() + key; - pskv.keys.push_back(master_key); - pskv.lens.push_back(0); - for (int64_t i = 0; i < num_rows; i++) { - ps::Key ps_key = krs[server].begin() + key + offsets[i]; - CHECK_LT(ps_key, krs[server].end()); - pskv.keys.push_back(ps_key); - pskv.lens.push_back(unit_len); + if (total_num_rows * unit_len >= bigarray_bound_) { + pskv.size = 0; + int64_t start_row = 0; + // parition it to all servers + for (int i = 0; i < num_servers; ++i) { + // calculate partition ranges + int64_t part_num_rows = + llround(static_cast(total_num_rows) / num_servers * (i + 1)) - + llround(static_cast(total_num_rows) / num_servers * i); + auto end_row = start_row + part_num_rows; + auto lb = std::lower_bound(offsets, offsets + num_rows, start_row); + auto ub = std::upper_bound(offsets, offsets + num_rows, end_row - 1); + ps::Key master_key = krs[i].begin() + key; + pskv.keys.push_back(master_key); + pskv.lens.push_back(0); + for (auto offset = lb; offset < ub; offset++) { + ps::Key ps_key = krs[i].begin() + key + (*offset - start_row); + CHECK_LT(ps_key, krs[i].end()); + pskv.keys.push_back(ps_key); + pskv.lens.push_back(unit_len); + pskv.size += unit_len; + } + start_row = end_row; + } + CHECK_EQ(static_cast(pskv.size), size); + } else { + // send it to a single random picked server + int server = (key * 9973) % num_servers; + ps::Key master_key = krs[server].begin() + key; + pskv.keys.push_back(master_key); + pskv.lens.push_back(0); + for (int64_t i = 0; i < num_rows; i++) { + ps::Key ps_key = krs[server].begin() + key + offsets[i]; + CHECK_LT(ps_key, krs[server].end()); + pskv.keys.push_back(ps_key); + pskv.lens.push_back(unit_len); + } + pskv.size = size; } - pskv.size = size; return pskv; } @@ -462,7 +522,7 @@ class KVStoreDist : public KVStoreLocal { size_t bigarray_bound_; /// \brief send & recver buffer std::unordered_map comm_buf_; - bool row_sparse_verbose_; + bool log_verbose_; }; } // namespace kvstore diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 59d2cb705654..9d1cd3c0bed5 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -15,6 +15,7 @@ #include #include "ps/ps.h" #include "mxnet/kvstore.h" +#include "../operator/tensor/elemwise_binary_op.h" namespace mxnet { namespace kvstore { @@ -96,6 +97,7 @@ class KVStoreDistServer { ps_server_->set_request_handle( std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); sync_mode_ = false; + log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } ~KVStoreDistServer() { @@ -120,6 +122,11 @@ class KVStoreDistServer { } private: + struct MergeBuf { + std::vector request; + NDArray array; + }; + void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) { if (recved.head == kStopServer) { exec_.Stop(); @@ -146,19 +153,40 @@ class KVStoreDistServer { return; } - inline void MergeUpdates(const NDArray& recved, int key, - std::unordered_set *change_set) { - auto& merged = merge_buf_[key]; - if (merged.is_none()) { - merged = NDArray(recved.shape(), Context()); - } - if (change_set->find(key) == change_set->end()) { - CopyFromTo(recved, &merged, 0); + inline void ApplyUpdates(const int key, MergeBuf *merged, NDArray *stored, + ps::KVServer* server) { + if (merged->request.size() == (size_t) ps::NumWorkers()) { + // let the main thread to execute updater_, which is necessary for python + if (updater_) { + exec_.Exec([this, key, merged, stored](){ + CHECK(updater_); + updater_(key, merged->array, stored); + }); + } else { + // if no updater, just copy + CopyFromTo(merged->array, stored); + } + if (log_verbose_) { + LOG(INFO) << "sync response to " << merged->request.size() << " workers"; + } + for (const auto& req : merged->request) { + server->Response(req); + } + merged->request.clear(); + stored->WaitToRead(); } else { - // TODO(haibin) handle row sparse gradient NDArray with `ReduceSumCPUExParallel` - merged += recved; + merged->array.WaitToRead(); + } + } + + void DecodeRowIds(const ps::SArray &keys, int64_t *indices, + const int64_t master_key, const int64_t num_rows) { + indices[0] = 0; + for (int64_t i = 1; i <= num_rows; i++) { + int key = DecodeKey(keys[i]); + auto row_id = key - master_key; + indices[i - 1] = row_id; } - change_set->insert(key); } void DataHandleRowSparse(const ps::KVMeta& req_meta, @@ -166,22 +194,26 @@ class KVStoreDistServer { ps::KVServer* server) { int master_key = DecodeKey(req_data.keys[0]); auto num_rows = req_data.keys.size() - 1; + auto& stored = store_[master_key]; if (req_meta.push) { + CHECK_GT(req_data.lens.size(), 0) << "req_data.lens cannot be empty"; CHECK_EQ(req_data.lens[0], 0); - CHECK_GT(req_data.lens.size(), 0); - auto unit_len = req_data.lens[1]; - CHECK_GT(unit_len, 0); real_t* data = req_data.vals.data(); - auto& stored = store_[master_key]; if (stored.is_none()) { - // LOG(INFO) << "initial push: " << master_key << " size = " << num_rows * unit_len; + if (log_verbose_) LOG(INFO) << "initial push: " << master_key; // initialization + CHECK_GT(num_rows, 0) << "init with empty data is not supported"; + auto unit_len = req_data.lens[1]; + CHECK_GT(unit_len, 0); size_t ds[] = {num_rows, (size_t) unit_len}; TShape dshape(ds, ds + 2); CHECK_EQ(req_data.vals.size(), num_rows * unit_len); TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*) NDArray recved = NDArray(recv_blob, 0); - stored = NDArray(dshape, Context()); + // TODO(haibin) temporarily initialized as dense NDArray. We need inplace operator + // support for rowsparse ndarrays. And after that `stored` should be initialized as + // RowSparse NDArray + stored = NDArray(kRowSparseStorage, dshape, Context()); CopyFromTo(recved, &stored, 0); stored.WaitToRead(); server->Response(req_meta); @@ -189,91 +221,104 @@ class KVStoreDistServer { } // synced push if (sync_mode_) { - // LOG(INFO) << "sync push: " << master_key; - size_t offset = 0; - auto& stored = store_[master_key]; - // merge updates - auto& request_buf = request_buf_[master_key]; - for (size_t i = 1; i <= num_rows; i++) { - // TODO(haibin) decode once and cache result - int key = DecodeKey(req_data.keys[i]); - auto len = req_data.lens[i]; - size_t ds[] = {(size_t)len}; - TShape dshape(ds, ds + 1); - TBlob recv_blob(data, // NOLINT(*) - dshape, cpu::kDevMask); - NDArray recved = NDArray(recv_blob, 0); - MergeUpdates(recved, key, &request_buf.change_set); - offset += len; + if (log_verbose_) LOG(INFO) << "sync push: " << master_key << " " << req_data.keys; + auto& merged = merge_buf_[master_key]; + if (merged.array.is_none()) { + merged.array = NDArray(kRowSparseStorage, stored.shape(), Context()); } - // perform updates - request_buf.requests.push_back(req_meta); - if (request_buf.requests.size() == (size_t) ps::NumWorkers()) { - // let the main thread to execute updater_, which is necessary for python - for (auto key : request_buf.change_set) { - // slice a row - auto row_id = key - master_key; - NDArray slice = stored.At(row_id); - NDArray update = merge_buf_[key]; - if (updater_) { - exec_.Exec([this, key, &update, &slice](){ - CHECK(updater_); - updater_(key, update, &slice); - }); - } else { - // if no updater, just copy - CopyFromTo(update, &slice); - } - slice.WaitToRead(); - } - request_buf.change_set.clear(); - // LOG(INFO) << "RESPONSE SYNC to " << request_buf.requests.size() << " clients"; - for (const auto& req : request_buf.requests) { - server->Response(req); + if (num_rows == 0) { + // reset to zeros + if (merged.request.size() == 0) { + merged.array = NDArray(kRowSparseStorage, stored.shape(), Context()); + } else { + // nothing to aggregate } - request_buf.requests.clear(); + merged.request.push_back(req_meta); + ApplyUpdates(master_key, &merged, &stored, server); + return; + } + auto unit_len = req_data.lens[1]; + CHECK_GT(unit_len, 0); + // indices + std::vector indices(num_rows); + DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows); + // data + TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask); + size_t ds[] = {(size_t) num_rows, (size_t) unit_len}; + TShape dshape(ds, ds + 2); + TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*) + // row_sparse NDArray + NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0); + + if (merged.request.size() == 0) { + CopyFromTo(recved, &merged.array, 0); } else { - for (size_t i = 1; i <= num_rows; i++) { - int key = DecodeKey(req_data.keys[i]); - merge_buf_[key].WaitToRead(); - } + NDArray out(kRowSparseStorage, stored.shape(), Context()); + std::vector const_vars; + const_vars.push_back(recved.var()); + const_vars.push_back(merged.array.var()); + // accumulate row_sparse gradients + // TODO(haibin) override + operator for row_sparse NDArray + // instead of calling BinaryComputeRspRsp directly + using namespace mshadow; + Engine::Get()->PushSync([recved, merged, out](RunContext ctx) { + std::vector inputs, outputs; + inputs.push_back(recved); + inputs.push_back(merged.array); + outputs.push_back(out); + op::BinaryComputeRspRspImpl({}, {}, inputs, {kWriteTo}, outputs); + }, recved.ctx(), const_vars, {out.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + CopyFromTo(out, &merged.array, 0); } + merged.request.push_back(req_meta); + ApplyUpdates(master_key, &merged, &stored, server); } else { // async push - auto& stored = store_[master_key]; - for (size_t i = 1; i <= num_rows; i++) { - int key = DecodeKey(req_data.keys[i]); - auto row_id = key - master_key; - auto len = req_data.lens[i]; - size_t ds[] = {(size_t)len}; - TShape dshape(ds, ds + 1); - TBlob recv_blob(data, // NOLINT(*) - dshape, cpu::kDevMask); - NDArray recved = NDArray(recv_blob, 0); - NDArray slice = stored.At(row_id); - exec_.Exec([this, key, &recved, &slice](){ - CHECK(updater_); - updater_(key, recved, &slice); - }); + if (log_verbose_) LOG(INFO) << "async push: " << master_key; + if (num_rows == 0) { + server->Response(req_meta); + return; } + auto unit_len = req_data.lens[1]; + CHECK_GT(unit_len, 0); + // indices + std::vector indices(num_rows); + DecodeRowIds(req_data.keys, indices.data(), master_key, num_rows); + TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask); + size_t ds[] = {(size_t) num_rows, (size_t) unit_len}; + TShape dshape(ds, ds + 2); + TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*) + NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0); + exec_.Exec([this, master_key, &recved, &stored](){ + CHECK(updater_); + updater_(master_key, recved, &stored); + }); server->Response(req_meta); stored.WaitToRead(); } } else { // pull + if (log_verbose_) LOG(INFO) << "pull: " << master_key; ps::KVPairs response; - auto& stored = store_[master_key]; + if (num_rows == 0) { + std::vector lens(req_data.keys.size(), 0); + response.keys = req_data.keys; + response.lens.CopyFrom(lens.begin(), lens.end()); + server->Response(req_meta, response); + return; + } CHECK(!stored.is_none()) << "init " << master_key << " first"; auto shape = stored.shape(); auto unit_len = shape.ProdShape(1, shape.ndim()); const float* data = stored.data().dptr(); auto len = unit_len * num_rows; - // LOG(INFO) << "received pull: " << len; - // concat response values + // concat values response.vals.resize(len); for (size_t i = 1; i <= num_rows; i++) { int key = DecodeKey(req_data.keys[i]); - const auto src = data + key * unit_len; + int64_t row_id = key - master_key; + const auto src = data + row_id * unit_len; auto begin = (i - 1) * unit_len; auto end = i * unit_len; response.vals.segment(begin, end).CopyFrom(src, unit_len); @@ -319,30 +364,16 @@ class KVStoreDistServer { } else if (sync_mode_) { // synced push auto& merged = merge_buf_[key]; - auto& request_buf = request_buf_[key]; - MergeUpdates(recved, key, &request_buf.change_set); - request_buf.requests.push_back(req_meta); - if (request_buf.requests.size() == (size_t) ps::NumWorkers()) { - CHECK_EQ(request_buf.change_set.size(), 1); - // let the main thread to execute updater_, which is necessary for python - if (updater_) { - exec_.Exec([this, key, &merged, &stored](){ - CHECK(updater_); - updater_(key, merged, &stored); - }); - } else { - // if no updater, just copy - CopyFromTo(merged, &stored); - } - request_buf.change_set.clear(); - for (const auto& req : request_buf.requests) { - server->Response(req); - } - request_buf.requests.clear(); - stored.WaitToRead(); + if (merged.array.is_none()) { + merged.array = NDArray(dshape, Context()); + } + if (merged.request.size() == 0) { + CopyFromTo(recved, &merged.array, 0); } else { - merged.WaitToRead(); + merged.array += recved; } + merged.request.push_back(req_meta); + ApplyUpdates(key, &merged, &stored, server); } else { // async push exec_.Exec([this, key, &recved, &stored](){ @@ -378,19 +409,13 @@ class KVStoreDistServer { KVStore::Updater updater_; std::unordered_map store_; - - struct RequestBuf { - std::vector requests; - std::unordered_set change_set; - }; - - std::unordered_map merge_buf_; - std::unordered_map request_buf_; - + std::unordered_map merge_buf_; Executor exec_; - ps::KVServer* ps_server_; + + // whether to LOG verbose information + bool log_verbose_; }; } // namespace kvstore diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index e159dd42e596..88a73a56a17f 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -106,6 +106,30 @@ class KVStoreLocal : public KVStore { } } + void PullRowSparse(const std::vector& keys, + const std::vector>& val_rowids, + int priority = 0) { + std::vector uniq_keys; + std::vector>> grouped_val_rowids; + GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + for (size_t i = 0; i < uniq_keys.size(); ++i) { + int key = uniq_keys[i]; + const NDArray& local = local_[key]; + CHECK(!local.is_none()) << "key " << key << " has not been inited"; + CHECK_EQ(local.storage_type(), kRowSparseStorage) + << "PullRowSparse expects row_sparse src NDArray"; + auto &target_val_rowids = grouped_val_rowids[i]; + const size_t num_vals = target_val_rowids.size(); + for (size_t i = 0; i < num_vals; i++) { + auto &row_id = target_val_rowids[i].second; + NDArray indices = row_id.Copy(pinned_ctx_); + Unique(&indices, priority); + target_val_rowids[i].second = indices; + } + comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority); + } + } + void Push(const std::vector& str_keys, const std::vector& values, int priority) override { @@ -122,6 +146,14 @@ class KVStoreLocal : public KVStore { Pull(keys, values, priority); } + void PullRowSparse(const std::vector& str_keys, + const std::vector>& val_rowids, + const int priority = 0) override { + std::vector keys(str_keys.size()); + LookupKeys(str_keys, &keys); + PullRowSparse(keys, val_rowids, priority); + } + protected: /** * \brief group values on keys @@ -164,6 +196,28 @@ class KVStoreLocal : public KVStore { } } + /** + * \brief sort and get unique values. Output is expected to be on cpu_pinned context + */ + void Unique(NDArray *out, int priority = 0) { + CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask()) + << "Unique expects input with `pinned_ctx_`"; + Engine::Get()->PushSync([out](RunContext rctx) { + NDArray *output = out; + CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs"; + const auto size = out->shape()[0]; + auto out_data = output->data(); + MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, IType, { + auto dptr = output->data().dptr(); + common::ParallelSort(dptr, dptr + size, omp_get_max_threads()); + auto num_unique_idx = std::unique(dptr, dptr + size) - dptr; + *output = output->Reshape(mshadow::Shape1(num_unique_idx)); + }); + }, pinned_ctx_, {}, {out->var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique")); + out->WaitToRead(); + } + /// reducer and broadcaster Comm* comm_; /// pinned context diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 005bfa8fa1c6..f764aa48dd82 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -238,10 +238,7 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, const SGDParam& param = nnvm::get(attrs.parsed); auto weight_stype = inputs[0].storage_type(); auto grad_stype = inputs[1].storage_type(); - if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) { - TBlob out = outputs[0].data(); - SGDUpdateDnsRspImpl(param, ctx, inputs[0].data(), inputs[1], req[0], &out); - } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { + if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { NDArray out = outputs[0]; SGDUpdateRspRspImpl(param, ctx, inputs[0], inputs[1], req[0], &out); } else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) { @@ -502,12 +499,7 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, auto weight_stype = weight.storage_type(); auto grad_stype = grad.storage_type(); auto mom_stype = mom.storage_type(); - if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage && - mom_stype == kDefaultStorage) { - TBlob out = outputs[0].data(); - SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, - mom.data(), req[0], &out); - } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && + if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && mom_stype == kRowSparseStorage) { NDArray out = outputs[0]; SGDMomUpdateRspRspRspImpl(param, ctx, weight, grad, mom, req[0], &out); diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index d2ff99b9c195..2f7818951947 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -23,6 +23,8 @@ #include "../mxnet_op.h" #include "./sort_op.h" #include "./dot-inl.h" +#include "./init_op.h" +#include "./matrix_op-inl.h" namespace mxnet { namespace op { @@ -757,6 +759,35 @@ struct SparseRetainRspForward { } }; +template +void SparseRetainOpForwardRspImpl(mshadow::Stream *s, const NDArray &input, + const TBlob &indices, OpReqType req, + NDArray *output) { + using namespace rowsparse; + if (req == kNullOp || !input.storage_initialized() || indices.Size() == 0U) { + FillZerosRspImpl(s, output); + return; + } + const TBlob input_data = input.data(); + const TBlob input_idx = input.aux_data(rowsparse::kIdx); + output->CheckAndAlloc({mshadow::Shape1(indices.Size())}); + TBlob output_data = output->data(); + TBlob output_idx = output->aux_data(rowsparse::kIdx); + const index_t row_length = input_data.shape_.ProdShape(1, input_data.shape_.ndim()); + + using namespace mxnet_op; + MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type + MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(indices.type_flag_, IType, { // index array data type + Kernel::Launch(s, output_data.Size(), output_data.dptr()); + Kernel::Launch(s, indices.Size(), output_data.dptr(), + output_idx.dptr(), input_data.dptr(), input_idx.dptr(), + indices.dptr(), input_data.shape_[0], row_length); + }); + }); + }); +} + template void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -777,33 +808,9 @@ void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, const NDArray& input_nd = inputs[sr::kArr]; const TBlob idx_data = inputs[sr::kIdx].data(); - - if (req[sr::kOut] == kNullOp - || !input_nd.storage_initialized() - || idx_data.Size() == 0U) return; - - const TBlob input_data = input_nd.data(); - if (input_data.shape_[0] == 0) return; - const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx); - NDArray output_nd = outputs[sr::kOut]; - output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); - TBlob output_data = output_nd.data(); - TBlob output_idx = output_nd.aux_data(rowsparse::kIdx); - const auto row_length = input_data.shape_.ProdShape(1, input_data.shape_.ndim()); - - using namespace mxnet_op; - Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type - MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type - MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type - Kernel::Launch(s, output_data.Size(), output_data.dptr()); - Kernel::Launch(s, idx_data.Size(), output_data.dptr(), - output_idx.dptr(), input_data.dptr(), input_idx.dptr(), - idx_data.dptr(), input_data.shape_[0], row_length); - }); - }); - }); + mshadow::Stream *s = ctx.get_stream(); + SparseRetainOpForwardRspImpl(s, input_nd, idx_data, req[sr::kOut], &output_nd); } template diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index c30aaed13a7a..f88b412b027c 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -4,37 +4,31 @@ sys.path.insert(0, "../../python/") import mxnet as mx import numpy as np +import numpy.random as rnd import time -def check_diff_to_scalar(A, x): +def check_diff_to_scalar(A, x, rank=None): """ assert A == x""" - assert(np.sum(np.abs((A - x).asnumpy())) == 0), A.asnumpy() + assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) # setup keys = ['3', '5', '7'] rsp_keys = ['9', '11', '13'] rate = 2 -shape = (2, 2) -big_shape = (1200, 1200) # big than BIGARRAY_BOUND +shape = (2, 3) +big_shape = (1200, 1200) # bigger than BIGARRAY_BOUND def init_kv(): kv = mx.kv.create('dist_sync') - # init kv + # init kv dns keys kv.init(keys, [mx.nd.ones(shape)] * len(keys)) kv.init('99', mx.nd.ones(big_shape)) - my_rank = kv.rank - nworker = kv.num_workers - # init updater on servers - kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) - return kv, my_rank, nworker - -def init_kv_rsp(): - kv = mx.kv.create('dist_sync') - # init kv + # init kv row_sparse keys kv.init(rsp_keys, [mx.nd.ones(shape)._to_rsp()] * len(rsp_keys)) - # kv.init(99, mx.nd.ones(big_shape)) + kv.init('100', mx.nd.ones(big_shape)._to_rsp()) + # worker info my_rank = kv.rank nworker = kv.num_workers # init updater on servers @@ -43,57 +37,122 @@ def init_kv_rsp(): def test_sync_push_pull(): kv, my_rank, nworker = init_kv() - nrepeat = 3 - for i in range(nrepeat): - kv.push('3', mx.nd.ones(shape)*(my_rank+1)) - kv.push('99', mx.nd.ones(big_shape)*(my_rank+1)) - - num = (nworker + 1 ) * nworker * rate / 2 * nrepeat + 1 - val = mx.nd.zeros(shape) - kv.pull('3', out = val) - check_diff_to_scalar(val, num) - - val2 = mx.nd.zeros(big_shape) - kv.pull('99', out = val2) - check_diff_to_scalar(val2, num) - print('done') - -def test_sync_push_pull_row_sparse(): - kv, my_rank, nworker = init_kv_rsp() - nrepeat = 2 - - v = mx.nd.zeros(shape) - my_row = my_rank % shape[0] - for col in range(shape[1]): - v[my_row][col] = my_rank + 1 - - for i in range(nrepeat): - kv.push('9', v._to_rsp()) - # kv.push(99, mx.nd.ones(big_shape)*(my_rank+1)) - - # pull a subset of rows this worker is interested in - val = v.copyto(mx.cpu())._to_rsp() - kv.pull('9', out = val) - - expected = mx.nd.zeros(shape) - # initial value - for col in range(shape[1]): - expected[my_row][col] = 1 - # apply updates from workers - for rank in range(nworker): - row = rank % shape[0] - if row != my_row: - continue - for col in range(shape[1]): - expected[my_row][col] += (rank + 1) * rate * nrepeat - #print("expect ", expected.asnumpy()) - - check_diff_to_scalar(val, expected) - # print('done') - #val2 = mx.nd.zeros(big_shape) - #kv.pull(99, out = val2) - #check_diff_to_scalar(val2, num) + def check_default_keys(kv, my_rank, nworker): + nrepeat = 3 + for i in range(nrepeat): + kv.push('3', mx.nd.ones(shape)*(my_rank+1)) + kv.push('99', mx.nd.ones(big_shape)*(my_rank+1)) + + num = (nworker + 1) * nworker * rate / 2 * nrepeat + 1 + val = mx.nd.zeros(shape) + kv.pull('3', out=val) + check_diff_to_scalar(val, num) + + val2 = mx.nd.zeros(big_shape) + kv.pull('99', out=val2) + check_diff_to_scalar(val2, num) + + def check_row_sparse_keys(kv, my_rank, nworker): + nrepeat = 3 + # prepare gradient + v = mx.nd.zeros(shape) + my_row = my_rank % shape[0] + v[my_row] = my_rank + 1 + # push + for i in range(nrepeat): + kv.push('9', v._to_rsp()) + # select a random subset of rows this worker is interested in + num_rows = shape[0] + row_ids_np = np.random.randint(num_rows, size=num_rows) + row_ids = mx.nd.array(row_ids_np, dtype='int64') + # perform pull + val = mx.nd.zeros(shape, stype='row_sparse') + kv.row_sparse_pull('9', out=val, row_ids=row_ids) + # prepare updated values + updated_val = mx.nd.ones(shape) + for rank in range(nworker): + row = rank % shape[0] + updated_val[row] += (rank + 1) * rate * nrepeat + # verify subset of updated values + expected = mx.nd.zeros(shape) + for row in row_ids_np: + expected[row] = updated_val[row] + check_diff_to_scalar(val, expected) + + def check_row_sparse_keys_with_zeros(kv, my_rank, nworker): + nrepeat = 3 + # prepare gradient + v = mx.nd.zeros(shape) + big_v = mx.nd.zeros(big_shape) + # push + for i in range(nrepeat): + kv.push('11', v._to_rsp()) + kv.push('100', big_v._to_rsp()) + + # pull a subset of rows this worker is interested in + all_row_ids = np.arange(shape[0]) + val = mx.nd.ones(shape)._to_rsp() + big_val = mx.nd.ones(big_shape)._to_rsp() + kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids, dtype='int64')) + big_num_rows = shape[0] + big_all_row_ids = np.arange(big_shape[0]) + kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array(big_all_row_ids, dtype='int64')) + # verify results + check_diff_to_scalar(val, mx.nd.ones(shape)) + check_diff_to_scalar(big_val, mx.nd.ones(big_shape)) + + def check_big_row_sparse_keys(kv, my_rank, nworker): + mx.random.seed(123) + rnd.seed(123) + density = 0.3 + nrepeat = 3 + # prepare gradient + v = mx.nd.zeros(big_shape) + idx_sample = rnd.rand(big_shape[0]) + indices = np.argwhere(idx_sample < density).flatten() + # each worker chooses a subset of the indices to update + update_rows = [] + for rank in range(nworker): + rows = [] + i = 0 + step = (rank + 1) * 2 + while i < len(indices): + rows.append(indices[i]) + i += step + update_rows.append(np.array(rows)) + # rows to update for this worker + for row in update_rows[my_rank]: + v[row] = my_rank + 1 + # push + for i in range(nrepeat): + kv.push('100', v._to_rsp()) + + # select a random subset of rows this worker is interested in + mx.random.seed(my_rank) + rnd.seed(my_rank) + num_rows = big_shape[0] + row_ids_np = np.random.randint(num_rows, size=num_rows) + row_ids = mx.nd.array(row_ids_np, dtype='int64') + # perform pull + val = mx.nd.zeros(big_shape, stype='row_sparse') + kv.row_sparse_pull('100', out=val, row_ids=row_ids) + # prepare expected result + updated_val = mx.nd.ones(big_shape) + # apply updates from each worker + for rank in range(nworker): + for row in update_rows[rank]: + updated_val[row] += (rank + 1) * rate * nrepeat + + expected = mx.nd.zeros(big_shape) + for row in row_ids_np: + expected[row] = updated_val[row] + check_diff_to_scalar(val, expected, rank=my_rank) + + check_default_keys(kv, my_rank, nworker) + check_row_sparse_keys(kv, my_rank, nworker) + check_row_sparse_keys_with_zeros(kv, my_rank, nworker) + check_big_row_sparse_keys(kv, my_rank, nworker) + print('worker ' + str(my_rank) + ' is done') if __name__ == "__main__": test_sync_push_pull() - test_sync_push_pull_row_sparse() diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 1489b8687c26..665467854977 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -16,13 +16,13 @@ def init_kv(stype='default'): kv.init(keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) return kv -def init_kv_with_str(): +def init_kv_with_str(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init('a', mx.nd.zeros(shape)) + kv.init('a', mx.nd.zeros(shape, stype=stype)) # list - kv.init(str_keys, [mx.nd.zeros(shape)] * len(keys)) + kv.init(str_keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) return kv def check_diff_to_scalar(A, x): @@ -41,6 +41,33 @@ def check_single_kv_pair(kv, key): check_single_kv_pair(init_kv(), 3) check_single_kv_pair(init_kv_with_str(), 'a') +def test_row_sparse_pull(): + kv = init_kv_with_str('row_sparse') + kv.init('e', mx.nd.ones(shape)._to_rsp()) + + def check_row_sparse_pull(kv, count): + num_rows = shape[0] + vals = [] + row_ids = [] + all_row_ids = np.arange(num_rows) + for i in range(count): + vals.append(mx.nd.zeros(shape)._to_rsp()) + row_id = np.random.randint(num_rows, size=num_rows) + row_ids.append(mx.nd.array(row_id, dtype='int64')) + row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids + vals_to_pull = vals[0] if len(vals) == 1 else vals + + kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull) + for val, row_id in zip(vals, row_ids): + retained = val.asnumpy() + excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy()) + for row in range(num_rows): + expected_val = np.zeros_like(retained[row]) + expected_val += 0 if row in excluded_row_ids else 1 + assert_almost_equal(retained[row], expected_val) + + check_row_sparse_pull(kv, 1) + check_row_sparse_pull(kv, 4) def test_init(): """test init""" @@ -53,7 +80,6 @@ def check_init(kv, key): check_init(mx.kv.create(), 3) check_init(mx.kv.create(), 'a') - def test_list_kv_pair(): """list key-value pair push & pull""" def check_list_kv_pair(kv, key): @@ -101,7 +127,7 @@ def test_sparse_aggregator(): """aggregate sparse ndarray on muliple devices""" stype = 'row_sparse' - kv = init_kv(stype) + kv = init_kv_with_str(stype) # devices num_devs = 4 @@ -113,8 +139,10 @@ def test_sparse_aggregator(): for v in vals: expected_sum += v.asnumpy() - kv.push(3, vals) - kv.pull(3, out = vals) + # prepare row_ids + all_rows = mx.nd.array(np.arange(shape[0]), dtype='int64') + kv.push('a', vals) + kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals)) result_sum = np.zeros(shape) for v in vals: result_sum += v.asnumpy() @@ -126,15 +154,14 @@ def test_sparse_aggregator(): for v in vals[0]: expected_sum += v.asnumpy() - kv.push(keys, vals) - kv.pull(keys, out = vals) + kv.push(str_keys, vals) + kv.row_sparse_pull(str_keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals)) for vv in vals: result_sum = np.zeros(shape) for v in vv: result_sum += v.asnumpy() assert_almost_equal(result_sum, expected_sum * num_devs) - def updater(key, recv, local): """use updater: +=""" local += recv @@ -179,7 +206,6 @@ def check_updater(kv, key, key_list): check_updater(str_kv, 'a', str_keys) - def test_get_type(): kvtype = 'local_allreduce_cpu' kv = mx.kv.create(kvtype) @@ -193,3 +219,4 @@ def test_get_type(): test_sparse_aggregator() test_aggregator() test_updater() + test_row_sparse_pull() diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index b6dac5f937d3..76e121e462dc 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -21,10 +21,10 @@ def sparse_nd_ones(shape, stype): def check_sparse_nd_elemwise_binary(shapes, stypes, f, g): # generate inputs nds = [] - for i, storage_type in enumerate(stypes): - if storage_type == 'row_sparse': - nd, _ = rand_sparse_ndarray(shapes[i], storage_type) - elif storage_type == 'default': + for i, stype in enumerate(stypes): + if stype == 'row_sparse': + nd, _ = rand_sparse_ndarray(shapes[i], stype) + elif stype == 'default': nd = mx.nd.array(random_arrays(shapes[i]), dtype = np.float32) else: assert(False) @@ -78,7 +78,7 @@ def check_sparse_nd_copy(from_stype, to_stype, shape): check_sparse_nd_copy('default', 'row_sparse', shape) check_sparse_nd_copy('default', 'csr', shape) check_sparse_nd_copy('row_sparse', 'row_sparse', shape_3d) - check_sparse_nd_copy('default', 'row_sparse', shape_3d) + def test_sparse_nd_basic(): def check_rsp_creation(values, indices, shape): @@ -140,8 +140,8 @@ def check_sparse_nd_setitem(stype, shape, dst): def test_sparse_nd_slice(): def check_sparse_nd_csr_slice(shape): - storage_type = 'csr' - A, _ = rand_sparse_ndarray(shape, storage_type) + stype = 'csr' + A, _ = rand_sparse_ndarray(shape, stype) A2 = A.asnumpy() start = rnd.randint(0, shape[0] - 1) end = rnd.randint(start + 1, shape[0]) @@ -232,7 +232,7 @@ def test_sparse_nd_lesser_equal(): def test_sparse_nd_binary(): - N = 100 + N = 10 def check_binary(fn): for _ in range(N): ndim = 2 @@ -270,7 +270,7 @@ def check_binary(fn): def test_sparse_nd_binary_rop(): - N = 100 + N = 10 def check(fn): for _ in range(N): ndim = 2 @@ -297,6 +297,33 @@ def check(fn): check(lambda x: 0.5 <= x) check(lambda x: 0.5 == x) +def test_sparse_nd_binary_iop(): + N = 10 + def check_binary(fn, stype): + for _ in range(N): + ndim = 2 + oshape = np.random.randint(1, 6, size=(ndim,)) + lshape = list(oshape) + rshape = list(oshape) + lhs = np.random.uniform(0, 1, size=lshape) + rhs = np.random.uniform(0, 1, size=rshape) + lhs_nd = mx.nd.cast_storage(mx.nd.array(lhs), stype=stype) + rhs_nd = mx.nd.cast_storage(mx.nd.array(rhs), stype=stype) + assert_allclose(fn(lhs, rhs), + fn(lhs_nd, rhs_nd).asnumpy(), + rtol=1e-4, atol=1e-4) + + def inplace_add(x, y): + x += y + return x + def inplace_mul(x, y): + x *= y + return x + stypes = ['csr', 'row_sparse'] + fns = [inplace_add, inplace_mul] + for stype in stypes: + for fn in fns: + check_binary(fn, stype) def test_sparse_nd_negate(): npy = np.random.uniform(-10, 10, rand_shape_2d())