-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work to bring NCCL to MXNet! I haven't finished reading all code, a few comments so far..
include/mxnet/kvstore.h
Outdated
@@ -162,7 +162,7 @@ class KVStore { | |||
* \param priority Priority of the action. | |||
*/ | |||
virtual void Pull(const std::vector<int>& keys, | |||
const std::vector<NDArray*>& values, | |||
const std::vector<NDArray>& values, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really necessary to change the interface here? Was this causing memory issues in pool_storage_manager?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not really necessary, but I think it makes the interface more consistent (and C API interface is not changed so it should not affect frontend languages) - it makes it simpler to reuse code between push and pull for example. It was originally introduced as part of the previous NCCL integration effort (that never got merged) to accomodate allreduce API interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it still required for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, and I changed it back (although this makes some of the other functions kind of ugly when you need to support both pointers and references to ndarrays).
python/mxnet/model.py
Outdated
# Use aggregation by default only with NCCL | ||
default_batch = 16 if 'nccl' in kvstore.type else 1 | ||
batch = int(os.getenv('MXNET_UPDATE_AGGREGATION_SIZE', default_batch)) | ||
while(start < size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: while start < size:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change.
python/mxnet/model.py
Outdated
size = len(grad_arrays) | ||
start = 0 | ||
# Use aggregation by default only with NCCL | ||
default_batch = 16 if 'nccl' in kvstore.type else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does the magic number 16
come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance experiments :-). User may change that value with env variable though.
python/mxnet/model.py
Outdated
# pull back the weights | ||
kvstore.pull(name, arg_list, priority=-index) | ||
kvstore.pull(param_names[start:end], param_arrays[start:end], priority=-start) | ||
start = end | ||
|
||
def _update_params(param_arrays, grad_arrays, updater, num_device, | ||
kvstore=None, param_names=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function not updated with batch aggregation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not used in GPU environment (only in local
kvstore), and this aggregation is not "real" aggregation (although it makes it possible to implement actual aggregation in the future, either explicit by copying to long buffer and modifying pointers or implicit inside NCCL), so there is no sense in enabling it here.
What this aggregation does is basically delay the synchronization, so that multiple NCCL kernels may work at the same time, having better chance at saturating available links.
src/kvstore/comm.h
Outdated
@@ -58,7 +76,10 @@ class Comm { | |||
*/ | |||
virtual void Broadcast( | |||
int key, const NDArray& src, | |||
const std::vector<NDArray*> dst, int priority) = 0; | |||
const std::vector<NDArray> dst, int priority) = 0; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add brief comments for these two methods? Are they only for nccl? Do we want to declare it only when MXNET_USE_NCCL is set?
src/kvstore/kvstore_local.h
Outdated
@@ -61,7 +61,10 @@ class KVStoreLocal : public KVStore { | |||
} | |||
|
|||
virtual ~KVStoreLocal() { | |||
delete comm_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think delete nullptr
is safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will change that to only add comm_ = nullptr;
after deleting. I added the check because the previous version of the destructor was not setting comm_ to nullptr and it gave me segfault when kvstore_nccl called both its own destructor and kvstore_local destructor (both trying to delete comm_).
const std::vector<NDArray*>& values, | ||
std::vector<int> *uniq_keys, | ||
std::vector<std::vector<NDArray*>> *grouped_vals) { | ||
virtual void GroupKVPairsPull(const std::vector<int>& keys, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is virtual
added here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the validator
may be different in inheriting classes. That is the case for NCCL kvstore - it inherits from local kvstore to remove copy-paste code, but I can't support sparse types there.
include/mxnet/storage.h
Outdated
@@ -78,6 +78,16 @@ class Storage { | |||
*/ | |||
virtual ~Storage() {} | |||
/*! | |||
* \brief Returns mutex used by storage manager | |||
*/ | |||
std::mutex& GetMutex(Context::DeviceType dev) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add brief description when mutex is required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mutex is not really required outside of NCCL (and only for GPU allocations). See discussion from the previous NCCL PR: #5521 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any performance number for using nccl in MXNet?
Also adding @rahul003 for review
src/kvstore/comm.h
Outdated
for (size_t i = 0; i < src.size(); ++i) { | ||
NCCLEntry cur = nccl_data_[src[i].ctx().dev_id]; | ||
if (i == root_id) { | ||
MSHADOW_TYPE_SWITCH(src[i].dtype(), DType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Please fix indentation
src/kvstore/comm.h
Outdated
return ncclDouble; | ||
case mshadow::kUint8: | ||
return ncclChar; | ||
case mshadow::kInt32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should kInt64 also be added??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
/** | ||
* \brief store data in local machine using NCCL | ||
*/ | ||
class KVStoreNCCL : public KVStoreLocal { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean multi-machine with GPUs cannot benefit from nccl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently no - it is a future work. The biggest problem is how to bootstrap NCCL in multi-node scenario, and I do not yet understand MXNet's distributed kvstore enough to use it for that task.
tests/python/gpu/test_nccl.py
Outdated
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# pylint: skip-file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you enable pylint here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most current tests have pylint disabled, I copied that part from them. Sure, I will do that.
tests/python/gpu/test_nccl.py
Outdated
a = mx.nd.ones(shape, mx.gpu(0)) | ||
cur_key = str(key*max(gpus)+n_gpus) | ||
kv_nccl.init(cur_key, a) | ||
arr_list = [mx.nd.ones(shape, mx.gpu(x)) for x in xrange(n_gpus)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xrange
is not py3 compatible. Can you replace it with range
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
src/kvstore/comm.h
Outdated
@@ -32,6 +35,21 @@ | |||
#include "mxnet/ndarray.h" | |||
#include "../ndarray/ndarray_function.h" | |||
#include "../operator/tensor/sparse_retain-inl.h" | |||
|
|||
#if MXNET_USE_NCCL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we merge the two #if MXNET_USE_NCCL
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Linter would complain about system header being after local headers (I assume those are the 2 #if
s you would want merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, then we can leave them as is
src/kvstore/comm.h
Outdated
dev_ids.push_back(e.ctx().dev_id); | ||
} | ||
std::sort(dev_ids.begin(), dev_ids.end()); | ||
CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to check here that the set of devices don't change during the training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Handling multiple sets of devices can be done, but not with the structure imposed by the Comm
class. Basically in order to keep the benefits of batching I need to ensure that the root for the reduction will be the same for the whole batch, but I know who participates only during the actual push/pull, not during Init, and all of the data structures are initialized only once during the first push. This BTW should also be checked in the device
kvstore (and currently is not), otherwise you can do something like this:
>>> import mxnet as mx
>>> kv = mx.kv.create("device")
>>> shape = (2,3)
>>> kv.init(4, mx.nd.ones(shape))
>>> gpus = [mx.gpu(i) for i in range(2)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push(4, b)
>>> a = mx.nd.zeros(shape)
>>> kv.pull(4, out = a)
>>> a
[[ 2. 2. 2.]
[ 2. 2. 2.]]
<NDArray 2x3 @cpu(0)>
>>> gpus = [mx.gpu(i) for i in range(4)]
>>>
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push(4, b)
Segmentation fault
src/kvstore/comm.h
Outdated
using KeyAttrs = std::tuple<int, TShape, int>; | ||
// try to allocate buff on device evenly | ||
void InitMergeBuffer(const std::vector<Context>& devs) { | ||
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't look like sorted_key_attrs_
is actually sorted in this case. And also doesn't look like they need to be sorted here. If so, can you use a different variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
src/kvstore/comm.h
Outdated
auto& buf = merge_buf_[key]; | ||
Context ctx; | ||
// use devs[0] as root | ||
ctx = devs[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three lines are strange. If you want to make devs[0] root always, please directly use devs[0] as arg in line 927 instead.
As I understand, the buffers are no longer evenly allocated on all devices? If so, please remove the comment on 917.
And why don't we want to do that anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove the comment.
We want to use devs[0]
every time because this helps hide some of the latencies and keep more inter-GPU links occupied if the flow of data is always the same.
src/kvstore/comm.h
Outdated
} | ||
}, | ||
Context::CPU(), | ||
const_vars, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be more compact if const_vars is replaced with {}. Same in below function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
src/kvstore/comm.h
Outdated
#include "../common/cuda_utils.h" | ||
|
||
#ifndef NCCL_MAJOR | ||
#define NCCL_MAJOR 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment explaining this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
src/kvstore/comm.h
Outdated
virtual void Init(int key, const NDArrayStorageType stype, | ||
const TShape& shape, int dtype = mshadow::kFloat32) = 0; | ||
virtual void Init(int key, const NDArrayStorageType stype, const TShape& shape, | ||
int dtype = mshadow::kFloat32, Context pinned_ctx = Context::CPUPinned(0)) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you changing the pinnned context to something other than CPU in
kVStoreNCCL? Or is this change just to generalize the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a relict of a previous iteration of the code - I will remove it.
src/kvstore/comm.h
Outdated
} | ||
} | ||
} else { | ||
auto& buf = merge_buf_[key]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please add some comments explaining the flow of data here
src/kvstore/comm.h
Outdated
if (dst.size() == 1) return; | ||
std::vector<Engine::VarHandle> mutable_vars; | ||
for (size_t i = 0; i < dst.size(); ++i) { | ||
if ( i != root_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add braces around line 848. This style has potential for future bugs :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
src/kvstore/comm.h
Outdated
size_t root_id = -1; | ||
for (size_t i = 0; i < dst.size(); ++i) { | ||
if (dst[i].ctx().dev_id == root) { | ||
root_id = i; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please wrap such tasks in small wrapper functions so the code becomes more readable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did for this snippet. Unfortunately the other small functions like this have slightly different elements so it's not as simple to wrap them and reuse. I also added some comments on what is being done.
Thanks for addressing all these review comments. Is anyone helping you to setup the CI test with NCCL? |
Not really, no. |
python/mxnet/model.py
Outdated
# push gradient, priority is negative index | ||
kvstore.push(name, grad_list, priority=-index) | ||
kvstore.push(param_names[start:end], grad_arrays[start:end], priority=-start) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of this? Why should it be done in frontend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enables aggregation of the reductions (in NCCL 2.1 it is not a real aggregation, since each reduction is being done in its own launch and they benefit from lack of synchronization between reductions in a group, but NCCL 2.2 will introduce real aggregation support).
It needs to be done in the frontend, because kvstore itself does not have any information on which gradients should be aggregated, in which order and how much time is it supposed to wait before real launch - mxnet's dependency engine does not really allow for that. That's why we need to provide data about all of the reductions in a group to benefit from aggregation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it valid to push multiple keys in one call? If its valid why not push all keys in one call and decide aggregation size in backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the if grad_list[0] is None:
logic is removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is valid to push multiple keys in 1 call. But you still want to set priority to those reductions in frontend - kvstore should not assume the order in which you push gradients to it, so aggregating everything would get rid of priority altogether.
Regarding the missing logic - good point, I forgot about it - will fix.
@mbaijal We'll need your help setting up a new CI test with NCCL build after this is merged. |
src/kvstore/comm.h
Outdated
/** | ||
* \brief copy from src to dst[i] for every i | ||
*/ | ||
virtual void Broadcast( | ||
int key, const NDArray& src, | ||
const std::vector<NDArray*> dst, int priority) = 0; | ||
|
||
#if MXNET_USE_NCCL | ||
// Aggregated reductions | ||
virtual void Reduce(const std::vector<int> keys, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If NCCL is going to do everything differently then it shouldn't inherit the comm interface. Do this in KVStoreNCCL directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, will do.
@@ -88,7 +92,7 @@ class GPUPooledStorageManager final : public StorageManager { | |||
}; // class GPUPooledStorageManager | |||
|
|||
void GPUPooledStorageManager::Alloc(Storage::Handle* handle) { | |||
std::lock_guard<std::mutex> lock(mutex_); | |||
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So now memory allocation on all gpus share the same mutex? This could slow down memory allocation. Especially when using Gluon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you suggest a Gluon benchmark to check the performance impact? If the impact is too big we can move to cooperative launch in NCCL2 (that would not need a shared mutex anymore), but this would mean it will be compatible only with NCCL 2 and CUDA 9. Also, currently cooperative launch is slower than parallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be looking for places to insert many-read/single-write shared mutexes in places to help performance. Do you think that this would be a good candidate for this, or is there a reason that there needs to be a global lock for this operation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a reason it needs to be global lock - NCCL needs to finish scheduling all kernels before anybody can start allocation/deallocation of gpu memory, otherwise deadlock will happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For multiple GPUs, would this just be applicable to the two GPU's involved in the transfer? Or just one GPU? Or all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NCCL is collective communication library so all gpus are involved at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the clarification
@eric-haibin-lin Please create an issue and mention me so I can keep track of this case. I've been thinking about adding a p2/p3-specific job to the Unit Tests - this would cover features which are unsupported by our usual CI-machiens. |
Integrated last comments from @piiswrong and merged with 2-bit compression PR (NCCL does not support gradient compression and prints error message when trying to use it). |
Is this ready to go in? (assuming CI passes) |
@cjolivier01 As far as I am concerned - yes, it is done. @piiswrong? |
Yes, we decided to deal with the aggregation policy later |
* NCCL integration * Skipping NCCL test (since it requires NCCL library to be present and enabled in build) * Add Apache header to test_nccl.py * Fixes from review * Trigger CI * Removing API change for Pull * Fixes * Fix * Fix * Fix * Fix * Fix * Indentation fixes and importing unittest in test_nccl.py * sorted_key_attrs -> key_attrs * More fixes from review * Fix * Fix lint * Support for aggregation in NCCL * Fix typo * Fix missing logic * Move from CommNCCL to KVStoreNCCL * Fix * Moved nccl update to separate function * Add message about not supporting gradient compression * Fix lint * Trigger CI
* NCCL integration * Skipping NCCL test (since it requires NCCL library to be present and enabled in build) * Add Apache header to test_nccl.py * Fixes from review * Trigger CI * Removing API change for Pull * Fixes * Fix * Fix * Fix * Fix * Fix * Indentation fixes and importing unittest in test_nccl.py * sorted_key_attrs -> key_attrs * More fixes from review * Fix * Fix lint * Support for aggregation in NCCL * Fix typo * Fix missing logic * Move from CommNCCL to KVStoreNCCL * Fix * Moved nccl update to separate function * Add message about not supporting gradient compression * Fix lint * Trigger CI
* NCCL integration * Skipping NCCL test (since it requires NCCL library to be present and enabled in build) * Add Apache header to test_nccl.py * Fixes from review * Trigger CI * Removing API change for Pull * Fixes * Fix * Fix * Fix * Fix * Fix * Indentation fixes and importing unittest in test_nccl.py * sorted_key_attrs -> key_attrs * More fixes from review * Fix * Fix lint * Support for aggregation in NCCL * Fix typo * Fix missing logic * Move from CommNCCL to KVStoreNCCL * Fix * Moved nccl update to separate function * Add message about not supporting gradient compression * Fix lint * Trigger CI
Description
This PR provides new KVStore type with integration for NCCL communication library.
Checklist
Essentials
make lint
)Changes
nccl
type of kvstore, using ncclReduce and ncclBcasttest_nccl.py
added totests/python/gpu
, but not enabled, since NCCL is not present and enabled in CIComments