diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index 51a1abec7c48..2b002c770266 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -103,6 +103,11 @@ def add_fit_args(parser): help='1 means test reading speed without training') train.add_argument('--dtype', type=str, default='float32', help='precision: float32 or float16') + train.add_argument('--gc-type', type=str, default='none', + help='type of gradient compression to use, \ + takes `2bit` or `none` for now') + train.add_argument('--gc-threshold', type=float, default=0.5, + help='threshold for 2bit gradient compression') return train def fit(args, network, data_loader, **kwargs): @@ -114,6 +119,9 @@ def fit(args, network, data_loader, **kwargs): """ # kvstore kv = mx.kvstore.create(args.kv_store) + if args.gc_type != 'none': + kv.set_gradient_compression({'type': args.gc_type, + 'threshold': args.gc_threshold}) # logging head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' @@ -162,10 +170,10 @@ def fit(args, network, data_loader, **kwargs): lr_scheduler = lr_scheduler optimizer_params = { - 'learning_rate': lr, - 'wd' : args.wd, - 'lr_scheduler': lr_scheduler, - 'multi_precision': True} + 'learning_rate': lr, + 'wd' : args.wd, + 'lr_scheduler': lr_scheduler, + 'multi_precision': True} # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag'} @@ -195,17 +203,17 @@ def fit(args, network, data_loader, **kwargs): # run model.fit(train, - begin_epoch = args.load_epoch if args.load_epoch else 0, - num_epoch = args.num_epochs, - eval_data = val, - eval_metric = eval_metrics, - kvstore = kv, - optimizer = args.optimizer, - optimizer_params = optimizer_params, - initializer = initializer, - arg_params = arg_params, - aux_params = aux_params, - batch_end_callback = batch_end_callbacks, - epoch_end_callback = checkpoint, - allow_missing = True, - monitor = monitor) + begin_epoch = args.load_epoch if args.load_epoch else 0, + num_epoch = args.num_epochs, + eval_data = val, + eval_metric = eval_metrics, + kvstore = kv, + optimizer = args.optimizer, + optimizer_params = optimizer_params, + initializer = initializer, + arg_params = arg_params, + aux_params = aux_params, + batch_end_callback = batch_end_callbacks, + epoch_end_callback = checkpoint, + allow_missing = True, + monitor = monitor) diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/lstm_bucketing.py index 2e7bc65d437a..0e7f064f0078 100644 --- a/example/rnn/lstm_bucketing.py +++ b/example/rnn/lstm_bucketing.py @@ -48,7 +48,6 @@ parser.add_argument('--disp-batches', type=int, default=50, help='show progress for every n batches') - def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0): if not os.path.isfile(fname): raise IOError("Please use get_ptb_data.sh to download requied file (data/ptb.train.txt)") diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 6fa7afdb0690..3486367278f7 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1549,6 +1549,19 @@ MXNET_DLL int MXInitPSEnv(mx_uint num_vars, */ MXNET_DLL int MXKVStoreCreate(const char *type, KVStoreHandle *out); + +/*! + * \brief Set parameters to use low-bit compressed gradients + * \param handle handle to the kvstore + * \param keys keys for compression parameters + * \param vals values for compression parameters + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle, + mx_uint num_params, + const char** keys, + const char** vals); + /*! * \brief Delete a KVStore handle. * \param handle handle to the kvstore diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index ddaa207daba2..6957876b6c42 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -30,6 +30,7 @@ #include #include #include +#include "../../src/kvstore/gradient_compression.h" #include "./ndarray.h" #if MXNET_USE_DIST_KVSTORE #include "ps/ps.h" @@ -64,6 +65,14 @@ class KVStore { */ inline const std::string& type() { return type_; } + /** + * \brief Set parameters to use low-bit compressed gradients + * \param compression_type type of compression + * \param threshold threshold for 2bit compression + */ + virtual void SetGradientCompression(const std::vector > + & kwargs) = 0; + /*! * \brief Initialize a list of key-value pair to the store. * @@ -387,6 +396,12 @@ class KVStore { */ std::string type_; + /** \brief Gradient compression object starts with GC_NONE mode + * Used if SetGradientCompression sets the type. + * Currently there is no support for un-setting gradient compression + */ + std::shared_ptr gradient_compression_; + /** * \brief whether to do barrier when finalize */ diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 115d1ff09ce5..f3a14609587f 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -44,6 +44,11 @@ class Trainer(object): kvstore : str or KVStore kvstore type for multi-gpu and distributed training. See help on :any:`mxnet.kvstore.create` for more information. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. Properties ---------- @@ -51,7 +56,8 @@ class Trainer(object): The current learning rate of the optimizer. Given an Optimizer object optimizer, its learning rate can be accessed as optimizer.learning_rate. """ - def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'): + def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', + compression_params=None): if isinstance(params, (dict, ParameterDict)): params = list(params.values()) if not isinstance(params, (list, tuple)): @@ -65,7 +71,7 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'): "First argument must be a list or dict of Parameters, " \ "got list of %s."%(type(param))) self._params.append(param) - + self._compression_params = compression_params optimizer_params = optimizer_params if optimizer_params else {} self._scale = optimizer_params.get('rescale_grad', 1.0) self._contexts = self._check_contexts() @@ -104,6 +110,8 @@ def _init_kvstore(self): kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts), arg_arrays) if kvstore: + if self._compression_params: + kvstore.set_gradient_compression(self._compression_params) if 'dist' in kvstore.type: update_on_kvstore = False for i, param in enumerate(self._params): diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 8625303ee40e..bf424559df8d 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -64,6 +64,16 @@ def _ctype_key_value(keys, vals): else c_array_buf(ctypes.c_int, array('i', [keys] * len(vals))) return (c_keys, c_handle_array(vals), use_str_keys) +def _ctype_dict(param_dict): + """ + Returns ctype arrays for keys and values(converted to strings) in a dictionary + """ + assert(isinstance(param_dict, dict)), \ + "unexpected type for param_dict: " + str(type(param_dict)) + c_keys = c_array(ctypes.c_char_p, [c_str(k) for k in param_dict.keys()]) + c_vals = c_array(ctypes.c_char_p, [c_str(str(v)) for v in param_dict.values()]) + return (c_keys, c_vals) + def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" def updater_handle(key, lhs_handle, rhs_handle, _): @@ -350,6 +360,58 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): check_call(_LIB.MXKVStorePullRowSparse( self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority))) + def set_gradient_compression(self, compression_params): + """ Specifies type of low-bit quantization for gradient compression \ + and additional arguments depending on the type of compression being used. + + 2bit Gradient Compression takes a positive float `threshold`. + The technique works by thresholding values such that positive values in the + gradient above threshold will be set to threshold. Negative values whose absolute + values are higher than threshold, will be set to the negative of threshold. + Values whose absolute values are less than threshold will be set to 0. + By doing so, each value in the gradient is in one of three states. 2bits are + used to represent these states, and every 16 float values in the original + gradient can be represented using one float. This compressed representation + can reduce communication costs. The difference between these thresholded values and + original values is stored at the sender's end as residual and added to the + gradient in the next iteration. + + When kvstore is 'local', gradient compression is used to reduce communication + between multiple devices (gpus). Gradient is quantized on each GPU which + computed the gradients, then sent to the GPU which merges the gradients. This + receiving GPU dequantizes the gradients and merges them. Note that this + increases memory usage on each GPU because of the residual array stored. + + When kvstore is 'dist', gradient compression is used to reduce communication + from worker to sender. Gradient is quantized on each worker which + computed the gradients, then sent to the server which dequantizes + this data and merges the gradients from each worker. Note that this + increases CPU memory usage on each worker because of the residual array stored. + Only worker to server communication is compressed in this setting. + If each machine has multiple GPUs, currently this GPU to GPU or GPU to CPU communication + is not compressed. Server to worker communication (in the case of pull) + is also not compressed. + + To use 2bit compression, we need to specify `type` as `2bit`. + Only specifying `type` would use default value for the threshold. + To completely specify the arguments for 2bit compression, we would need to pass + a dictionary which includes `threshold` like: + {'type': '2bit', 'threshold': 0.5} + + Parameters + ---------- + compression_params : dict + A dictionary specifying the type and parameters for gradient compression. + The key `type` in this dictionary is a + required string argument and specifies the type of gradient compression. + Currently `type` can be only `2bit` + Other keys in this dictionary are optional and specific to the type + of gradient compression. + """ + ckeys, cvals = _ctype_dict(compression_params) + check_call(_LIB.MXKVStoreSetGradientCompression(self.handle, + mx_uint(len(compression_params)), + ckeys, cvals)) def set_optimizer(self, optimizer): """ Registers an optimizer with the kvstore. diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index dd6cafb277f0..4a5330ea2c5a 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -54,10 +54,16 @@ class BucketingModule(BaseModule): Instead they are initialized to 0 and can be set by set_states() group2ctxs : list of dict of str to context Default is `None`. Mapping the `ctx_group` attribute to the context assignment. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. """ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None, group2ctxs=None): + fixed_param_names=None, state_names=None, group2ctxs=None, + compression_params=None): super(BucketingModule, self).__init__(logger=logger) assert default_bucket_key is not None @@ -75,6 +81,7 @@ def __init__(self, sym_gen, default_bucket_key=None, logger=logging, _check_input_names(symbol, state_names, "state", True) _check_input_names(symbol, fixed_param_names, "fixed_param", True) + self._compression_params = compression_params self._fixed_param_names = fixed_param_names self._state_names = state_names self._context = context @@ -322,7 +329,9 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, module = Module(symbol, data_names, label_names, logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names, group2ctxs=self._group2ctxs) + state_names=self._state_names, + group2ctxs=self._group2ctxs, + compression_params=self._compression_params) module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind=False, shared_module=None, grad_req=grad_req) self._curr_module = module @@ -352,7 +361,9 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): logger=self.logger, context=self._context, work_load_list=self._work_load_list, fixed_param_names=self._fixed_param_names, - state_names=self._state_names, group2ctxs=self._group2ctxs) + state_names=self._state_names, + group2ctxs=self._group2ctxs, + compression_params=self._compression_params) module.bind(data_shapes, label_shapes, self._curr_module.for_training, self._curr_module.inputs_need_grad, force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 8301330313ae..a9c6516a32ed 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -61,10 +61,16 @@ class Module(BaseModule): Instead they are initialized to 0 and can be set by `set_states()`. group2ctxs : list of dict of str to context Default is `None`. Mapping the `ctx_group` attribute to the context assignment. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. """ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), logger=logging, context=ctx.cpu(), work_load_list=None, - fixed_param_names=None, state_names=None, group2ctxs=None): + fixed_param_names=None, state_names=None, group2ctxs=None, + compression_params=None): super(Module, self).__init__(logger=logger) if isinstance(context, ctx.Context): @@ -103,6 +109,7 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), self._aux_params = None self._params_dirty = False + self._compression_params = compression_params self._optimizer = None self._kvstore = None self._update_on_kvstore = None @@ -525,6 +532,8 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', self._updater = None if kvstore: + if self._compression_params: + kvstore.set_gradient_compression(self._compression_params) # copy initialized local parameters to kvstore _initialize_kvstore(kvstore=kvstore, param_arrays=self._exec_group.param_arrays, diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index da759fe2f49f..87edd1863214 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -747,6 +747,20 @@ int MXKVStoreCreate(const char *type, API_END(); } +int MXKVStoreSetGradientCompression(KVStoreHandle handle, mx_uint num_params, + const char** keys, const char** vals) { + API_BEGIN(); + std::vector > params; + for (mx_uint i = 0; i < num_params; ++i) { + std::pair p; + p.first = keys[i]; + p.second = vals[i]; + params.push_back(p); + } + static_cast(handle)->SetGradientCompression(params); + API_END(); +} + int MXKVStoreFree(KVStoreHandle handle) { API_BEGIN(); delete static_cast(handle); diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 028ab5992c76..ca261a96ef3d 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -30,6 +30,7 @@ #include #include #include "mxnet/ndarray.h" +#include "gradient_compression.h" #include "../ndarray/ndarray_function.h" #include "../operator/tensor/sparse_retain-inl.h" namespace mxnet { @@ -79,8 +80,18 @@ class Comm { return pinned_ctx_; } + /** + * \brief Sets gradient compression parameters to be able to + * perform reduce with compressed gradients + */ + void SetGradientCompression(std::shared_ptr gc) { + gc_ = gc; + } + protected: Context pinned_ctx_; + + std::shared_ptr gc_; }; /** @@ -484,14 +495,7 @@ class CommDevice : public Comm { } } - const NDArray& Reduce(int key, const std::vector& src, - int priority) override { - // avoid extra copy for single device, but it may bring problems for - // abnormal usage of kvstore - if (src.size() == 1) { - return src[0]; - } - + void InitBuffersAndComm(const std::vector& src) { if (!inited_) { std::vector devs; for (const auto& a : src) { @@ -502,7 +506,23 @@ class CommDevice : public Comm { EnableP2P(devs); } } + } + + const NDArray& Reduce(int key, const std::vector& src, + int priority) override { + // when this reduce is called from kvstore_dist, gc is not set + // we don't do compression twice in dist_sync_device + if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) { + return ReduceCompressed(key, src, priority); + } + + // avoid extra copy for single device, but it may bring problems for + // abnormal usage of kvstore + if (src.size() == 1) { + return src[0]; + } + InitBuffersAndComm(src); auto& buf = merge_buf_[key]; std::vector reduce(src.size()); CopyFromTo(src[0], &(buf.merged), priority); @@ -525,7 +545,52 @@ class CommDevice : public Comm { } ElementwiseSum(reduce, &buf.merged); + return buf.merged; + } + + const NDArray& ReduceCompressed(int key, const std::vector& src, + int priority) { + InitBuffersAndComm(src); + auto& buf = merge_buf_[key]; + std::vector reduce(src.size()); + if (buf.copy_buf.empty()) { + // one buf for each context + buf.copy_buf.resize(src.size()); + buf.compressed_recv_buf.resize(src.size()); + buf.compressed_send_buf.resize(src.size()); + buf.residual.resize(src.size()); + for (size_t i = 0; i < src.size(); ++i) { + buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), + false, buf.merged.dtype()); + buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(), + false, buf.merged.dtype()); + buf.residual[i] = 0; + int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size()); + buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(), + false, buf.merged.dtype()); + buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(), + false, buf.merged.dtype()); + } + } + + for (size_t i = 0; i < src.size(); ++i) { + // compress before copy + // this is done even if the data is on same context as copy_buf because + // we don't want the training to be biased towards data on this GPU + gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority); + + if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) { + CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority); + } else { + // avoid memory copy when they are on same context + buf.compressed_recv_buf[i] = buf.compressed_send_buf[i]; + } + + gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + } + ElementwiseSum(reduce, &buf.merged); return buf.merged; } @@ -638,6 +703,12 @@ class CommDevice : public Comm { NDArray merged; /// \brief the gpu buffer std::vector copy_buf; + /// \brief the residual buffer for gradient compression + std::vector residual; + /// \brief the small buffer for compressed data in sender + std::vector compressed_send_buf; + /// \brief the small buffer for compressed data in receiver + std::vector compressed_recv_buf; }; std::unordered_map merge_buf_; bool inited_; diff --git a/src/kvstore/gradient_compression-inl.h b/src/kvstore/gradient_compression-inl.h new file mode 100644 index 000000000000..9b69bd11472c --- /dev/null +++ b/src/kvstore/gradient_compression-inl.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file gradient_compression-inl.h + * \author Rahul Huilgol + * \brief Declares and defines functions used to quantize and dequantize data + */ +#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_ +#define MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_ + +#include +#include "../operator/mxnet_op.h" + +namespace mxnet { +namespace kvstore { + +// these gpu functions are defined in gradient_compression.cu +void Quantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, + const float threshold); +void Dequantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, + const float threshold); + +struct quantize_2bit { + MSHADOW_XINLINE static void Map(int out_block_id, + int original_size, + float *out, + float *grad, + float *residual, + const float neg_threshold, + const float pos_threshold) { + // this block contains the compressed representation of + // upto 16 values starting from out_block_id*16 + float *compr_block = out + out_block_id; + // init to 0 + *compr_block = 0; + // start and end are indices in original grad array + const int start = out_block_id << 4; + const int end = (start + 16 <= original_size) ? start + 16 : original_size; + // cast as char* to manipulate bits of float addresses + char *block_ptr = reinterpret_cast < char * > (compr_block); + // masks to set bits when value meets pos_threshold + // 0xc0 is mask when value is to be represented by the first two bits in a char* + // 0xc0 means first two bits are set to 11 + const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03}; + // masks to set bits when value meets neg_threshold + const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02}; + for (int i = start; i < end; i++) { + // adds offset to reach appropriate byte + char *curr_byte = block_ptr + ((i - start) >> 2); + // adds gradient to existing residual to get updated grad + residual[i] += grad[i]; + if (residual[i] >= pos_threshold) { + // set data to 11 + *curr_byte |= posbits[(i & 3)]; + // reduce residual by pos_threshold + residual[i] -= pos_threshold; + } else if (residual[i] <= neg_threshold) { + // set data to 10 + *curr_byte |= negbits[(i & 3)]; + residual[i] -= neg_threshold; + } + } + } +}; + +template +void Quantize2BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, + const float threshold) { + mxnet::op::mxnet_op::Kernel + ::Launch(s, + inputs[2].Size(), // compressed array size + inputs[0].Size(), // original size + inputs[2].dptr(), // compressed array + inputs[0].dptr(), // original array + inputs[1].dptr(), // residual array + -1 *threshold, // negative threshold + threshold); // positive threshold +} + +struct dequantize_2bit { + MSHADOW_XINLINE static void Map(int i, + float *out, + float *in, + const float neg_threshold, + const float pos_threshold) { + // get position of dequantized value to fill + float *outval = out + i; + // gets byte which holds quantized value for this position + char *ch_ptr = reinterpret_cast(in + (i >> 4)); + ch_ptr += ((i & 15) >> 2); + // masks used to quantize data + const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03}; + const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02}; + // col denotes which two bits of a byte are set for this value + // col=0 implies first two bits, col=3 implies last two bits,... + const int col = i & 3; + const uint8_t mask = posbits[col]; + const uint8_t negmask = negbits[col]; + const uint8_t masked = *ch_ptr & mask; + if (masked == mask) { + *outval = pos_threshold; + } else if (masked == negmask) { + // use posbits for mask as posbits are both 1s + // then compare masked with negbits to see if only negbits were set + *outval = neg_threshold; + } else { + *outval = 0; + } + } +}; + +template +void Dequantize2BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, + const float threshold) { + mxnet::op::mxnet_op::Kernel + ::Launch(s, + inputs[1].Size(), // original size + inputs[1].dptr(), // out array + inputs[0].dptr(), // compressed array + -1 *threshold, // negative threshold + threshold); // positive threshold +} + +inline void Quantize2BitImpl(mshadow::Stream *s, + const std::vector &inputs, + const float threshold) { + Quantize2BitKernelLaunch(s, inputs, threshold); +} + +inline void Dequantize2BitImpl(mshadow::Stream *s, + const std::vector &inputs, + const float threshold) { + Dequantize2BitKernelLaunch(s, inputs, threshold); +} +} // namespace kvstore +} // namespace mxnet + +#endif // MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_ diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc new file mode 100644 index 000000000000..b8c626cd53a8 --- /dev/null +++ b/src/kvstore/gradient_compression.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file gradient_compression.cc + * \brief Gradient compression for kvstore + * \author Rahul Huilgol + */ + +#include +#include +#include "gradient_compression.h" +#include "gradient_compression-inl.h" + +namespace mxnet { +namespace kvstore { + +/*! + * \brief Splits a string into smaller strings using char as delimiter + * Example: "a,b,c,,d" is split into ["a","b","c","","d"] + * \param s string to split + * \param delim char to split string around + * \param result container for tokens extracted after splitting + */ +template +void split(const std::string &s, const char delim, Out result) { + std::stringstream ss; + ss.str(s); + std::string item; + while (std::getline(ss, item, delim)) { + *(result++) = item; + } +} + +DMLC_REGISTER_PARAMETER(GradientCompressionParam); + +GradientCompression::GradientCompression() { + type_ = CompressionType::kNone; +} + +void GradientCompression::SetParams(const std::vector > + & kwargs) { + GradientCompressionParam params; + params.InitAllowUnknown(kwargs); + CHECK_GT(params.threshold, 0) << "threshold must be greater than 0"; + if (params.type == "2bit") { + SetTwoBitCompression(params.threshold); + } else { + LOG(FATAL) << "Unknown type for gradient compression " << params.type; + } +} + +CompressionType GradientCompression::get_type() { + return type_; +} + +std::string GradientCompression::get_type_str() { + return std::to_string(static_cast(type_)); +} + +void GradientCompression::SetTwoBitCompression(const float threshold) { + type_ = CompressionType::kTwoBit; + threshold_ = threshold; +} + +std::string GradientCompression::EncodeParams() { + using namespace std; // to reduce length of next line + string rval = get_type_str(); + if (type_ == CompressionType::kTwoBit) { + rval += "," + to_string(threshold_); + } + return rval; +} + +void GradientCompression::DecodeParams(const std::string &s) { + std::vector elems; + split(s, ',', std::back_inserter(elems)); + type_ = static_cast(stoi(elems[0])); + if (elems.size() > 1) { + if (!elems[1].empty()) { + threshold_ = stof(elems[1]); + } + } +} + +int GradientCompression::GetCompressionFactor() { + if (type_ == CompressionType::kTwoBit) { + return 16; + } else { + LOG(FATAL) << "Unsupported compression type: " << get_type_str(); + return 0; + } +} + +int64_t GradientCompression::GetCompressedSize(const int64_t original_size) { + const int bits = GetCompressionFactor(); + return ((original_size % bits == 0) ? + original_size / bits : + original_size / bits + 1); +} + +void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, + mxnet::NDArray *residual, const int priority) { + CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape"; + CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; + CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape"; + const int a = from.ctx().dev_mask(); + const int b = to->ctx().dev_mask(); + const float threshold = threshold_; + if (type_ == CompressionType::kTwoBit) { + if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { + mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize2BitImpl(ctx.get_stream(), inputs, threshold); + }, from.ctx(), {from.var()}, {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeCPU")); + } else { +#if MXNET_USE_CUDA + if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { + mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize2BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, from.ctx(), {from.var()}, {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeGPU")); + } else { + LOG(FATAL) << "unknown device mask"; + } +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } + } else { + LOG(FATAL) << "Unsupported quantization of type " << get_type_str(); + } +} + +void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, + const int priority) { + CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape"; + CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape"; + const int a = from.ctx().dev_mask(); + const int b = to->ctx().dev_mask(); + const float threshold = threshold_; + if (type_ == CompressionType::kTwoBit) { + if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { + mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); + }, from.ctx(), {from.var()}, {to->var()}, + mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeCPU")); + } else { +#if MXNET_USE_CUDA + if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { + mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, from.ctx(), {from.var()}, {to->var()}, + mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeGPU")); + } else { + LOG(FATAL) << "unknown device mask"; + } +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } + } else { + LOG(FATAL) << "Unsupported dequantization of type " << get_type_str(); + } +} + +} // namespace kvstore +} // namespace mxnet + diff --git a/src/kvstore/gradient_compression.cu b/src/kvstore/gradient_compression.cu new file mode 100644 index 000000000000..b0d9662520b2 --- /dev/null +++ b/src/kvstore/gradient_compression.cu @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file gradient_compression.cu + * \author Rahul Huilgol + * \brief Implementation for gpu version of code + */ + +#include "gradient_compression-inl.h" + +namespace mxnet { +namespace kvstore { +void Quantize2BitImpl(mshadow::Stream* s, const std::vector& inputs, + const float threshold) { + Quantize2BitKernelLaunch(s, inputs, threshold); +} + +void Dequantize2BitImpl(mshadow::Stream* s, const std::vector& inputs, + const float threshold) { + Dequantize2BitKernelLaunch(s, inputs, threshold); +} +} // namespace kvstore +} // namespace mxnet diff --git a/src/kvstore/gradient_compression.h b/src/kvstore/gradient_compression.h new file mode 100644 index 000000000000..f40b45f5a513 --- /dev/null +++ b/src/kvstore/gradient_compression.h @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file gradient_compression.h + * \brief Gradient compression for kvstore + * \author Rahul Huilgol + */ + +#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_H_ +#define MXNET_KVSTORE_GRADIENT_COMPRESSION_H_ +#include +#include +#include +#include +#include "mxnet/ndarray.h" + +namespace mxnet { +namespace kvstore { + +enum class CompressionType { + kNone, kTwoBit +}; + +struct GradientCompressionParam : public dmlc::Parameter { + std::string type; + float threshold; + DMLC_DECLARE_PARAMETER(GradientCompressionParam) { + DMLC_DECLARE_FIELD(type) + .describe("Type of gradient compression to use, like `2bit` for example"); + DMLC_DECLARE_FIELD(threshold).set_default(0.5) + .describe("Threshold to use for 2bit gradient compression"); + } +}; + +class GradientCompression { + public: + GradientCompression(); + + virtual ~GradientCompression() {} + + /*! + * \brief sets parameters for gradient compression + * \param kwargs a vector of pair of strings. A pair represents key and value + * of the parameter. Will be parsed by GradientCompressionParam + */ + void SetParams(const std::vector >& kwargs); + + /*! + * \brief returns type of compression if any + */ + CompressionType get_type(); + + /*! + * \brief returns as string the enum value of compression type + */ + std::string get_type_str(); + + /*! + * \brief sets two bit gradient compression + * \param threshold float value used for thresholding gradients + */ + void SetTwoBitCompression(const float threshold); + + /*! + * \brief encodes parameters of gc into a string + */ + std::string EncodeParams(); + + /*! + * \brief decodes parameters of gc from a string and assigns them to member variables + */ + void DecodeParams(const std::string &s); + + /*! + * \brief returns compression factor, which is the factor by which size of gradient + * reduces when using a particular type of compression + */ + int GetCompressionFactor(); + + /*! + * \brief returns the size of compressed gradients given an original sized gradient array + */ + int64_t GetCompressedSize(const int64_t original_size); + + /*! + * \brief Issues quantize operation to be scheduled by the engine + * Compresses `from` into `to` and accumulates the quantization error + * into 'residual', using the quantization of type `type_` + * \param from the ndarray containing original data to be quantized + * \param to the target ndarray which contains quantized data + * \param residual the ndarray which accumulates quantization error + * \param priority Priority of the action. + */ + void Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, + mxnet::NDArray *residual, const int priority); + + /*! + * \brief Issues dequantize operation to be scheduled by the engine + * Decompresses `from` into `to` using current parameters of `type` and `threshold` + * \param from the ndarray containing quantized data + * \param to the target ndarray which contains final dequantized data + * \param priority Priority of the action. + */ + void Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority); + + private: + /*! + * \brief denotes the type of gradient compression which has been set + */ + CompressionType type_; + + /*! + * \brief denotes threshold used for quantization and dequantization + * Must be a positive value. All positive gradients will be thresholded to `threshold_` and + * all negative gradients will be thresholded to -1*`threshold_` + */ + float threshold_ = 0; +}; +} // namespace kvstore +} // namespace mxnet +#endif // MXNET_KVSTORE_GRADIENT_COMPRESSION_H_ diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index a288676102cb..059961e1781f 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -48,7 +48,7 @@ KVStore* KVStore::Create(const char *type_name) { kv = new kvstore::KVStoreDist(use_device_comm); if (!has("_async") && kv->IsWorkerNode() && kv->get_rank() == 0) { // configure the server to be the sync mode - kv->SendCommandToServers(kvstore::kSyncMode, ""); + kv->SendCommandToServers(static_cast(kvstore::CommandType::kSyncMode), ""); } #else LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to use " << tname; diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 5e62be8c4c40..002d30d7161c 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -68,7 +68,7 @@ class KVStoreDist : public KVStoreLocal { Barrier(); if (get_rank() == 0) { // stop the executor at servers - SendCommandToServers(kStopServer, ""); + SendCommandToServers(static_cast(CommandType::kStopServer), ""); } } ps::Finalize(barrier_before_exit_); @@ -85,6 +85,15 @@ class KVStoreDist : public KVStoreLocal { } } + void SetGradientCompression(const std::vector > + & kwargs) override { + KVStoreLocal::SetGradientCompression(kwargs); + if (get_rank() == 0) { + SendCommandToServers(static_cast(CommandType::kSetGradientCompression), + gradient_compression_->EncodeParams()); + } + } + void Barrier() override { ps::Postoffice::Get()->Barrier(ps::kWorkerGroup); } @@ -131,6 +140,38 @@ class KVStoreDist : public KVStoreLocal { } private: + /** + * \brief struct for ps keys and lens + */ + struct PSKV { + ps::SArray keys; // n keys + ps::SArray lens; // the length of the i-th value + int size; + }; + + struct ComprPSKV { + PSKV push; + PSKV pull; + }; + + /** + * \brief cache all key partitions + * + * `ps_kv_` is used for pushes and pulls without gradient compression + * `compr_ps_kv_` is used for gradient compression. It contains different + * pskv for push and pull because sizes would be different in both cases. + * Note: `ps_kv_[k]` for some key k may not be the same as `compr_ps_kv_[k].pull` + * This is because sharding may cause slightly different divisions when size is + * not perfectly divisible. + */ + std::unordered_map ps_kv_; + std::unordered_map compr_ps_kv_; + + /** + * \brief serialize access to ps_kv_ or push_ps_kv_/pull_ps_kv_ while encoding keys + */ + std::mutex mu_; + void InitImpl(const std::vector& keys, const std::vector& values) override { CheckUnique(keys); @@ -142,6 +183,7 @@ class KVStoreDist : public KVStoreLocal { // wait until the push is finished for (const int key : keys) { comm_buf_[key].WaitToWrite(); + compr_buf_[key].WaitToWrite(); } } else { // do nothing @@ -181,7 +223,10 @@ class KVStoreDist : public KVStoreLocal { RunContext rctx, Engine::CallbackOnComplete cb) { // convert to ps keys size_t size = recv_buf.shape().Size(); - PSKV& pskv = EncodeKey(key, size); + + PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ? + EncodeDefaultKey(key, size, false) : + EncodeCompressedKey(key, size, false); #if MKL_EXPERIMENTAL == 1 mkl_set_tblob_eager_mode(recv_buf.data()); #endif @@ -189,8 +234,11 @@ class KVStoreDist : public KVStoreLocal { // false means not to delete data when SArray is deleted auto vals = new ps::SArray(data, size, false); // issue pull + int cmd = (gradient_compression_->get_type() != CompressionType::kNone) ? + static_cast(DataHandleType::kCompressedPushPull) : + static_cast(DataHandleType::kDefaultPushPull); CHECK_NOTNULL(ps_worker_)->ZPull( - pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); }); + pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); }); }; CHECK_NOTNULL(Engine::Get())->PushAsync( @@ -200,7 +248,7 @@ class KVStoreDist : public KVStoreLocal { {recv_buf.var()}, FnProperty::kNormal, priority, - PROFILER_MESSAGE("KVStoreDistDefaultPull")); + PROFILER_MESSAGE("KVStoreDistDefaultStoragePull")); comm_->Broadcast(key, recv_buf, grouped_vals[i], priority); } @@ -260,103 +308,121 @@ class KVStoreDist : public KVStoreLocal { GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { - // merge over devcies + // merge over devices int key = uniq_keys[i]; const auto& vals = grouped_vals[i]; NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0]; - auto& send_buf = comm_buf_[key]; const auto storage_type = merged.storage_type(); + auto &comm_buf = comm_buf_[key]; if (merged.ctx().dev_mask() == cpu::kDevMask) { // Start of a push doesn't guarantee that the previous pushes are completed. // This shouldn't affect training of networks though because training involves // a sequence of push, pull, then push. This imposes ordering that the // second push happens after the first pull, and the pull happens after first push. - send_buf = merged; // avoid memory copy + comm_buf = merged; // avoid memory copy } else { - if (send_buf.is_none()) { + if (comm_buf.is_none()) { if (storage_type == kDefaultStorage) { - send_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype()); + comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype()); } else { - send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype()); + comm_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype()); } } - CopyFromTo(merged, &send_buf); + CopyFromTo(merged, &comm_buf); } // push to servers if (storage_type == kDefaultStorage) { - auto push_to_servers = - [this, key, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) { - // convert to ps keys - size_t size = send_buf.shape().Size(); - PSKV& pskv = EncodeKey(key, size); - -#if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(send_buf.data()); -#endif - real_t* data = send_buf.data().dptr(); - // do push. false means no delete - ps::SArray vals(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPush( - pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); }); - }; - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {send_buf.var()}, - {}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistDefaultPush")); + if (gradient_compression_->get_type() == CompressionType::kNone) { + PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true); + PushDefault(key, comm_buf, pskv, priority); + } else { + // Note: gradient compression uses `do_merge` as proxy to + // detect whether the push is initialization of a key or not. + // is_active is false when push is initialization of key + bool is_active = do_merge; + PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active); + // Returns push_pskv if active, else pull_pskv + // we want inactive gc to send uncompressed gradients, + // but sharded in the same way as later pushes would when gc becomes active + if (is_active) { + PushCompressed(key, comm_buf, pskv, priority); + } else { + PushDefault(key, comm_buf, pskv, priority); + } + } } else if (storage_type == kRowSparseStorage) { - PushRowSparse(key, send_buf, priority); + CHECK(gradient_compression_->get_type() == CompressionType::kNone) + << "Gradient compression for row sparse storage type is not supported"; + PushRowSparse(key, comm_buf, priority); } else { LOG(FATAL) << "unknown storage type"; } } } - // pull row sparse weight into `recv_buf` based on indices given by `indices` - void PullRowSparse_(const int key, const NDArray& recv_buf, - const NDArray& indices, int priority) { - using namespace rowsparse; - auto pull_from_servers = [this, key, recv_buf, indices] - (RunContext rctx, Engine::CallbackOnComplete cb) { - // allocate memory for the buffer - size_t num_rows = indices.shape().Size(); - recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)}); + void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int priority) { + auto &small_buf = compr_buf_[key]; + auto &res_buf = residual_[key]; + size_t original_size = comm_buf.shape().Size(); + + // Init the small buffer and residual_ buffer for quantize + if (small_buf.is_none()) { + small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, comm_buf.dtype()); + res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(), + false, comm_buf.dtype()); + res_buf = 0; + } + gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority); + auto push_to_servers = + [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + size_t size = small_buf.shape().Size(); + real_t* data = small_buf.data().dptr(); #if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(recv_buf.data()); + mkl_set_tblob_eager_mode(small_buf.data()); #endif - real_t* data = recv_buf.data().dptr(); - const auto offsets = indices.data().dptr(); - const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim()); - const int64_t size = num_rows * unit_len; - // convert to ps keys in row sparse format - PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, - unit_len, recv_buf.shape()[0]); - if (this->log_verbose_) { - LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: " - << pskv.keys << " size: " << size; - } - auto vals = new ps::SArray(data, size, false); - // copy indices to recv_buf. this needs to be done before ZPull - // because after pull is done, the callback function returns and locks are released. - // at this point, later functions may access the indices variable while copy happens - mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D(), - indices.data().FlatTo1D()); - CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull, - [vals, cb]() { delete vals; cb(); }); - }; - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, + // do push. false means no delete + ps::SArray vals(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPush( + pskv.keys, vals, pskv.lens, + static_cast(DataHandleType::kCompressedPushPull), [cb]() { cb(); }); + }; + // acquire locks on both comm_buf and small_buf so that + // pull (which uses comm_buf) for the same key waits till push finishes + Engine::Get()->PushAsync( + push_to_servers, + pinned_ctx_, + {small_buf.var(), comm_buf.var()}, + {}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistCompressedPush")); + } + + void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) { + auto push_to_servers = + [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + // convert to ps keys + size_t size = send_buf.shape().Size(); + real_t* data = send_buf.data().dptr(); +#if MKL_EXPERIMENTAL == 1 + mkl_set_tblob_eager_mode(send_buf.data()); +#endif + // do push. false means no delete + ps::SArray vals(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPush( + pskv.keys, vals, pskv.lens, + static_cast(DataHandleType::kDefaultPushPull), [cb]() { cb(); }); + }; + Engine::Get()->PushAsync( + push_to_servers, pinned_ctx_, - {indices.var()}, - {recv_buf.var()}, + {send_buf.var()}, + {}, FnProperty::kNormal, priority, - PROFILER_MESSAGE("KVStoreDistRowSparsePull")); + PROFILER_MESSAGE("KVStoreDistDefaultPush")); } // push row sparse gradient @@ -381,9 +447,9 @@ class KVStoreDist : public KVStoreLocal { << pskv.keys << " size: " << size; } ps::SArray vals(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() { - cb(); - }); + CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, + static_cast(DataHandleType::kRowSparsePushPull), + [cb]() { cb(); }); }; Engine::Get()->PushAsync( push_to_servers, @@ -395,6 +461,50 @@ class KVStoreDist : public KVStoreLocal { PROFILER_MESSAGE("KVStoreDistRowSparsePush")); } + + // pull row sparse weight into `recv_buf` based on indices given by `indices` + void PullRowSparse_(const int key, const NDArray& recv_buf, + const NDArray& indices, int priority) { + using namespace rowsparse; + auto pull_from_servers = [this, key, recv_buf, indices] + (RunContext rctx, Engine::CallbackOnComplete cb) { + // allocate memory for the buffer + size_t num_rows = indices.shape().Size(); + recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)}); +#if MKL_EXPERIMENTAL == 1 + mkl_set_tblob_eager_mode(recv_buf.data()); +#endif + real_t* data = recv_buf.data().dptr(); + const auto offsets = indices.data().dptr(); + const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim()); + const int64_t size = num_rows * unit_len; + // convert to ps keys in row sparse format + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, + unit_len, recv_buf.shape()[0]); + if (this->log_verbose_) { + LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: " + << pskv.keys << " size: " << size; + } + auto vals = new ps::SArray(data, size, false); + // copy indices to recv_buf. this needs to be done before ZPull + // because after pull is done, the callback function returns and locks are released. + // at this point, later functions may access the indices variable while copy happens + mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D(), + indices.data().FlatTo1D()); + CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, + static_cast(DataHandleType::kRowSparsePushPull), + [vals, cb]() { delete vals; cb(); }); + }; + CHECK_NOTNULL(Engine::Get())->PushAsync( + pull_from_servers, + pinned_ctx_, + {indices.var()}, + {recv_buf.var()}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistRowSparsePull")); + } + /** * \brief check if the keys are all unique */ @@ -405,33 +515,13 @@ class KVStoreDist : public KVStoreLocal { static_cast(keys.size())); } - /** - * \brief struct for ps keys and lens - */ - struct PSKV { - ps::SArray keys; // n keys - ps::SArray lens; // the length of the i-th value - int size; - }; - - /** - * \brief cache all key partitions - */ - std::unordered_map ps_kv_; - - /** - * \brief serizelize EncodeRowSparseKey and EncodeKey - */ - std::mutex mu_; - /** * \brief convert to keys in ps */ - inline PSKV& EncodeKey(int key, size_t size) { + inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push) { mu_.lock(); PSKV& pskv = ps_kv_[key]; mu_.unlock(); - if (!pskv.keys.empty()) { CHECK_EQ(static_cast(pskv.size), size) << "The value size cannot be changed"; } else { @@ -453,8 +543,8 @@ class KVStoreDist : public KVStoreLocal { pskv.size = 0; for (int i = 0; i < num_servers; ++i) { size_t part_size = - static_cast(round(static_cast(size)/num_servers*(i+1))) - - static_cast(round(static_cast(size)/num_servers*i)); + static_cast(round(static_cast(size)/num_servers*(i+1))) - + static_cast(round(static_cast(size)/num_servers*i)); ps::Key ps_key = krs[i].begin() + key; CHECK_LT(ps_key, krs[i].end()); pskv.keys.push_back(ps_key); @@ -467,6 +557,94 @@ class KVStoreDist : public KVStoreLocal { return pskv; } + /** + * \brief Convert to keys in ps for compressed values + * Divides original array into equal parts for each server + * Populates both push and pull pskv on first call + */ + inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool is_push) { + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + int num_servers = krs.size(); + CHECK_GT(num_servers, 0); + + // represents size of data to be sent + size_t compr_size = gradient_compression_->GetCompressedSize(original_size); + + mu_.lock(); + PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull; + mu_.unlock(); + + if (!pskv.keys.empty()) { + size_t size = (is_push) ? compr_size : original_size; + CHECK_EQ(static_cast(pskv.size), size)<< "The value size can't be changed"; + } else { + // populate both pull and push pskvs + // push pskv has sizes corresponding to compressed data + // pull pskv has decompressed sizes for parts in push_pskv + mu_.lock(); + PSKV& pull_pskv = compr_ps_kv_[key].pull; + PSKV& push_pskv = compr_ps_kv_[key].push; + mu_.unlock(); + + if (original_size < bigarray_bound_) { + // a simple heuristic for load balancing + // send it to a single random picked server + int server = (key * 9973) % num_servers; + ps::Key ps_key = krs[server].begin() + key; + CHECK_LT(ps_key, krs[server].end()); + // meta info + push_pskv.keys.push_back(krs[server].begin() + original_size); + push_pskv.lens.push_back(0); + // data + push_pskv.keys.push_back(ps_key); + pull_pskv.keys.push_back(ps_key); + push_pskv.lens.push_back(compr_size); + pull_pskv.lens.push_back(original_size); + push_pskv.size = compr_size; + pull_pskv.size = original_size; + } else { + // partition it to all servers + push_pskv.size = 0; + pull_pskv.size = 0; + + for (int i = 0; i < num_servers; ++i) { + size_t part_compr, part_orig; + if (i == num_servers-1) { + part_compr = compr_size - push_pskv.size; + part_orig = original_size - pull_pskv.size; + } else { + part_compr = + static_cast (round(static_cast(compr_size)/num_servers*(i+1))) - + static_cast (round(static_cast(compr_size)/num_servers*(i))); + part_orig = part_compr * gradient_compression_->GetCompressionFactor(); + } + + // meta info + ps::Key ps_key_dummy = krs[i].begin() + part_orig; + CHECK_LT(ps_key_dummy, krs[i].end()); + push_pskv.keys.push_back(ps_key_dummy); + push_pskv.lens.push_back(0); + + // data + ps::Key ps_key = krs[i].begin() + key; + CHECK_LT(ps_key, krs[i].end()); + push_pskv.keys.push_back(ps_key); + pull_pskv.keys.push_back(ps_key); + // push_pskv stores lengths of compressed blocks + push_pskv.lens.push_back(part_compr); + // pull_pskv stores lengths of original data + pull_pskv.lens.push_back(part_orig); + push_pskv.size += part_compr; + pull_pskv.size += part_orig; + } + CHECK_EQ(static_cast(push_pskv.size), compr_size); + CHECK_EQ(static_cast(pull_pskv.size), original_size); + CHECK_EQ(push_pskv.lens.size(), num_servers*2); + } + } + return pskv; + } + // Note: this encoding method for row sparse keys doesn't allow cross-layer batching inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows, const int64_t *offsets, const size_t unit_len, @@ -527,7 +705,6 @@ class KVStoreDist : public KVStoreLocal { return pskv; } - /** * \brief for worker to push and pull data */ @@ -540,8 +717,23 @@ class KVStoreDist : public KVStoreLocal { * \brief threshold for partition */ size_t bigarray_bound_; - /// \brief send & recver buffer + /** + * \brief buffer for non-compressed data. + * When gradient compression is active, this is used + * for the data in pull and for original data in push + */ std::unordered_map comm_buf_; + /** + * \brief buffer for compressed data + * Used when gradient compression is active and action + * is push + */ + std::unordered_map compr_buf_; + /** + * \brief residual buffer to accumulate quantization error + * during gradient compression + */ + std::unordered_map residual_; bool log_verbose_; }; diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 15034086186d..83c0c6200447 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -39,10 +39,13 @@ namespace mxnet { namespace kvstore { -static const int kRowSparsePushPull = 1; -static const int kDefaultPushPull = 0; -static const int kStopServer = -1; -static const int kSyncMode = -2; +enum class CommandType { + kController, kStopServer, kSyncMode, kSetGradientCompression +}; + +enum class DataHandleType { + kDefaultPushPull, kCompressedPushPull, kRowSparsePushPull +}; /** * \brief executor runs a function using the thread called \ref Start @@ -116,6 +119,7 @@ class KVStoreDistServer { ps_server_->set_request_handle( std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); sync_mode_ = false; + gradient_compression_ = std::make_shared(); log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } @@ -147,11 +151,15 @@ class KVStoreDistServer { }; void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) { - if (recved.head == kStopServer) { + CommandType recved_type = static_cast(recved.head); + if (recved_type == CommandType::kStopServer) { exec_.Stop(); - } else if (recved.head == kSyncMode) { + } else if (recved_type == CommandType::kSyncMode) { sync_mode_ = true; + } else if (recved_type == CommandType::kSetGradientCompression) { + gradient_compression_->DecodeParams(recved.body); } else { + // this uses value 0 for message id from frontend // let the main thread to execute ctrl, which is necessary for python exec_.Exec([this, recved]() { CHECK(controller_); @@ -164,8 +172,11 @@ class KVStoreDistServer { void DataHandleEx(const ps::KVMeta& req_meta, const ps::KVPairs& req_data, ps::KVServer* server) { - if (req_meta.cmd == kRowSparsePushPull) { + DataHandleType recved_type = static_cast(req_meta.cmd); + if (recved_type == DataHandleType::kRowSparsePushPull) { DataHandleRowSparse(req_meta, req_data, server); + } else if (recved_type == DataHandleType::kCompressedPushPull) { + DataHandleCompressed(req_meta, req_data, server); } else { DataHandleDefault(req_meta, req_data, server); } @@ -358,10 +369,91 @@ class KVStoreDistServer { } } + void DefaultStorageResponse(int key, const NDArray& stored, + const ps::KVMeta& req_meta, + const ps::KVPairs &req_data, + ps::KVServer* server) { + ps::KVPairs response; + CHECK(!stored.is_none()) << "init " << key << " first"; + auto len = stored.shape().Size(); + response.keys = req_data.keys; + response.lens = {len}; + // TODO(mli) try to remove this CopyFrom + response.vals.CopyFrom(static_cast(stored.data().dptr_), len); + server->Response(req_meta, response); + } + + void DataHandleCompressed(const ps::KVMeta& req_meta, + const ps::KVPairs &req_data, + ps::KVServer* server) { + if (req_meta.push) { + // there used several WaitToRead, this is because \a recved's memory + // could be deallocated when this function returns. so we need to make sure + // the operators with \a NDArray are actually finished + + // first for dummy key which represents original size of array, whose len is 0 + CHECK_EQ(req_data.keys.size(), (size_t)2); + CHECK_EQ(req_data.lens.size(), (size_t)2); + CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[1]); + + int original_size = DecodeKey(req_data.keys[0]); + int key = DecodeKey(req_data.keys[1]); + auto& stored = store_[key]; + + size_t ds[] = {(size_t)req_data.lens[1]}; + TShape dshape(ds, ds + 1); + TBlob recv_blob((real_t*) req_data.vals.data(), // NOLINT(*) + dshape, cpu::kDevMask); + NDArray recved = NDArray(recv_blob, 0); + + NDArray decomp_buf = decomp_buf_[key]; + dshape = TShape{(int64_t) original_size}; + + if (decomp_buf.is_none()) { + decomp_buf = NDArray(dshape, Context()); + } + + if (stored.is_none()) { + stored = NDArray(dshape, Context()); + gradient_compression_->Dequantize(recved, &stored, 0); + server->Response(req_meta); + stored.WaitToRead(); + } else if (sync_mode_) { + // synced push + auto& merged = merge_buf_[key]; + if (merged.array.is_none()) { + merged.array = NDArray(dshape, Context()); + } + if (merged.request.size() == 0) { + gradient_compression_->Dequantize(recved, &merged.array, 0); + } else { + gradient_compression_->Dequantize(recved, &decomp_buf, 0); + merged.array += decomp_buf; + } + merged.request.push_back(req_meta); + ApplyUpdates(key, &merged, &stored, server); + } else { + // async push + gradient_compression_->Dequantize(recved, &decomp_buf, 0); + exec_.Exec([this, key, &decomp_buf, &stored]() { + CHECK(updater_); + updater_(key, decomp_buf, &stored); + }); + server->Response(req_meta); + stored.WaitToRead(); + } + } else { // pull + CHECK_EQ(req_data.keys.size(), (size_t)1); + CHECK_EQ(req_data.lens.size(), (size_t)0); + int key = DecodeKey(req_data.keys[0]); + DefaultStorageResponse(key, store_[key], req_meta, req_data, server); + } + } + void DataHandleDefault(const ps::KVMeta& req_meta, const ps::KVPairs &req_data, ps::KVServer* server) { - CHECK_EQ(req_meta.cmd, kDefaultPushPull); + CHECK_EQ(req_meta.cmd, static_cast(DataHandleType::kDefaultPushPull)); // do some check CHECK_EQ(req_data.keys.size(), (size_t)1); if (req_meta.push) { @@ -410,15 +502,7 @@ class KVStoreDistServer { stored.WaitToRead(); } } else { - // pull - ps::KVPairs response; - CHECK(!stored.is_none()) << "init " << key << " first"; - auto len = stored.shape().Size(); - response.keys = req_data.keys; - response.lens = {len}; - // TODO(mli) try to remove this CopyFrom - response.vals.CopyFrom(static_cast(stored.data().dptr_), len); - server->Response(req_meta, response); + DefaultStorageResponse(key, stored, req_meta, req_data, server); } } @@ -427,21 +511,44 @@ class KVStoreDistServer { return key - kr.begin(); } + /** - * \brief user defined + * \brief user defined mode for push */ bool sync_mode_; KVStore::Controller controller_; KVStore::Updater updater_; + /** + * \brief store_ contains the value at kvstore for each key + */ std::unordered_map store_; + + /** + * \brief merge_buf_ is a buffer used if sync_mode is true. It represents + * values from different workers being merged. The store will be updated + * to this value when values from all workers are pushed into this buffer. + */ std::unordered_map merge_buf_; + /** + * \brief decomp_buf_ is a buffer into which compressed values are + * decompressed before merging to the store. used when compress_!='none' + */ + std::unordered_map decomp_buf_; + Executor exec_; ps::KVServer* ps_server_; // whether to LOG verbose information bool log_verbose_; + + /** + * \brief gradient compression object. + * starts with none, used after SetGradientCompression sets the type + * currently there is no support for unsetting gradient compression + */ + std::shared_ptr gradient_compression_; }; } // namespace kvstore diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 4038185244a7..d3ef35901a16 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -58,6 +58,7 @@ class KVStoreLocal : public KVStore { comm_ = new CommCPU(); } pinned_ctx_ = comm_->pinned_ctx(); + gradient_compression_ = std::make_shared(); } virtual ~KVStoreLocal() { @@ -135,6 +136,11 @@ class KVStoreLocal : public KVStore { PullRowSparseImpl(keys, val_rowids, priority); } + void SetGradientCompression(const std::vector > + & kwargs) override { + gradient_compression_->SetParams(kwargs); + } + private: virtual void InitImpl(const std::vector& keys, const std::vector& values) { @@ -144,6 +150,7 @@ class KVStoreLocal : public KVStore { local_[keys[i]] = values[i].Copy(pinned_ctx_); comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } + comm_->SetGradientCompression(gradient_compression_); } virtual void PushImpl(const std::vector& keys, diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 900d6bb6afb7..df85fe586054 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -23,7 +23,8 @@ import mxnet as mx import numpy as np import numpy.random as rnd -import time +from mxnet.test_utils import assert_almost_equal +from test_kvstore import compute_expected_2bit_quantization def check_diff_to_scalar(A, x, rank=None): """ assert A == x""" @@ -39,6 +40,7 @@ def check_diff_to_scalar(A, x, rank=None): rate = 2 shape = (2, 3) +irregular_shape = (1211,1211) big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND kv = mx.kv.create('dist_sync') @@ -57,6 +59,17 @@ def init_kv(): kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) return kv, my_rank, nworker +def init_kv_compressed(kv): + threshold = 0.5 + kv.set_gradient_compression({'type': '2bit', 'threshold':threshold}) + # init kv compression keys + kv.init('11221', mx.nd.zeros(big_shape)) + kv.init('112221', mx.nd.zeros(irregular_shape)) + kv.init('1121', mx.nd.zeros(shape)) + # to test inactive mode + kv.init('1122', mx.nd.ones(shape)) + return kv, threshold + def test_sync_push_pull(): kv, my_rank, nworker = init_kv() def check_default_keys(kv, my_rank, nworker): @@ -173,11 +186,114 @@ def check_big_row_sparse_keys(kv, my_rank, nworker): expected[row] = updated_val[row] check_diff_to_scalar(val, expected, rank=my_rank) + def check_compr_residual(kv, threshold, nworker): + for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]: + # doesn't meet threshold + kv.push(k, mx.nd.ones(s)*0.4) + val=mx.nd.zeros(s) + kv.pull(k,val) + check_diff_to_scalar(val, 0) + + # just meets threshold with residual + kv.push(k, mx.nd.ones(s)*(threshold - 0.4)) + val2 = mx.nd.zeros(s) + kv.pull(k,val2) + curval = threshold * rate * nworker + check_diff_to_scalar(val2, curval) + + # doesn't meet threshold + kv.push(k, mx.nd.ones(s)*0.2) + val3= mx.nd.zeros(s) + kv.pull(k, val3) + check_diff_to_scalar(val3, curval) + + # exceeds again + kv.push(k, mx.nd.ones(s)*(threshold-0.2)) + val4 = mx.nd.zeros(s) + kv.pull(k,val4) + curval += threshold*rate*nworker + check_diff_to_scalar(val4, curval) + # residual is 0 now + + def check_compr_ones(kv, threshold, nworker): + for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]: + val = mx.nd.zeros(s) + kv.pull(k, val) + curval = val[0][0].asnumpy()[0] + kv.push(k,mx.nd.ones(s)*threshold) + val2 = mx.nd.zeros(s) + kv.pull(k, val2) + newval = curval + rate*nworker*threshold + check_diff_to_scalar(val2, newval) + # residual = 0 again + + def check_compr_pull_before_push(kv): + for k,s in [('1121', shape),('112221',irregular_shape), + ('11221', big_shape), ('1122',shape)]: + if k=='1122': + # tests that GC is not used for init of a key + val = mx.nd.zeros(s) + kv.pull(k, val) + check_diff_to_scalar(val, 1) + else: + val = mx.nd.ones(s) + kv.pull(k, val) + check_diff_to_scalar(val, 0) + + def check_compr_zero(kv): + for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]: + kv.push(k, mx.nd.zeros(s)) + # to check that all are set to 0s + val = mx.nd.ones(s) + kv.pull(k, val) + check_diff_to_scalar(val, 0) + + def check_compr_random(kv, threshold, nworker): + # set a seed so all workers generate same data. knowing this helps + # calculate expected value after pull + mx.random.seed(123) + rnd.seed(123) + nrepeat = 5 + compr_random_keys_shapes = [('2121', shape),('212221',irregular_shape),('21221', big_shape)] + # use new keys so residual is 0 for calculation of expected + for k,s in compr_random_keys_shapes: + kv.init(k, mx.nd.zeros(s)) + for k,s in compr_random_keys_shapes: + curr_residual = np.zeros(s) + for l in range(nrepeat): + orig_val = mx.nd.zeros(s) + kv.pull(k, orig_val) + + grad = mx.nd.array(rnd.rand(s[0], s[1])) + # creates a copy because push changes grad because of assignment + grad_cpy = mx.nd.array(grad) + kv.push(k, grad) + val = mx.nd.zeros(s) + kv.pull(k, val) + + diff = val - orig_val + + # compute expected by using simulation of operator + compr, curr_residual, decompr = compute_expected_2bit_quantization(grad_cpy, curr_residual, threshold) + decompr *= nworker * rate + assert_almost_equal(diff.asnumpy(), decompr) + + print ('worker '+str(my_rank)+' started with non compression tests') check_default_keys(kv, my_rank, nworker) check_row_sparse_keys(kv, my_rank, nworker) check_row_sparse_keys_with_zeros(kv, my_rank, nworker) check_big_row_sparse_keys(kv, my_rank, nworker) - print('worker ' + str(my_rank) + ' is done') + print('worker ' + str(my_rank) + ' is done with non compression tests') + + # don't run non compressed keys after this as kvstore now is set to compressed + print ('worker '+str(my_rank)+' started with compression tests') + kv, threshold = init_kv_compressed(kv) + check_compr_pull_before_push(kv) + check_compr_zero(kv) + check_compr_residual(kv, threshold, nworker) + check_compr_ones(kv, threshold, nworker) + check_compr_random(kv, threshold, nworker) + print('worker ' + str(my_rank) + ' is done with compression tests') def test_sync_init(): def check_init(kv, cur_keys, cur_shape, device=False): diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py index 081bc9c5a456..a14feac7a3aa 100644 --- a/tests/nightly/test_kvstore.py +++ b/tests/nightly/test_kvstore.py @@ -21,17 +21,59 @@ sys.path.insert(0, "../../python/") import mxnet as mx import numpy as np +import numpy.random as rnd +import copy -keys = [3, 5, 7] -# let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND -shapes = [(4, 4), (100, 100), (2000, 2000)]; +from mxnet.test_utils import assert_almost_equal -lr = .1 -nworker = 4 -nrepeat = 10 +def check_diff_to_scalar(A, x, rank=None): + """ assert A == x""" + assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) -## generate data -data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)] +def compute_expected_2bit_quantization(arr, curr_residual, threshold): + from struct import pack,unpack + def bits2int(bits): + bits = [int(x) for x in bits[::-1]] + x = 0 + for i in range(len(bits)): + x += bits[i]*2**i + return x + + def as_float32(s): + return unpack("f",pack("I", bits2int(s)))[0] + + # str_quant stores the quantized representation as a sequence of bits + str_quant = '' + new_residual = [] + decompr = [] + + arr_npy = arr.asnumpy() + for i, a in np.ndenumerate(arr_npy): + a += curr_residual[i] + if a >= threshold: + str_quant += '11' + new_residual.append(a - threshold) + decompr.append(threshold) + elif a <= (-1*threshold): + str_quant += '10' + new_residual.append(a + threshold) + decompr.append(-1*threshold) + else: + str_quant += '00' + new_residual.append(a) + decompr.append(0) + # append extra bits when size of array not a factor of 16 + if len(str_quant)%16 != 0: + str_quant += '0'*(16 - len(str_quant)%16) + + compr = [] + # converts the string generated into integers 32chars at a time + i = 0 + while i