From 116b0596f3a2104a6105dae15d17661fee8e401c Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sun, 19 Nov 2017 15:51:42 -0800 Subject: [PATCH] Revert "2bit gradient compression (#8662)" (#8711) This reverts commit a499f892c9ee6f59ccfb57c9e431c91014078891. --- example/image-classification/common/fit.py | 44 +-- example/rnn/lstm_bucketing.py | 1 + include/mxnet/c_api.h | 13 - include/mxnet/kvstore.h | 15 - python/mxnet/gluon/trainer.py | 12 +- python/mxnet/kvstore.py | 62 ---- python/mxnet/module/bucketing_module.py | 17 +- python/mxnet/module/module.py | 11 +- src/c_api/c_api.cc | 14 - src/kvstore/comm.h | 87 +---- src/kvstore/gradient_compression-inl.h | 155 -------- src/kvstore/gradient_compression.cc | 193 ---------- src/kvstore/gradient_compression.cu | 40 --- src/kvstore/gradient_compression.h | 138 -------- src/kvstore/kvstore.cc | 2 +- src/kvstore/kvstore_dist.h | 388 ++++++--------------- src/kvstore/kvstore_dist_server.h | 143 +------- src/kvstore/kvstore_local.h | 7 - tests/nightly/dist_sync_kvstore.py | 120 +------ tests/nightly/test_kvstore.py | 200 +---------- tools/bandwidth/measure.py | 6 +- 21 files changed, 167 insertions(+), 1501 deletions(-) delete mode 100644 src/kvstore/gradient_compression-inl.h delete mode 100644 src/kvstore/gradient_compression.cc delete mode 100644 src/kvstore/gradient_compression.cu delete mode 100644 src/kvstore/gradient_compression.h diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index 2b002c770266..51a1abec7c48 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -103,11 +103,6 @@ def add_fit_args(parser): help='1 means test reading speed without training') train.add_argument('--dtype', type=str, default='float32', help='precision: float32 or float16') - train.add_argument('--gc-type', type=str, default='none', - help='type of gradient compression to use, \ - takes `2bit` or `none` for now') - train.add_argument('--gc-threshold', type=float, default=0.5, - help='threshold for 2bit gradient compression') return train def fit(args, network, data_loader, **kwargs): @@ -119,9 +114,6 @@ def fit(args, network, data_loader, **kwargs): """ # kvstore kv = mx.kvstore.create(args.kv_store) - if args.gc_type != 'none': - kv.set_gradient_compression({'type': args.gc_type, - 'threshold': args.gc_threshold}) # logging head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' @@ -170,10 +162,10 @@ def fit(args, network, data_loader, **kwargs): lr_scheduler = lr_scheduler optimizer_params = { - 'learning_rate': lr, - 'wd' : args.wd, - 'lr_scheduler': lr_scheduler, - 'multi_precision': True} + 'learning_rate': lr, + 'wd' : args.wd, + 'lr_scheduler': lr_scheduler, + 'multi_precision': True} # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag'} @@ -203,17 +195,17 @@ def fit(args, network, data_loader, **kwargs): # run model.fit(train, - begin_epoch = args.load_epoch if args.load_epoch else 0, - num_epoch = args.num_epochs, - eval_data = val, - eval_metric = eval_metrics, - kvstore = kv, - optimizer = args.optimizer, - optimizer_params = optimizer_params, - initializer = initializer, - arg_params = arg_params, - aux_params = aux_params, - batch_end_callback = batch_end_callbacks, - epoch_end_callback = checkpoint, - allow_missing = True, - monitor = monitor) + begin_epoch = args.load_epoch if args.load_epoch else 0, + num_epoch = args.num_epochs, + eval_data = val, + eval_metric = eval_metrics, + kvstore = kv, + optimizer = args.optimizer, + optimizer_params = optimizer_params, + initializer = initializer, + arg_params = arg_params, + aux_params = aux_params, + batch_end_callback = batch_end_callbacks, + epoch_end_callback = checkpoint, + allow_missing = True, + monitor = monitor) diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/lstm_bucketing.py index 0e7f064f0078..2e7bc65d437a 100644 --- a/example/rnn/lstm_bucketing.py +++ b/example/rnn/lstm_bucketing.py @@ -48,6 +48,7 @@ parser.add_argument('--disp-batches', type=int, default=50, help='show progress for every n batches') + def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0): if not os.path.isfile(fname): raise IOError("Please use get_ptb_data.sh to download requied file (data/ptb.train.txt)") diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a81193e6735c..77fc6a5f5080 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1550,19 +1550,6 @@ MXNET_DLL int MXInitPSEnv(mx_uint num_vars, */ MXNET_DLL int MXKVStoreCreate(const char *type, KVStoreHandle *out); - -/*! - * \brief Set parameters to use low-bit compressed gradients - * \param handle handle to the kvstore - * \param keys keys for compression parameters - * \param vals values for compression parameters - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle, - mx_uint num_params, - const char** keys, - const char** vals); - /*! * \brief Delete a KVStore handle. * \param handle handle to the kvstore diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 4e99a9c861f2..1649c4368079 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -31,7 +31,6 @@ #include #include #include -#include "../../src/kvstore/gradient_compression.h" #include "./ndarray.h" #if MXNET_USE_DIST_KVSTORE #include "ps/ps.h" @@ -66,14 +65,6 @@ class KVStore { */ inline const std::string& type() { return type_; } - /** - * \brief Set parameters to use low-bit compressed gradients - * \param compression_type type of compression - * \param threshold threshold for 2bit compression - */ - virtual void SetGradientCompression(const std::vector > - & kwargs) = 0; - /*! * \brief Initialize a list of key-value pair to the store. * @@ -397,12 +388,6 @@ class KVStore { */ std::string type_; - /** \brief Gradient compression object starts with GC_NONE mode - * Used if SetGradientCompression sets the type. - * Currently there is no support for un-setting gradient compression - */ - std::shared_ptr gradient_compression_; - /** * \brief whether to do barrier when finalize */ diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index f3a14609587f..115d1ff09ce5 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -44,11 +44,6 @@ class Trainer(object): kvstore : str or KVStore kvstore type for multi-gpu and distributed training. See help on :any:`mxnet.kvstore.create` for more information. - compression_params : dict - Specifies type of gradient compression and additional arguments depending - on the type of compression being used. For example, 2bit compression requires a threshold. - Arguments would then be {'type':'2bit', 'threshold':0.5} - See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. Properties ---------- @@ -56,8 +51,7 @@ class Trainer(object): The current learning rate of the optimizer. Given an Optimizer object optimizer, its learning rate can be accessed as optimizer.learning_rate. """ - def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', - compression_params=None): + def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'): if isinstance(params, (dict, ParameterDict)): params = list(params.values()) if not isinstance(params, (list, tuple)): @@ -71,7 +65,7 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', "First argument must be a list or dict of Parameters, " \ "got list of %s."%(type(param))) self._params.append(param) - self._compression_params = compression_params + optimizer_params = optimizer_params if optimizer_params else {} self._scale = optimizer_params.get('rescale_grad', 1.0) self._contexts = self._check_contexts() @@ -110,8 +104,6 @@ def _init_kvstore(self): kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts), arg_arrays) if kvstore: - if self._compression_params: - kvstore.set_gradient_compression(self._compression_params) if 'dist' in kvstore.type: update_on_kvstore = False for i, param in enumerate(self._params): diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index bf424559df8d..8625303ee40e 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -64,16 +64,6 @@ def _ctype_key_value(keys, vals): else c_array_buf(ctypes.c_int, array('i', [keys] * len(vals))) return (c_keys, c_handle_array(vals), use_str_keys) -def _ctype_dict(param_dict): - """ - Returns ctype arrays for keys and values(converted to strings) in a dictionary - """ - assert(isinstance(param_dict, dict)), \ - "unexpected type for param_dict: " + str(type(param_dict)) - c_keys = c_array(ctypes.c_char_p, [c_str(k) for k in param_dict.keys()]) - c_vals = c_array(ctypes.c_char_p, [c_str(str(v)) for v in param_dict.values()]) - return (c_keys, c_vals) - def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" def updater_handle(key, lhs_handle, rhs_handle, _): @@ -360,58 +350,6 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): check_call(_LIB.MXKVStorePullRowSparse( self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority))) - def set_gradient_compression(self, compression_params): - """ Specifies type of low-bit quantization for gradient compression \ - and additional arguments depending on the type of compression being used. - - 2bit Gradient Compression takes a positive float `threshold`. - The technique works by thresholding values such that positive values in the - gradient above threshold will be set to threshold. Negative values whose absolute - values are higher than threshold, will be set to the negative of threshold. - Values whose absolute values are less than threshold will be set to 0. - By doing so, each value in the gradient is in one of three states. 2bits are - used to represent these states, and every 16 float values in the original - gradient can be represented using one float. This compressed representation - can reduce communication costs. The difference between these thresholded values and - original values is stored at the sender's end as residual and added to the - gradient in the next iteration. - - When kvstore is 'local', gradient compression is used to reduce communication - between multiple devices (gpus). Gradient is quantized on each GPU which - computed the gradients, then sent to the GPU which merges the gradients. This - receiving GPU dequantizes the gradients and merges them. Note that this - increases memory usage on each GPU because of the residual array stored. - - When kvstore is 'dist', gradient compression is used to reduce communication - from worker to sender. Gradient is quantized on each worker which - computed the gradients, then sent to the server which dequantizes - this data and merges the gradients from each worker. Note that this - increases CPU memory usage on each worker because of the residual array stored. - Only worker to server communication is compressed in this setting. - If each machine has multiple GPUs, currently this GPU to GPU or GPU to CPU communication - is not compressed. Server to worker communication (in the case of pull) - is also not compressed. - - To use 2bit compression, we need to specify `type` as `2bit`. - Only specifying `type` would use default value for the threshold. - To completely specify the arguments for 2bit compression, we would need to pass - a dictionary which includes `threshold` like: - {'type': '2bit', 'threshold': 0.5} - - Parameters - ---------- - compression_params : dict - A dictionary specifying the type and parameters for gradient compression. - The key `type` in this dictionary is a - required string argument and specifies the type of gradient compression. - Currently `type` can be only `2bit` - Other keys in this dictionary are optional and specific to the type - of gradient compression. - """ - ckeys, cvals = _ctype_dict(compression_params) - check_call(_LIB.MXKVStoreSetGradientCompression(self.handle, - mx_uint(len(compression_params)), - ckeys, cvals)) def set_optimizer(self, optimizer): """ Registers an optimizer with the kvstore. diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 0bea260cd3d9..fa92c5d1a1bf 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -54,16 +54,10 @@ class BucketingModule(BaseModule): Instead they are initialized to 0 and can be set by set_states() group2ctxs : list of dict of str to context Default is `None`. Mapping the `ctx_group` attribute to the context assignment. - compression_params : dict - Specifies type of gradient compression and additional arguments depending - on the type of compression being used. For example, 2bit compression requires a threshold. - Arguments would then be {'type':'2bit', 'threshold':0.5} - See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. """ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None, group2ctxs=None, - compression_params=None): + fixed_param_names=None, state_names=None, group2ctxs=None): super(BucketingModule, self).__init__(logger=logger) assert default_bucket_key is not None @@ -81,7 +75,6 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, _check_input_names(symbol, state_names, "state", True) _check_input_names(symbol, fixed_param_names, "fixed_param", True) - self._compression_params = compression_params self._fixed_param_names = fixed_param_names self._state_names = state_names self._context = context @@ -330,9 +323,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, module = Module(symbol, data_names, label_names, logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names, - group2ctxs=self._group2ctxs, - compression_params=self._compression_params) + state_names=self._state_names, group2ctxs=self._group2ctxs) module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind=False, shared_module=None, grad_req=grad_req) self._curr_module = module @@ -362,9 +353,7 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names, - group2ctxs=self._group2ctxs, - compression_params=self._compression_params) + state_names=self._state_names, group2ctxs=self._group2ctxs) module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index a9c6516a32ed..8301330313ae 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -61,16 +61,10 @@ class Module(BaseModule): Instead they are initialized to 0 and can be set by `set_states()`. group2ctxs : list of dict of str to context Default is `None`. Mapping the `ctx_group` attribute to the context assignment. - compression_params : dict - Specifies type of gradient compression and additional arguments depending - on the type of compression being used. For example, 2bit compression requires a threshold. - Arguments would then be {'type':'2bit', 'threshold':0.5} - See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. """ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None, group2ctxs=None, - compression_params=None): + fixed_param_names=None, state_names=None, group2ctxs=None): super(Module, self).__init__(logger=logger) if isinstance(context, ctx.Context): @@ -109,7 +103,6 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), self._aux_params = None self._params_dirty = False - self._compression_params = compression_params self._optimizer = None self._kvstore = None self._update_on_kvstore = None @@ -532,8 +525,6 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', self._updater = None if kvstore: - if self._compression_params: - kvstore.set_gradient_compression(self._compression_params) # copy initialized local parameters to kvstore _initialize_kvstore(kvstore=kvstore, param_arrays=self._exec_group.param_arrays, diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 027f00ba8762..0dde00443a0b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -748,20 +748,6 @@ int MXKVStoreCreate(const char *type, API_END(); } -int MXKVStoreSetGradientCompression(KVStoreHandle handle, mx_uint num_params, - const char** keys, const char** vals) { - API_BEGIN(); - std::vector > params; - for (mx_uint i = 0; i < num_params; ++i) { - std::pair p; - p.first = keys[i]; - p.second = vals[i]; - params.push_back(p); - } - static_cast(handle)->SetGradientCompression(params); - API_END(); -} - int MXKVStoreFree(KVStoreHandle handle) { API_BEGIN(); delete static_cast(handle); diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 5e15c2a085f1..fcf1e6b17f00 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -31,7 +31,6 @@ #include #include #include "mxnet/ndarray.h" -#include "gradient_compression.h" #include "../ndarray/ndarray_function.h" #include "../operator/tensor/sparse_retain-inl.h" namespace mxnet { @@ -81,18 +80,8 @@ class Comm { return pinned_ctx_; } - /** - * \brief Sets gradient compression parameters to be able to - * perform reduce with compressed gradients - */ - void SetGradientCompression(std::shared_ptr gc) { - gc_ = gc; - } - protected: Context pinned_ctx_; - - std::shared_ptr gc_; }; /** @@ -496,7 +485,14 @@ class CommDevice : public Comm { } } - void InitBuffersAndComm(const std::vector& src) { + const NDArray& Reduce(int key, const std::vector& src, + int priority) override { + // avoid extra copy for single device, but it may bring problems for + // abnormal usage of kvstore + if (src.size() == 1) { + return src[0]; + } + if (!inited_) { std::vector devs; for (const auto& a : src) { @@ -507,23 +503,7 @@ class CommDevice : public Comm { EnableP2P(devs); } } - } - - const NDArray& Reduce(int key, const std::vector& src, - int priority) override { - // when this reduce is called from kvstore_dist, gc is not set - // we don't do compression twice in dist_sync_device - if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) { - return ReduceCompressed(key, src, priority); - } - - // avoid extra copy for single device, but it may bring problems for - // abnormal usage of kvstore - if (src.size() == 1) { - return src[0]; - } - InitBuffersAndComm(src); auto& buf = merge_buf_[key]; std::vector reduce(src.size()); CopyFromTo(src[0], &(buf.merged), priority); @@ -546,52 +526,7 @@ class CommDevice : public Comm { } ElementwiseSum(reduce, &buf.merged); - return buf.merged; - } - - const NDArray& ReduceCompressed(int key, const std::vector& src, - int priority) { - InitBuffersAndComm(src); - auto& buf = merge_buf_[key]; - std::vector reduce(src.size()); - if (buf.copy_buf.empty()) { - // one buf for each context - buf.copy_buf.resize(src.size()); - buf.compressed_recv_buf.resize(src.size()); - buf.compressed_send_buf.resize(src.size()); - buf.residual.resize(src.size()); - for (size_t i = 0; i < src.size(); ++i) { - buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), - false, buf.merged.dtype()); - buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(), - false, buf.merged.dtype()); - buf.residual[i] = 0; - int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size()); - buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(), - false, buf.merged.dtype()); - buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(), - false, buf.merged.dtype()); - } - } - - for (size_t i = 0; i < src.size(); ++i) { - // compress before copy - // this is done even if the data is on same context as copy_buf because - // we don't want the training to be biased towards data on this GPU - gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority); - - if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) { - CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority); - } else { - // avoid memory copy when they are on same context - buf.compressed_recv_buf[i] = buf.compressed_send_buf[i]; - } - - gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority); - reduce[i] = buf.copy_buf[i]; - } - ElementwiseSum(reduce, &buf.merged); return buf.merged; } @@ -704,12 +639,6 @@ class CommDevice : public Comm { NDArray merged; /// \brief the gpu buffer std::vector copy_buf; - /// \brief the residual buffer for gradient compression - std::vector residual; - /// \brief the small buffer for compressed data in sender - std::vector compressed_send_buf; - /// \brief the small buffer for compressed data in receiver - std::vector compressed_recv_buf; }; std::unordered_map merge_buf_; bool inited_; diff --git a/src/kvstore/gradient_compression-inl.h b/src/kvstore/gradient_compression-inl.h deleted file mode 100644 index 9b69bd11472c..000000000000 --- a/src/kvstore/gradient_compression-inl.h +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file gradient_compression-inl.h - * \author Rahul Huilgol - * \brief Declares and defines functions used to quantize and dequantize data - */ -#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_ -#define MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_ - -#include -#include "../operator/mxnet_op.h" - -namespace mxnet { -namespace kvstore { - -// these gpu functions are defined in gradient_compression.cu -void Quantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, - const float threshold); -void Dequantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, - const float threshold); - -struct quantize_2bit { - MSHADOW_XINLINE static void Map(int out_block_id, - int original_size, - float *out, - float *grad, - float *residual, - const float neg_threshold, - const float pos_threshold) { - // this block contains the compressed representation of - // upto 16 values starting from out_block_id*16 - float *compr_block = out + out_block_id; - // init to 0 - *compr_block = 0; - // start and end are indices in original grad array - const int start = out_block_id << 4; - const int end = (start + 16 <= original_size) ? start + 16 : original_size; - // cast as char* to manipulate bits of float addresses - char *block_ptr = reinterpret_cast < char * > (compr_block); - // masks to set bits when value meets pos_threshold - // 0xc0 is mask when value is to be represented by the first two bits in a char* - // 0xc0 means first two bits are set to 11 - const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03}; - // masks to set bits when value meets neg_threshold - const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02}; - for (int i = start; i < end; i++) { - // adds offset to reach appropriate byte - char *curr_byte = block_ptr + ((i - start) >> 2); - // adds gradient to existing residual to get updated grad - residual[i] += grad[i]; - if (residual[i] >= pos_threshold) { - // set data to 11 - *curr_byte |= posbits[(i & 3)]; - // reduce residual by pos_threshold - residual[i] -= pos_threshold; - } else if (residual[i] <= neg_threshold) { - // set data to 10 - *curr_byte |= negbits[(i & 3)]; - residual[i] -= neg_threshold; - } - } - } -}; - -template -void Quantize2BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, - const float threshold) { - mxnet::op::mxnet_op::Kernel - ::Launch(s, - inputs[2].Size(), // compressed array size - inputs[0].Size(), // original size - inputs[2].dptr(), // compressed array - inputs[0].dptr(), // original array - inputs[1].dptr(), // residual array - -1 *threshold, // negative threshold - threshold); // positive threshold -} - -struct dequantize_2bit { - MSHADOW_XINLINE static void Map(int i, - float *out, - float *in, - const float neg_threshold, - const float pos_threshold) { - // get position of dequantized value to fill - float *outval = out + i; - // gets byte which holds quantized value for this position - char *ch_ptr = reinterpret_cast(in + (i >> 4)); - ch_ptr += ((i & 15) >> 2); - // masks used to quantize data - const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03}; - const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02}; - // col denotes which two bits of a byte are set for this value - // col=0 implies first two bits, col=3 implies last two bits,... - const int col = i & 3; - const uint8_t mask = posbits[col]; - const uint8_t negmask = negbits[col]; - const uint8_t masked = *ch_ptr & mask; - if (masked == mask) { - *outval = pos_threshold; - } else if (masked == negmask) { - // use posbits for mask as posbits are both 1s - // then compare masked with negbits to see if only negbits were set - *outval = neg_threshold; - } else { - *outval = 0; - } - } -}; - -template -void Dequantize2BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, - const float threshold) { - mxnet::op::mxnet_op::Kernel - ::Launch(s, - inputs[1].Size(), // original size - inputs[1].dptr(), // out array - inputs[0].dptr(), // compressed array - -1 *threshold, // negative threshold - threshold); // positive threshold -} - -inline void Quantize2BitImpl(mshadow::Stream *s, - const std::vector &inputs, - const float threshold) { - Quantize2BitKernelLaunch(s, inputs, threshold); -} - -inline void Dequantize2BitImpl(mshadow::Stream *s, - const std::vector &inputs, - const float threshold) { - Dequantize2BitKernelLaunch(s, inputs, threshold); -} -} // namespace kvstore -} // namespace mxnet - -#endif // MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_ diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc deleted file mode 100644 index b8c626cd53a8..000000000000 --- a/src/kvstore/gradient_compression.cc +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file gradient_compression.cc - * \brief Gradient compression for kvstore - * \author Rahul Huilgol - */ - -#include -#include -#include "gradient_compression.h" -#include "gradient_compression-inl.h" - -namespace mxnet { -namespace kvstore { - -/*! - * \brief Splits a string into smaller strings using char as delimiter - * Example: "a,b,c,,d" is split into ["a","b","c","","d"] - * \param s string to split - * \param delim char to split string around - * \param result container for tokens extracted after splitting - */ -template -void split(const std::string &s, const char delim, Out result) { - std::stringstream ss; - ss.str(s); - std::string item; - while (std::getline(ss, item, delim)) { - *(result++) = item; - } -} - -DMLC_REGISTER_PARAMETER(GradientCompressionParam); - -GradientCompression::GradientCompression() { - type_ = CompressionType::kNone; -} - -void GradientCompression::SetParams(const std::vector > - & kwargs) { - GradientCompressionParam params; - params.InitAllowUnknown(kwargs); - CHECK_GT(params.threshold, 0) << "threshold must be greater than 0"; - if (params.type == "2bit") { - SetTwoBitCompression(params.threshold); - } else { - LOG(FATAL) << "Unknown type for gradient compression " << params.type; - } -} - -CompressionType GradientCompression::get_type() { - return type_; -} - -std::string GradientCompression::get_type_str() { - return std::to_string(static_cast(type_)); -} - -void GradientCompression::SetTwoBitCompression(const float threshold) { - type_ = CompressionType::kTwoBit; - threshold_ = threshold; -} - -std::string GradientCompression::EncodeParams() { - using namespace std; // to reduce length of next line - string rval = get_type_str(); - if (type_ == CompressionType::kTwoBit) { - rval += "," + to_string(threshold_); - } - return rval; -} - -void GradientCompression::DecodeParams(const std::string &s) { - std::vector elems; - split(s, ',', std::back_inserter(elems)); - type_ = static_cast(stoi(elems[0])); - if (elems.size() > 1) { - if (!elems[1].empty()) { - threshold_ = stof(elems[1]); - } - } -} - -int GradientCompression::GetCompressionFactor() { - if (type_ == CompressionType::kTwoBit) { - return 16; - } else { - LOG(FATAL) << "Unsupported compression type: " << get_type_str(); - return 0; - } -} - -int64_t GradientCompression::GetCompressedSize(const int64_t original_size) { - const int bits = GetCompressionFactor(); - return ((original_size % bits == 0) ? - original_size / bits : - original_size / bits + 1); -} - -void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, - mxnet::NDArray *residual, const int priority) { - CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape"; - CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; - CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape"; - const int a = from.ctx().dev_mask(); - const int b = to->ctx().dev_mask(); - const float threshold = threshold_; - if (type_ == CompressionType::kTwoBit) { - if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { - mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), residual->data(), to->data()}; - Quantize2BitImpl(ctx.get_stream(), inputs, threshold); - }, from.ctx(), {from.var()}, {to->var(), residual->var()}, - mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeCPU")); - } else { -#if MXNET_USE_CUDA - if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { - mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), residual->data(), to->data()}; - Quantize2BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); - }, from.ctx(), {from.var()}, {to->var(), residual->var()}, - mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeGPU")); - } else { - LOG(FATAL) << "unknown device mask"; - } -#else - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; -#endif - } - } else { - LOG(FATAL) << "Unsupported quantization of type " << get_type_str(); - } -} - -void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, - const int priority) { - CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape"; - CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; - const int a = from.ctx().dev_mask(); - const int b = to->ctx().dev_mask(); - const float threshold = threshold_; - if (type_ == CompressionType::kTwoBit) { - if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { - mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), to->data()}; - Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); - }, from.ctx(), {from.var()}, {to->var()}, - mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeCPU")); - } else { -#if MXNET_USE_CUDA - if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { - mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), to->data()}; - Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); - }, from.ctx(), {from.var()}, {to->var()}, - mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeGPU")); - } else { - LOG(FATAL) << "unknown device mask"; - } -#else - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; -#endif - } - } else { - LOG(FATAL) << "Unsupported dequantization of type " << get_type_str(); - } -} - -} // namespace kvstore -} // namespace mxnet - diff --git a/src/kvstore/gradient_compression.cu b/src/kvstore/gradient_compression.cu deleted file mode 100644 index b0d9662520b2..000000000000 --- a/src/kvstore/gradient_compression.cu +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file gradient_compression.cu - * \author Rahul Huilgol - * \brief Implementation for gpu version of code - */ - -#include "gradient_compression-inl.h" - -namespace mxnet { -namespace kvstore { -void Quantize2BitImpl(mshadow::Stream* s, const std::vector& inputs, - const float threshold) { - Quantize2BitKernelLaunch(s, inputs, threshold); -} - -void Dequantize2BitImpl(mshadow::Stream* s, const std::vector& inputs, - const float threshold) { - Dequantize2BitKernelLaunch(s, inputs, threshold); -} -} // namespace kvstore -} // namespace mxnet diff --git a/src/kvstore/gradient_compression.h b/src/kvstore/gradient_compression.h deleted file mode 100644 index f40b45f5a513..000000000000 --- a/src/kvstore/gradient_compression.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file gradient_compression.h - * \brief Gradient compression for kvstore - * \author Rahul Huilgol - */ - -#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_H_ -#define MXNET_KVSTORE_GRADIENT_COMPRESSION_H_ -#include -#include -#include -#include -#include "mxnet/ndarray.h" - -namespace mxnet { -namespace kvstore { - -enum class CompressionType { - kNone, kTwoBit -}; - -struct GradientCompressionParam : public dmlc::Parameter { - std::string type; - float threshold; - DMLC_DECLARE_PARAMETER(GradientCompressionParam) { - DMLC_DECLARE_FIELD(type) - .describe("Type of gradient compression to use, like `2bit` for example"); - DMLC_DECLARE_FIELD(threshold).set_default(0.5) - .describe("Threshold to use for 2bit gradient compression"); - } -}; - -class GradientCompression { - public: - GradientCompression(); - - virtual ~GradientCompression() {} - - /*! - * \brief sets parameters for gradient compression - * \param kwargs a vector of pair of strings. A pair represents key and value - * of the parameter. Will be parsed by GradientCompressionParam - */ - void SetParams(const std::vector >& kwargs); - - /*! - * \brief returns type of compression if any - */ - CompressionType get_type(); - - /*! - * \brief returns as string the enum value of compression type - */ - std::string get_type_str(); - - /*! - * \brief sets two bit gradient compression - * \param threshold float value used for thresholding gradients - */ - void SetTwoBitCompression(const float threshold); - - /*! - * \brief encodes parameters of gc into a string - */ - std::string EncodeParams(); - - /*! - * \brief decodes parameters of gc from a string and assigns them to member variables - */ - void DecodeParams(const std::string &s); - - /*! - * \brief returns compression factor, which is the factor by which size of gradient - * reduces when using a particular type of compression - */ - int GetCompressionFactor(); - - /*! - * \brief returns the size of compressed gradients given an original sized gradient array - */ - int64_t GetCompressedSize(const int64_t original_size); - - /*! - * \brief Issues quantize operation to be scheduled by the engine - * Compresses `from` into `to` and accumulates the quantization error - * into 'residual', using the quantization of type `type_` - * \param from the ndarray containing original data to be quantized - * \param to the target ndarray which contains quantized data - * \param residual the ndarray which accumulates quantization error - * \param priority Priority of the action. - */ - void Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, - mxnet::NDArray *residual, const int priority); - - /*! - * \brief Issues dequantize operation to be scheduled by the engine - * Decompresses `from` into `to` using current parameters of `type` and `threshold` - * \param from the ndarray containing quantized data - * \param to the target ndarray which contains final dequantized data - * \param priority Priority of the action. - */ - void Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority); - - private: - /*! - * \brief denotes the type of gradient compression which has been set - */ - CompressionType type_; - - /*! - * \brief denotes threshold used for quantization and dequantization - * Must be a positive value. All positive gradients will be thresholded to `threshold_` and - * all negative gradients will be thresholded to -1*`threshold_` - */ - float threshold_ = 0; -}; -} // namespace kvstore -} // namespace mxnet -#endif // MXNET_KVSTORE_GRADIENT_COMPRESSION_H_ diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index ac158736f4d4..ac37d5d32cd0 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -49,7 +49,7 @@ KVStore* KVStore::Create(const char *type_name) { kv = new kvstore::KVStoreDist(use_device_comm); if (!has("_async") && kv->IsWorkerNode() && kv->get_rank() == 0) { // configure the server to be the sync mode - kv->SendCommandToServers(static_cast(kvstore::CommandType::kSyncMode), ""); + kv->SendCommandToServers(kvstore::kSyncMode, ""); } #else LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to use " << tname; diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index b00d0de935f7..571767db7ab3 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -69,7 +69,7 @@ class KVStoreDist : public KVStoreLocal { Barrier(); if (get_rank() == 0) { // stop the executor at servers - SendCommandToServers(static_cast(CommandType::kStopServer), ""); + SendCommandToServers(kStopServer, ""); } } ps::Finalize(barrier_before_exit_); @@ -86,15 +86,6 @@ class KVStoreDist : public KVStoreLocal { } } - void SetGradientCompression(const std::vector > - & kwargs) override { - KVStoreLocal::SetGradientCompression(kwargs); - if (get_rank() == 0) { - SendCommandToServers(static_cast(CommandType::kSetGradientCompression), - gradient_compression_->EncodeParams()); - } - } - void Barrier() override { ps::Postoffice::Get()->Barrier(ps::kWorkerGroup); } @@ -141,38 +132,6 @@ class KVStoreDist : public KVStoreLocal { } private: - /** - * \brief struct for ps keys and lens - */ - struct PSKV { - ps::SArray keys; // n keys - ps::SArray lens; // the length of the i-th value - int size; - }; - - struct ComprPSKV { - PSKV push; - PSKV pull; - }; - - /** - * \brief cache all key partitions - * - * `ps_kv_` is used for pushes and pulls without gradient compression - * `compr_ps_kv_` is used for gradient compression. It contains different - * pskv for push and pull because sizes would be different in both cases. - * Note: `ps_kv_[k]` for some key k may not be the same as `compr_ps_kv_[k].pull` - * This is because sharding may cause slightly different divisions when size is - * not perfectly divisible. - */ - std::unordered_map ps_kv_; - std::unordered_map compr_ps_kv_; - - /** - * \brief serialize access to ps_kv_ or push_ps_kv_/pull_ps_kv_ while encoding keys - */ - std::mutex mu_; - void InitImpl(const std::vector& keys, const std::vector& values) override { CheckUnique(keys); @@ -184,7 +143,6 @@ class KVStoreDist : public KVStoreLocal { // wait until the push is finished for (const int key : keys) { comm_buf_[key].WaitToWrite(); - compr_buf_[key].WaitToWrite(); } } else { // do nothing @@ -224,10 +182,7 @@ class KVStoreDist : public KVStoreLocal { RunContext rctx, Engine::CallbackOnComplete cb) { // convert to ps keys size_t size = recv_buf.shape().Size(); - - PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ? - EncodeDefaultKey(key, size, false) : - EncodeCompressedKey(key, size, false); + PSKV& pskv = EncodeKey(key, size); #if MKL_EXPERIMENTAL == 1 mkl_set_tblob_eager_mode(recv_buf.data()); #endif @@ -235,11 +190,8 @@ class KVStoreDist : public KVStoreLocal { // false means not to delete data when SArray is deleted auto vals = new ps::SArray(data, size, false); // issue pull - int cmd = (gradient_compression_->get_type() != CompressionType::kNone) ? - static_cast(DataHandleType::kCompressedPushPull) : - static_cast(DataHandleType::kDefaultPushPull); CHECK_NOTNULL(ps_worker_)->ZPull( - pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); }); + pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); }); }; CHECK_NOTNULL(Engine::Get())->PushAsync( @@ -249,7 +201,7 @@ class KVStoreDist : public KVStoreLocal { {recv_buf.var()}, FnProperty::kNormal, priority, - PROFILER_MESSAGE("KVStoreDistDefaultStoragePull")); + PROFILER_MESSAGE("KVStoreDistDefaultPull")); comm_->Broadcast(key, recv_buf, grouped_vals[i], priority); } @@ -309,121 +261,103 @@ class KVStoreDist : public KVStoreLocal { GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { - // merge over devices + // merge over devcies int key = uniq_keys[i]; const auto& vals = grouped_vals[i]; NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0]; + auto& send_buf = comm_buf_[key]; const auto storage_type = merged.storage_type(); - auto &comm_buf = comm_buf_[key]; if (merged.ctx().dev_mask() == cpu::kDevMask) { // Start of a push doesn't guarantee that the previous pushes are completed. // This shouldn't affect training of networks though because training involves // a sequence of push, pull, then push. This imposes ordering that the // second push happens after the first pull, and the pull happens after first push. - comm_buf = merged; // avoid memory copy + send_buf = merged; // avoid memory copy } else { - if (comm_buf.is_none()) { + if (send_buf.is_none()) { if (storage_type == kDefaultStorage) { - comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype()); + send_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype()); } else { - comm_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype()); + send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype()); } } - CopyFromTo(merged, &comm_buf); + CopyFromTo(merged, &send_buf); } // push to servers if (storage_type == kDefaultStorage) { - if (gradient_compression_->get_type() == CompressionType::kNone) { - PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true); - PushDefault(key, comm_buf, pskv, priority); - } else { - // Note: gradient compression uses `do_merge` as proxy to - // detect whether the push is initialization of a key or not. - // is_active is false when push is initialization of key - bool is_active = do_merge; - PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active); - // Returns push_pskv if active, else pull_pskv - // we want inactive gc to send uncompressed gradients, - // but sharded in the same way as later pushes would when gc becomes active - if (is_active) { - PushCompressed(key, comm_buf, pskv, priority); - } else { - PushDefault(key, comm_buf, pskv, priority); - } - } - } else if (storage_type == kRowSparseStorage) { - CHECK(gradient_compression_->get_type() == CompressionType::kNone) - << "Gradient compression for row sparse storage type is not supported"; - PushRowSparse(key, comm_buf, priority); - } else { - LOG(FATAL) << "unknown storage type"; - } - } - } - - void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int priority) { - auto &small_buf = compr_buf_[key]; - auto &res_buf = residual_[key]; - size_t original_size = comm_buf.shape().Size(); - - // Init the small buffer and residual_ buffer for quantize - if (small_buf.is_none()) { - small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, comm_buf.dtype()); - res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(), - false, comm_buf.dtype()); - res_buf = 0; - } - gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority); - auto push_to_servers = - [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) { - size_t size = small_buf.shape().Size(); - real_t* data = small_buf.data().dptr(); -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(small_buf.data()); -#endif - // do push. false means no delete - ps::SArray vals(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPush( - pskv.keys, vals, pskv.lens, - static_cast(DataHandleType::kCompressedPushPull), [cb]() { cb(); }); - }; - // acquire locks on both comm_buf and small_buf so that - // pull (which uses comm_buf) for the same key waits till push finishes - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {small_buf.var(), comm_buf.var()}, - {}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistCompressedPush")); - } - - void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) { - auto push_to_servers = - [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + auto push_to_servers = + [this, key, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) { // convert to ps keys size_t size = send_buf.shape().Size(); - real_t* data = send_buf.data().dptr(); + PSKV& pskv = EncodeKey(key, size); + #if MKL_EXPERIMENTAL == 1 mkl_set_tblob_eager_mode(send_buf.data()); #endif + real_t* data = send_buf.data().dptr(); // do push. false means no delete ps::SArray vals(data, size, false); CHECK_NOTNULL(ps_worker_)->ZPush( - pskv.keys, vals, pskv.lens, - static_cast(DataHandleType::kDefaultPushPull), [cb]() { cb(); }); + pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); }); }; - Engine::Get()->PushAsync( - push_to_servers, + Engine::Get()->PushAsync( + push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistDefaultPush")); + } else if (storage_type == kRowSparseStorage) { + PushRowSparse(key, send_buf, priority); + } else { + LOG(FATAL) << "unknown storage type"; + } + } + } + + // pull row sparse weight into `recv_buf` based on indices given by `indices` + void PullRowSparse_(const int key, const NDArray& recv_buf, + const NDArray& indices, int priority) { + using namespace rowsparse; + auto pull_from_servers = [this, key, recv_buf, indices] + (RunContext rctx, Engine::CallbackOnComplete cb) { + // allocate memory for the buffer + size_t num_rows = indices.shape().Size(); + recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)}); +#if MKL_EXPERIMENTAL == 1 + mkl_set_tblob_eager_mode(recv_buf.data()); +#endif + real_t* data = recv_buf.data().dptr(); + const auto offsets = indices.data().dptr(); + const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim()); + const int64_t size = num_rows * unit_len; + // convert to ps keys in row sparse format + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, + unit_len, recv_buf.shape()[0]); + if (this->log_verbose_) { + LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: " + << pskv.keys << " size: " << size; + } + auto vals = new ps::SArray(data, size, false); + // copy indices to recv_buf. this needs to be done before ZPull + // because after pull is done, the callback function returns and locks are released. + // at this point, later functions may access the indices variable while copy happens + mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D(), + indices.data().FlatTo1D()); + CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull, + [vals, cb]() { delete vals; cb(); }); + }; + CHECK_NOTNULL(Engine::Get())->PushAsync( + pull_from_servers, pinned_ctx_, - {send_buf.var()}, - {}, + {indices.var()}, + {recv_buf.var()}, FnProperty::kNormal, priority, - PROFILER_MESSAGE("KVStoreDistDefaultPush")); + PROFILER_MESSAGE("KVStoreDistRowSparsePull")); } // push row sparse gradient @@ -448,9 +382,9 @@ class KVStoreDist : public KVStoreLocal { << pskv.keys << " size: " << size; } ps::SArray vals(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, - static_cast(DataHandleType::kRowSparsePushPull), - [cb]() { cb(); }); + CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() { + cb(); + }); }; Engine::Get()->PushAsync( push_to_servers, @@ -462,50 +396,6 @@ class KVStoreDist : public KVStoreLocal { PROFILER_MESSAGE("KVStoreDistRowSparsePush")); } - - // pull row sparse weight into `recv_buf` based on indices given by `indices` - void PullRowSparse_(const int key, const NDArray& recv_buf, - const NDArray& indices, int priority) { - using namespace rowsparse; - auto pull_from_servers = [this, key, recv_buf, indices] - (RunContext rctx, Engine::CallbackOnComplete cb) { - // allocate memory for the buffer - size_t num_rows = indices.shape().Size(); - recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)}); -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(recv_buf.data()); -#endif - real_t* data = recv_buf.data().dptr(); - const auto offsets = indices.data().dptr(); - const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim()); - const int64_t size = num_rows * unit_len; - // convert to ps keys in row sparse format - PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, - unit_len, recv_buf.shape()[0]); - if (this->log_verbose_) { - LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: " - << pskv.keys << " size: " << size; - } - auto vals = new ps::SArray(data, size, false); - // copy indices to recv_buf. this needs to be done before ZPull - // because after pull is done, the callback function returns and locks are released. - // at this point, later functions may access the indices variable while copy happens - mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D(), - indices.data().FlatTo1D()); - CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, - static_cast(DataHandleType::kRowSparsePushPull), - [vals, cb]() { delete vals; cb(); }); - }; - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, - pinned_ctx_, - {indices.var()}, - {recv_buf.var()}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistRowSparsePull")); - } - /** * \brief check if the keys are all unique */ @@ -516,13 +406,33 @@ class KVStoreDist : public KVStoreLocal { static_cast(keys.size())); } + /** + * \brief struct for ps keys and lens + */ + struct PSKV { + ps::SArray keys; // n keys + ps::SArray lens; // the length of the i-th value + int size; + }; + + /** + * \brief cache all key partitions + */ + std::unordered_map ps_kv_; + + /** + * \brief serizelize EncodeRowSparseKey and EncodeKey + */ + std::mutex mu_; + /** * \brief convert to keys in ps */ - inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push) { + inline PSKV& EncodeKey(int key, size_t size) { mu_.lock(); PSKV& pskv = ps_kv_[key]; mu_.unlock(); + if (!pskv.keys.empty()) { CHECK_EQ(static_cast(pskv.size), size) << "The value size cannot be changed"; } else { @@ -544,8 +454,8 @@ class KVStoreDist : public KVStoreLocal { pskv.size = 0; for (int i = 0; i < num_servers; ++i) { size_t part_size = - static_cast(round(static_cast(size)/num_servers*(i+1))) - - static_cast(round(static_cast(size)/num_servers*i)); + static_cast(round(static_cast(size)/num_servers*(i+1))) - + static_cast(round(static_cast(size)/num_servers*i)); ps::Key ps_key = krs[i].begin() + key; CHECK_LT(ps_key, krs[i].end()); pskv.keys.push_back(ps_key); @@ -558,94 +468,6 @@ class KVStoreDist : public KVStoreLocal { return pskv; } - /** - * \brief Convert to keys in ps for compressed values - * Divides original array into equal parts for each server - * Populates both push and pull pskv on first call - */ - inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool is_push) { - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); - int num_servers = krs.size(); - CHECK_GT(num_servers, 0); - - // represents size of data to be sent - size_t compr_size = gradient_compression_->GetCompressedSize(original_size); - - mu_.lock(); - PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull; - mu_.unlock(); - - if (!pskv.keys.empty()) { - size_t size = (is_push) ? compr_size : original_size; - CHECK_EQ(static_cast(pskv.size), size)<< "The value size can't be changed"; - } else { - // populate both pull and push pskvs - // push pskv has sizes corresponding to compressed data - // pull pskv has decompressed sizes for parts in push_pskv - mu_.lock(); - PSKV& pull_pskv = compr_ps_kv_[key].pull; - PSKV& push_pskv = compr_ps_kv_[key].push; - mu_.unlock(); - - if (original_size < bigarray_bound_) { - // a simple heuristic for load balancing - // send it to a single random picked server - int server = (key * 9973) % num_servers; - ps::Key ps_key = krs[server].begin() + key; - CHECK_LT(ps_key, krs[server].end()); - // meta info - push_pskv.keys.push_back(krs[server].begin() + original_size); - push_pskv.lens.push_back(0); - // data - push_pskv.keys.push_back(ps_key); - pull_pskv.keys.push_back(ps_key); - push_pskv.lens.push_back(compr_size); - pull_pskv.lens.push_back(original_size); - push_pskv.size = compr_size; - pull_pskv.size = original_size; - } else { - // partition it to all servers - push_pskv.size = 0; - pull_pskv.size = 0; - - for (int i = 0; i < num_servers; ++i) { - size_t part_compr, part_orig; - if (i == num_servers-1) { - part_compr = compr_size - push_pskv.size; - part_orig = original_size - pull_pskv.size; - } else { - part_compr = - static_cast (round(static_cast(compr_size)/num_servers*(i+1))) - - static_cast (round(static_cast(compr_size)/num_servers*(i))); - part_orig = part_compr * gradient_compression_->GetCompressionFactor(); - } - - // meta info - ps::Key ps_key_dummy = krs[i].begin() + part_orig; - CHECK_LT(ps_key_dummy, krs[i].end()); - push_pskv.keys.push_back(ps_key_dummy); - push_pskv.lens.push_back(0); - - // data - ps::Key ps_key = krs[i].begin() + key; - CHECK_LT(ps_key, krs[i].end()); - push_pskv.keys.push_back(ps_key); - pull_pskv.keys.push_back(ps_key); - // push_pskv stores lengths of compressed blocks - push_pskv.lens.push_back(part_compr); - // pull_pskv stores lengths of original data - pull_pskv.lens.push_back(part_orig); - push_pskv.size += part_compr; - pull_pskv.size += part_orig; - } - CHECK_EQ(static_cast(push_pskv.size), compr_size); - CHECK_EQ(static_cast(pull_pskv.size), original_size); - CHECK_EQ(push_pskv.lens.size(), num_servers*2); - } - } - return pskv; - } - // Note: this encoding method for row sparse keys doesn't allow cross-layer batching inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows, const int64_t *offsets, const size_t unit_len, @@ -706,6 +528,7 @@ class KVStoreDist : public KVStoreLocal { return pskv; } + /** * \brief for worker to push and pull data */ @@ -718,23 +541,8 @@ class KVStoreDist : public KVStoreLocal { * \brief threshold for partition */ size_t bigarray_bound_; - /** - * \brief buffer for non-compressed data. - * When gradient compression is active, this is used - * for the data in pull and for original data in push - */ + /// \brief send & recver buffer std::unordered_map comm_buf_; - /** - * \brief buffer for compressed data - * Used when gradient compression is active and action - * is push - */ - std::unordered_map compr_buf_; - /** - * \brief residual buffer to accumulate quantization error - * during gradient compression - */ - std::unordered_map residual_; bool log_verbose_; }; diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index de94c8669abb..f2123e765f0d 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -40,13 +40,10 @@ namespace mxnet { namespace kvstore { -enum class CommandType { - kController, kStopServer, kSyncMode, kSetGradientCompression -}; - -enum class DataHandleType { - kDefaultPushPull, kCompressedPushPull, kRowSparsePushPull -}; +static const int kRowSparsePushPull = 1; +static const int kDefaultPushPull = 0; +static const int kStopServer = -1; +static const int kSyncMode = -2; /** * \brief executor runs a function using the thread called \ref Start @@ -120,7 +117,6 @@ class KVStoreDistServer { ps_server_->set_request_handle( std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); sync_mode_ = false; - gradient_compression_ = std::make_shared(); log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } @@ -152,15 +148,11 @@ class KVStoreDistServer { }; void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) { - CommandType recved_type = static_cast(recved.head); - if (recved_type == CommandType::kStopServer) { + if (recved.head == kStopServer) { exec_.Stop(); - } else if (recved_type == CommandType::kSyncMode) { + } else if (recved.head == kSyncMode) { sync_mode_ = true; - } else if (recved_type == CommandType::kSetGradientCompression) { - gradient_compression_->DecodeParams(recved.body); } else { - // this uses value 0 for message id from frontend // let the main thread to execute ctrl, which is necessary for python exec_.Exec([this, recved]() { CHECK(controller_); @@ -173,11 +165,8 @@ class KVStoreDistServer { void DataHandleEx(const ps::KVMeta& req_meta, const ps::KVPairs& req_data, ps::KVServer* server) { - DataHandleType recved_type = static_cast(req_meta.cmd); - if (recved_type == DataHandleType::kRowSparsePushPull) { + if (req_meta.cmd == kRowSparsePushPull) { DataHandleRowSparse(req_meta, req_data, server); - } else if (recved_type == DataHandleType::kCompressedPushPull) { - DataHandleCompressed(req_meta, req_data, server); } else { DataHandleDefault(req_meta, req_data, server); } @@ -370,91 +359,10 @@ class KVStoreDistServer { } } - void DefaultStorageResponse(int key, const NDArray& stored, - const ps::KVMeta& req_meta, - const ps::KVPairs &req_data, - ps::KVServer* server) { - ps::KVPairs response; - CHECK(!stored.is_none()) << "init " << key << " first"; - auto len = stored.shape().Size(); - response.keys = req_data.keys; - response.lens = {len}; - // TODO(mli) try to remove this CopyFrom - response.vals.CopyFrom(static_cast(stored.data().dptr_), len); - server->Response(req_meta, response); - } - - void DataHandleCompressed(const ps::KVMeta& req_meta, - const ps::KVPairs &req_data, - ps::KVServer* server) { - if (req_meta.push) { - // there used several WaitToRead, this is because \a recved's memory - // could be deallocated when this function returns. so we need to make sure - // the operators with \a NDArray are actually finished - - // first for dummy key which represents original size of array, whose len is 0 - CHECK_EQ(req_data.keys.size(), (size_t)2); - CHECK_EQ(req_data.lens.size(), (size_t)2); - CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[1]); - - int original_size = DecodeKey(req_data.keys[0]); - int key = DecodeKey(req_data.keys[1]); - auto& stored = store_[key]; - - size_t ds[] = {(size_t)req_data.lens[1]}; - TShape dshape(ds, ds + 1); - TBlob recv_blob((real_t*) req_data.vals.data(), // NOLINT(*) - dshape, cpu::kDevMask); - NDArray recved = NDArray(recv_blob, 0); - - NDArray decomp_buf = decomp_buf_[key]; - dshape = TShape{(int64_t) original_size}; - - if (decomp_buf.is_none()) { - decomp_buf = NDArray(dshape, Context()); - } - - if (stored.is_none()) { - stored = NDArray(dshape, Context()); - gradient_compression_->Dequantize(recved, &stored, 0); - server->Response(req_meta); - stored.WaitToRead(); - } else if (sync_mode_) { - // synced push - auto& merged = merge_buf_[key]; - if (merged.array.is_none()) { - merged.array = NDArray(dshape, Context()); - } - if (merged.request.size() == 0) { - gradient_compression_->Dequantize(recved, &merged.array, 0); - } else { - gradient_compression_->Dequantize(recved, &decomp_buf, 0); - merged.array += decomp_buf; - } - merged.request.push_back(req_meta); - ApplyUpdates(key, &merged, &stored, server); - } else { - // async push - gradient_compression_->Dequantize(recved, &decomp_buf, 0); - exec_.Exec([this, key, &decomp_buf, &stored]() { - CHECK(updater_); - updater_(key, decomp_buf, &stored); - }); - server->Response(req_meta); - stored.WaitToRead(); - } - } else { // pull - CHECK_EQ(req_data.keys.size(), (size_t)1); - CHECK_EQ(req_data.lens.size(), (size_t)0); - int key = DecodeKey(req_data.keys[0]); - DefaultStorageResponse(key, store_[key], req_meta, req_data, server); - } - } - void DataHandleDefault(const ps::KVMeta& req_meta, const ps::KVPairs &req_data, ps::KVServer* server) { - CHECK_EQ(req_meta.cmd, static_cast(DataHandleType::kDefaultPushPull)); + CHECK_EQ(req_meta.cmd, kDefaultPushPull); // do some check CHECK_EQ(req_data.keys.size(), (size_t)1); if (req_meta.push) { @@ -503,7 +411,15 @@ class KVStoreDistServer { stored.WaitToRead(); } } else { - DefaultStorageResponse(key, stored, req_meta, req_data, server); + // pull + ps::KVPairs response; + CHECK(!stored.is_none()) << "init " << key << " first"; + auto len = stored.shape().Size(); + response.keys = req_data.keys; + response.lens = {len}; + // TODO(mli) try to remove this CopyFrom + response.vals.CopyFrom(static_cast(stored.data().dptr_), len); + server->Response(req_meta, response); } } @@ -512,44 +428,21 @@ class KVStoreDistServer { return key - kr.begin(); } - /** - * \brief user defined mode for push + * \brief user defined */ bool sync_mode_; KVStore::Controller controller_; KVStore::Updater updater_; - /** - * \brief store_ contains the value at kvstore for each key - */ std::unordered_map store_; - - /** - * \brief merge_buf_ is a buffer used if sync_mode is true. It represents - * values from different workers being merged. The store will be updated - * to this value when values from all workers are pushed into this buffer. - */ std::unordered_map merge_buf_; - /** - * \brief decomp_buf_ is a buffer into which compressed values are - * decompressed before merging to the store. used when compress_!='none' - */ - std::unordered_map decomp_buf_; - Executor exec_; ps::KVServer* ps_server_; // whether to LOG verbose information bool log_verbose_; - - /** - * \brief gradient compression object. - * starts with none, used after SetGradientCompression sets the type - * currently there is no support for unsetting gradient compression - */ - std::shared_ptr gradient_compression_; }; } // namespace kvstore diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 9fe161c4b0ee..1a4ced8a4f58 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -59,7 +59,6 @@ class KVStoreLocal : public KVStore { comm_ = new CommCPU(); } pinned_ctx_ = comm_->pinned_ctx(); - gradient_compression_ = std::make_shared(); } virtual ~KVStoreLocal() { @@ -137,11 +136,6 @@ class KVStoreLocal : public KVStore { PullRowSparseImpl(keys, val_rowids, priority); } - void SetGradientCompression(const std::vector > - & kwargs) override { - gradient_compression_->SetParams(kwargs); - } - private: virtual void InitImpl(const std::vector& keys, const std::vector& values) { @@ -151,7 +145,6 @@ class KVStoreLocal : public KVStore { local_[keys[i]] = values[i].Copy(pinned_ctx_); comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } - comm_->SetGradientCompression(gradient_compression_); } virtual void PushImpl(const std::vector& keys, diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index df85fe586054..900d6bb6afb7 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -23,8 +23,7 @@ import mxnet as mx import numpy as np import numpy.random as rnd -from mxnet.test_utils import assert_almost_equal -from test_kvstore import compute_expected_2bit_quantization +import time def check_diff_to_scalar(A, x, rank=None): """ assert A == x""" @@ -40,7 +39,6 @@ def check_diff_to_scalar(A, x, rank=None): rate = 2 shape = (2, 3) -irregular_shape = (1211,1211) big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND kv = mx.kv.create('dist_sync') @@ -59,17 +57,6 @@ def init_kv(): kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) return kv, my_rank, nworker -def init_kv_compressed(kv): - threshold = 0.5 - kv.set_gradient_compression({'type': '2bit', 'threshold':threshold}) - # init kv compression keys - kv.init('11221', mx.nd.zeros(big_shape)) - kv.init('112221', mx.nd.zeros(irregular_shape)) - kv.init('1121', mx.nd.zeros(shape)) - # to test inactive mode - kv.init('1122', mx.nd.ones(shape)) - return kv, threshold - def test_sync_push_pull(): kv, my_rank, nworker = init_kv() def check_default_keys(kv, my_rank, nworker): @@ -186,114 +173,11 @@ def check_big_row_sparse_keys(kv, my_rank, nworker): expected[row] = updated_val[row] check_diff_to_scalar(val, expected, rank=my_rank) - def check_compr_residual(kv, threshold, nworker): - for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]: - # doesn't meet threshold - kv.push(k, mx.nd.ones(s)*0.4) - val=mx.nd.zeros(s) - kv.pull(k,val) - check_diff_to_scalar(val, 0) - - # just meets threshold with residual - kv.push(k, mx.nd.ones(s)*(threshold - 0.4)) - val2 = mx.nd.zeros(s) - kv.pull(k,val2) - curval = threshold * rate * nworker - check_diff_to_scalar(val2, curval) - - # doesn't meet threshold - kv.push(k, mx.nd.ones(s)*0.2) - val3= mx.nd.zeros(s) - kv.pull(k, val3) - check_diff_to_scalar(val3, curval) - - # exceeds again - kv.push(k, mx.nd.ones(s)*(threshold-0.2)) - val4 = mx.nd.zeros(s) - kv.pull(k,val4) - curval += threshold*rate*nworker - check_diff_to_scalar(val4, curval) - # residual is 0 now - - def check_compr_ones(kv, threshold, nworker): - for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]: - val = mx.nd.zeros(s) - kv.pull(k, val) - curval = val[0][0].asnumpy()[0] - kv.push(k,mx.nd.ones(s)*threshold) - val2 = mx.nd.zeros(s) - kv.pull(k, val2) - newval = curval + rate*nworker*threshold - check_diff_to_scalar(val2, newval) - # residual = 0 again - - def check_compr_pull_before_push(kv): - for k,s in [('1121', shape),('112221',irregular_shape), - ('11221', big_shape), ('1122',shape)]: - if k=='1122': - # tests that GC is not used for init of a key - val = mx.nd.zeros(s) - kv.pull(k, val) - check_diff_to_scalar(val, 1) - else: - val = mx.nd.ones(s) - kv.pull(k, val) - check_diff_to_scalar(val, 0) - - def check_compr_zero(kv): - for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]: - kv.push(k, mx.nd.zeros(s)) - # to check that all are set to 0s - val = mx.nd.ones(s) - kv.pull(k, val) - check_diff_to_scalar(val, 0) - - def check_compr_random(kv, threshold, nworker): - # set a seed so all workers generate same data. knowing this helps - # calculate expected value after pull - mx.random.seed(123) - rnd.seed(123) - nrepeat = 5 - compr_random_keys_shapes = [('2121', shape),('212221',irregular_shape),('21221', big_shape)] - # use new keys so residual is 0 for calculation of expected - for k,s in compr_random_keys_shapes: - kv.init(k, mx.nd.zeros(s)) - for k,s in compr_random_keys_shapes: - curr_residual = np.zeros(s) - for l in range(nrepeat): - orig_val = mx.nd.zeros(s) - kv.pull(k, orig_val) - - grad = mx.nd.array(rnd.rand(s[0], s[1])) - # creates a copy because push changes grad because of assignment - grad_cpy = mx.nd.array(grad) - kv.push(k, grad) - val = mx.nd.zeros(s) - kv.pull(k, val) - - diff = val - orig_val - - # compute expected by using simulation of operator - compr, curr_residual, decompr = compute_expected_2bit_quantization(grad_cpy, curr_residual, threshold) - decompr *= nworker * rate - assert_almost_equal(diff.asnumpy(), decompr) - - print ('worker '+str(my_rank)+' started with non compression tests') check_default_keys(kv, my_rank, nworker) check_row_sparse_keys(kv, my_rank, nworker) check_row_sparse_keys_with_zeros(kv, my_rank, nworker) check_big_row_sparse_keys(kv, my_rank, nworker) - print('worker ' + str(my_rank) + ' is done with non compression tests') - - # don't run non compressed keys after this as kvstore now is set to compressed - print ('worker '+str(my_rank)+' started with compression tests') - kv, threshold = init_kv_compressed(kv) - check_compr_pull_before_push(kv) - check_compr_zero(kv) - check_compr_residual(kv, threshold, nworker) - check_compr_ones(kv, threshold, nworker) - check_compr_random(kv, threshold, nworker) - print('worker ' + str(my_rank) + ' is done with compression tests') + print('worker ' + str(my_rank) + ' is done') def test_sync_init(): def check_init(kv, cur_keys, cur_shape, device=False): diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py index a14feac7a3aa..081bc9c5a456 100644 --- a/tests/nightly/test_kvstore.py +++ b/tests/nightly/test_kvstore.py @@ -21,59 +21,17 @@ sys.path.insert(0, "../../python/") import mxnet as mx import numpy as np -import numpy.random as rnd -import copy -from mxnet.test_utils import assert_almost_equal +keys = [3, 5, 7] +# let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND +shapes = [(4, 4), (100, 100), (2000, 2000)]; -def check_diff_to_scalar(A, x, rank=None): - """ assert A == x""" - assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) +lr = .1 +nworker = 4 +nrepeat = 10 -def compute_expected_2bit_quantization(arr, curr_residual, threshold): - from struct import pack,unpack - def bits2int(bits): - bits = [int(x) for x in bits[::-1]] - x = 0 - for i in range(len(bits)): - x += bits[i]*2**i - return x - - def as_float32(s): - return unpack("f",pack("I", bits2int(s)))[0] - - # str_quant stores the quantized representation as a sequence of bits - str_quant = '' - new_residual = [] - decompr = [] - - arr_npy = arr.asnumpy() - for i, a in np.ndenumerate(arr_npy): - a += curr_residual[i] - if a >= threshold: - str_quant += '11' - new_residual.append(a - threshold) - decompr.append(threshold) - elif a <= (-1*threshold): - str_quant += '10' - new_residual.append(a + threshold) - decompr.append(-1*threshold) - else: - str_quant += '00' - new_residual.append(a) - decompr.append(0) - # append extra bits when size of array not a factor of 16 - if len(str_quant)%16 != 0: - str_quant += '0'*(16 - len(str_quant)%16) - - compr = [] - # converts the string generated into integers 32chars at a time - i = 0 - while i