-
Notifications
You must be signed in to change notification settings - Fork 6.8k
2bit gradient compression #8662
Changes from 250 commits
407c01a
8cbb7f6
0dd1874
bbd21e4
72640e9
aaafa84
2d85430
5a99e6a
861fca5
54c6f06
03e47a4
fedd4b4
13ff1bc
b75d7ca
b84b762
260b606
7d78e3a
1b550eb
2a90dae
baba1d8
aac5292
b63673a
ce0f3b2
3d3ac92
f797271
5807469
7dbce8b
d1fdfc4
112b683
fe10b7a
cea9199
0ad7acc
e44f8fb
09ceb54
09971bf
237dc9b
2ffcfeb
c91cca3
261e244
54bb44b
6baa79e
91df1b3
ec8bbc7
39f2e44
fd42f8c
f743ab1
e18331f
7ec80ed
e24bc9b
3657869
95e073e
d595fa5
f6e2b92
6cf214e
4b0e756
bc31c4c
8630dbe
3a8f709
3f9256e
d6be11f
75363bb
e34d263
c0894b1
381941e
d5c37eb
c0dc329
38e94f5
3f27a14
5888e31
dbcec87
2d53722
0bc1da3
d120e9a
4d5315a
13b2ce5
648e0e9
a9bdcdc
1fac41f
1fdbdf0
5b4e405
c1a9add
24bc361
3a7985a
953ca95
8c6ba4f
96fa9b3
baae59d
2d5696e
36e1b51
f73e463
647b2ef
8fd1cde
a0c2a2a
52f47e5
a334924
be8d01d
bb473a4
c8cfae5
e2b405a
3ee9249
15e4f9c
39f3bac
50fa0fa
6bb1411
558e1b5
69f9e11
48591f2
4146690
71296f8
25cdda3
3234aa4
5e333bf
72d28b6
8357301
287e040
9290c23
99154c9
52b6905
b560d25
60b1b69
e3153ce
eeb454b
2f936ee
5e849e1
69608da
847a7f2
2f8e86e
69af018
32b9e7c
39e2d22
bf3ea61
9c9ae58
5c42ebb
49e4ee0
44b20e7
f74d317
b67a392
804f7d1
9629410
0feabd5
d3e4df8
d6801dd
d63e0b4
e97c477
18df71e
9f480ee
92dd85f
505d3e7
51d0349
3e17ec3
12d4499
248908c
35b42f7
cabb948
b8d2b50
e09a8fd
b2c9f29
6e651ed
0c48ebb
fe66ef9
f5204ca
5e473b2
88cc0fd
5c7a1ff
5283035
5294d4d
6bb9933
2a7f2f5
feaae67
080882b
a5abca4
7e5301d
594b40c
642cfe4
3c8686a
483d610
bc245b4
2f257e5
106feb8
46cbf5c
a351723
c84af06
d316700
f044830
e8aa9b5
5f130dd
c1fbeb7
180af91
4e0bded
8a083d2
c4d9a45
3a2060b
193586e
75399ff
ac2886a
e6e41e4
a7d6c68
970acbb
7ec0655
b72df8e
2913b56
f2e2469
30eae11
f5ddf7f
d3b668d
f19e7ee
42cdbdf
82f7964
b96c3c0
4bb3701
5c0114c
9557795
3945a8f
bbdfe1a
80957a7
222f33c
dc3b8e6
ac55cdc
48d54df
eea86ff
b37d36d
aa6fb6f
b694f15
e4c46e0
b84f179
2578883
b60b3fb
8328923
aa242b8
62c5255
b8b1d66
b66a3f2
f32b391
0743f60
6fd68f7
d7aea02
40f71f8
eabc503
806586f
6070450
2289129
f41e102
5acbc9a
3c1bacb
18d6a90
dfe7a7d
4b6f34a
30a197b
3073bf7
d5e4b2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
#include <string> | ||
#include <functional> | ||
#include <atomic> | ||
#include "../../src/kvstore/gradient_compression.h" | ||
#include "./ndarray.h" | ||
#if MXNET_USE_DIST_KVSTORE | ||
#include "ps/ps.h" | ||
|
@@ -64,6 +65,14 @@ class KVStore { | |
*/ | ||
inline const std::string& type() { return type_; } | ||
|
||
/** | ||
* \brief Set parameters to use low-bit compressed gradients | ||
* \param compression_type type of compression | ||
* \param threshold threshold for 2bit compression | ||
*/ | ||
virtual void SetGradientCompression(const std::string& compression_type, | ||
const float threshold) = 0; | ||
|
||
/*! | ||
* \brief Initialize a list of key-value pair to the store. | ||
* | ||
|
@@ -387,6 +396,13 @@ class KVStore { | |
*/ | ||
std::string type_; | ||
|
||
/** \brief Gradient compression object starts with GC_NONE mode | ||
* Used if SetGradientCompression sets the type. | ||
* Currently there is no support for un-setting gradient compression | ||
*/ | ||
std::shared_ptr<kvstore::GradientCompression> gradient_compression_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if user uses |
||
|
||
|
||
/** | ||
* \brief whether to do barrier when finalize | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ | |
from .ndarray import NDArray | ||
from .ndarray import _ndarray_cls | ||
from .base import _LIB | ||
from .base import check_call, c_array, c_str, string_types, mx_uint, py_str | ||
from .base import check_call, c_array, c_str, string_types, numeric_types, mx_uint, mx_float, py_str | ||
from .base import NDArrayHandle, KVStoreHandle | ||
from . import optimizer as opt | ||
|
||
|
@@ -349,6 +349,77 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): | |
check_call(_LIB.MXKVStorePullRowSparse( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority))) | ||
|
||
def set_gradient_compression(self, compression_params=(('compression', '2bit'),)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there should be a default value at all. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename key |
||
""" Specifies type of low-bit quantization for gradient compression if any, \ | ||
and additional arguments depending on the type of compression being used. | ||
|
||
2bit Gradient Compression takes a positive float `threshold`. | ||
The technique works by thresholding values such that positive values in the | ||
gradient above threshold will be set to threshold. Negative values whose absolute | ||
values are higher than threshold, will be set to the negative of threshold. | ||
Values whose absolute values are less than threshold will be set to 0. | ||
By doing so, each value in the gradient is in one of three states. 2bits are | ||
used to represent these states, and every 16 float values in the original | ||
gradient can be represented using one float. This compressed representation | ||
can reduce communication costs. The difference between these thresholded values and | ||
original values is stored at the sender's end as residual and added to the | ||
gradient in the next iteration. | ||
|
||
When kvstore is 'local', gradient compression is used to reduce communication | ||
between multiple devices (gpus). Gradient is quantized on each GPU which | ||
computed the gradients, then sent to the GPU which merges the gradients. This | ||
receiving GPU dequantizes the gradients and merges them. Note that this | ||
increases memory usage on each GPU because of the residual array stored. | ||
|
||
When kvstore is 'dist', gradient compression is used to reduce communication | ||
from worker to sender. Gradient is quantized on each worker which | ||
computed the gradients, then sent to the server which dequantizes | ||
this data and merges the gradients from each worker. Note that this | ||
increases CPU memory usage on each worker because of the residual array stored. | ||
Only worker to server communication is compressed in this setting. | ||
If each machine has multiple GPUs, currently this GPU to GPU communication is | ||
not compressed. Server to worker communication (in the case of pull) is also not | ||
compressed. | ||
|
||
To use 2bit compression, we need to specify `compression` as `2bit`. | ||
Only specifying `compression` would use default value for the threshold. | ||
To completely specify the arguments for 2bit compression, we would need to pass | ||
a dictionary which includes `threshold` like: | ||
{'compression': '2bit', 'threshold': 0.5} | ||
|
||
Parameters | ||
---------- | ||
compression_params : dict | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this doc render correctly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
A dictionary specifying the type and parameters for gradient compression. | ||
The key `compression` in this dictionary is a | ||
required string argument and specifies the type of gradient compression. | ||
Other keys in this dictionary are optional and specific to the type | ||
of gradient compression. Defaults to (('compression', '2bit'),). | ||
The default value is not a dict, | ||
just to avoid pylint warning on dangerous default values. | ||
""" | ||
if compression_params: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. superfluous if? |
||
if not isinstance(compression_params, dict): | ||
raise ValueError("compression_params needs to be a dictionary") | ||
if 'compression' not in compression_params: | ||
raise ValueError('compression_params requires `compression` to be set') | ||
elif not isinstance(compression_params['compression'], string_types): | ||
raise TypeError('compression must be a string') | ||
elif compression_params['compression'] not in ['none', '2bit']: | ||
raise ValueError('Unsupported type of compression') | ||
|
||
if compression_params['compression'] == '2bit': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These parsing should be done in backend with The frontend should pass strings of key value pairs. |
||
if 'threshold' in compression_params: | ||
if not isinstance(compression_params['threshold'], numeric_types): | ||
raise TypeError('threshold must be a numeric type') | ||
if compression_params['threshold'] <= 0: | ||
raise ValueError('threshold must be greater than 0') | ||
else: | ||
compression_params['threshold'] = 0.5 | ||
|
||
check_call(_LIB.MXKVStoreSetGradientCompression( | ||
self.handle, c_str(compression_params['compression']), | ||
mx_float(compression_params['threshold']))) | ||
|
||
def set_optimizer(self, optimizer): | ||
""" Registers an optimizer with the kvstore. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API should be
MXKVStoreSetGradientCompression(KVStoreHandle handle, mx_uint num_params, const char **keys, const char **vals)
The values should be parsed in backend with
dmlc::Parameter