Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

NCCL integration #8294

Merged
merged 30 commits into from
Nov 21, 2017
Merged

NCCL integration #8294

merged 30 commits into from
Nov 21, 2017

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Oct 16, 2017

Description

This PR provides new KVStore type with integration for NCCL communication library.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • For user-facing API changes, API doc string has been updated.
  • To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • New nccl type of kvstore, using ncclReduce and ncclBcast
  • test_nccl.py added to tests/python/gpu, but not enabled, since NCCL is not present and enabled in CI

Comments

  • Interesting edge cases to note here:
    • NCCL KVStore requires the same set of devices to be used for all communications (as is the case in typical data parallel training)
    • in NCCL KVStore push and pull are implemented using 2 steps - launching NCCL kernels in 1 step and synchronizing in the second step. This was made to enable seamless aggregation support - several reductions are scheduled before a synchronization.

@ptrendx ptrendx mentioned this pull request Oct 16, 2017
@piiswrong
Copy link
Contributor

@mli @eric-haibin-lin

@eric-haibin-lin eric-haibin-lin self-requested a review October 17, 2017 03:49
Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work to bring NCCL to MXNet! I haven't finished reading all code, a few comments so far..

@@ -162,7 +162,7 @@ class KVStore {
* \param priority Priority of the action.
*/
virtual void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
const std::vector<NDArray>& values,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary to change the interface here? Was this causing memory issues in pool_storage_manager?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not really necessary, but I think it makes the interface more consistent (and C API interface is not changed so it should not affect frontend languages) - it makes it simpler to reuse code between push and pull for example. It was originally introduced as part of the previous NCCL integration effort (that never got merged) to accomodate allreduce API interface.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still required for this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, and I changed it back (although this makes some of the other functions kind of ugly when you need to support both pointers and references to ndarrays).

# Use aggregation by default only with NCCL
default_batch = 16 if 'nccl' in kvstore.type else 1
batch = int(os.getenv('MXNET_UPDATE_AGGREGATION_SIZE', default_batch))
while(start < size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: while start < size:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change.

size = len(grad_arrays)
start = 0
# Use aggregation by default only with NCCL
default_batch = 16 if 'nccl' in kvstore.type else 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does the magic number 16 come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance experiments :-). User may change that value with env variable though.

# pull back the weights
kvstore.pull(name, arg_list, priority=-index)
kvstore.pull(param_names[start:end], param_arrays[start:end], priority=-start)
start = end

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None, param_names=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function not updated with batch aggregation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not used in GPU environment (only in local kvstore), and this aggregation is not "real" aggregation (although it makes it possible to implement actual aggregation in the future, either explicit by copying to long buffer and modifying pointers or implicit inside NCCL), so there is no sense in enabling it here.
What this aggregation does is basically delay the synchronization, so that multiple NCCL kernels may work at the same time, having better chance at saturating available links.

@@ -58,7 +76,10 @@ class Comm {
*/
virtual void Broadcast(
int key, const NDArray& src,
const std::vector<NDArray*> dst, int priority) = 0;
const std::vector<NDArray> dst, int priority) = 0;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add brief comments for these two methods? Are they only for nccl? Do we want to declare it only when MXNET_USE_NCCL is set?

@@ -61,7 +61,10 @@ class KVStoreLocal : public KVStore {
}

virtual ~KVStoreLocal() {
delete comm_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think delete nullptr is safe

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will change that to only add comm_ = nullptr; after deleting. I added the check because the previous version of the destructor was not setting comm_ to nullptr and it gave me segfault when kvstore_nccl called both its own destructor and kvstore_local destructor (both trying to delete comm_).

const std::vector<NDArray*>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray*>> *grouped_vals) {
virtual void GroupKVPairsPull(const std::vector<int>& keys,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is virtual added here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the validator may be different in inheriting classes. That is the case for NCCL kvstore - it inherits from local kvstore to remove copy-paste code, but I can't support sparse types there.

@@ -78,6 +78,16 @@ class Storage {
*/
virtual ~Storage() {}
/*!
* \brief Returns mutex used by storage manager
*/
std::mutex& GetMutex(Context::DeviceType dev) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add brief description when mutex is required?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutex is not really required outside of NCCL (and only for GPU allocations). See discussion from the previous NCCL PR: #5521 (comment)

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any performance number for using nccl in MXNet?
Also adding @rahul003 for review

for (size_t i = 0; i < src.size(); ++i) {
NCCLEntry cur = nccl_data_[src[i].ctx().dev_id];
if (i == root_id) {
MSHADOW_TYPE_SWITCH(src[i].dtype(), DType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please fix indentation

return ncclDouble;
case mshadow::kUint8:
return ncclChar;
case mshadow::kInt32:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should kInt64 also be added??

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

/**
* \brief store data in local machine using NCCL
*/
class KVStoreNCCL : public KVStoreLocal {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean multi-machine with GPUs cannot benefit from nccl?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently no - it is a future work. The biggest problem is how to bootstrap NCCL in multi-node scenario, and I do not yet understand MXNet's distributed kvstore enough to use it for that task.

# specific language governing permissions and limitations
# under the License.

# pylint: skip-file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you enable pylint here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most current tests have pylint disabled, I copied that part from them. Sure, I will do that.

a = mx.nd.ones(shape, mx.gpu(0))
cur_key = str(key*max(gpus)+n_gpus)
kv_nccl.init(cur_key, a)
arr_list = [mx.nd.ones(shape, mx.gpu(x)) for x in xrange(n_gpus)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xrange is not py3 compatible. Can you replace it with range?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

@@ -32,6 +35,21 @@
#include "mxnet/ndarray.h"
#include "../ndarray/ndarray_function.h"
#include "../operator/tensor/sparse_retain-inl.h"

#if MXNET_USE_NCCL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we merge the two #if MXNET_USE_NCCL ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linter would complain about system header being after local headers (I assume those are the 2 #ifs you would want merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then we can leave them as is

dev_ids.push_back(e.ctx().dev_id);
}
std::sort(dev_ids.begin(), dev_ids.end());
CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to check here that the set of devices don't change during the training?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Handling multiple sets of devices can be done, but not with the structure imposed by the Comm class. Basically in order to keep the benefits of batching I need to ensure that the root for the reduction will be the same for the whole batch, but I know who participates only during the actual push/pull, not during Init, and all of the data structures are initialized only once during the first push. This BTW should also be checked in the device kvstore (and currently is not), otherwise you can do something like this:

>>> import mxnet as mx
>>> kv = mx.kv.create("device")
>>> shape = (2,3)
>>> kv.init(4, mx.nd.ones(shape))
>>> gpus = [mx.gpu(i) for i in range(2)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push(4, b)
>>> a = mx.nd.zeros(shape)
>>> kv.pull(4, out = a)
>>> a
[[ 2.  2.  2.]
 [ 2.  2.  2.]]
<NDArray 2x3 @cpu(0)>
>>> gpus = [mx.gpu(i) for i in range(4)]
>>> 
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push(4, b)
Segmentation fault

using KeyAttrs = std::tuple<int, TShape, int>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look like sorted_key_attrs_ is actually sorted in this case. And also doesn't look like they need to be sorted here. If so, can you use a different variable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

auto& buf = merge_buf_[key];
Context ctx;
// use devs[0] as root
ctx = devs[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three lines are strange. If you want to make devs[0] root always, please directly use devs[0] as arg in line 927 instead.

As I understand, the buffers are no longer evenly allocated on all devices? If so, please remove the comment on 917.
And why don't we want to do that anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove the comment.
We want to use devs[0] every time because this helps hide some of the latencies and keep more inter-GPU links occupied if the flow of data is always the same.

}
},
Context::CPU(),
const_vars,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be more compact if const_vars is replaced with {}. Same in below function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

#include "../common/cuda_utils.h"

#ifndef NCCL_MAJOR
#define NCCL_MAJOR 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

virtual void Init(int key, const NDArrayStorageType stype,
const TShape& shape, int dtype = mshadow::kFloat32) = 0;
virtual void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32, Context pinned_ctx = Context::CPUPinned(0)) = 0;
Copy link
Member

@rahul003 rahul003 Nov 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you changing the pinnned context to something other than CPU in
kVStoreNCCL? Or is this change just to generalize the function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a relict of a previous iteration of the code - I will remove it.

}
}
} else {
auto& buf = merge_buf_[key];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add some comments explaining the flow of data here

if (dst.size() == 1) return;
std::vector<Engine::VarHandle> mutable_vars;
for (size_t i = 0; i < dst.size(); ++i) {
if ( i != root_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add braces around line 848. This style has potential for future bugs :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

size_t root_id = -1;
for (size_t i = 0; i < dst.size(); ++i) {
if (dst[i].ctx().dev_id == root) {
root_id = i;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please wrap such tasks in small wrapper functions so the code becomes more readable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did for this snippet. Unfortunately the other small functions like this have slightly different elements so it's not as simple to wrap them and reuse. I also added some comments on what is being done.

@eric-haibin-lin
Copy link
Member

Thanks for addressing all these review comments. Is anyone helping you to setup the CI test with NCCL?

@ptrendx
Copy link
Member Author

ptrendx commented Nov 7, 2017

Not really, no.

# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
kvstore.push(param_names[start:end], grad_arrays[start:end], priority=-start)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of this? Why should it be done in frontend?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enables aggregation of the reductions (in NCCL 2.1 it is not a real aggregation, since each reduction is being done in its own launch and they benefit from lack of synchronization between reductions in a group, but NCCL 2.2 will introduce real aggregation support).
It needs to be done in the frontend, because kvstore itself does not have any information on which gradients should be aggregated, in which order and how much time is it supposed to wait before real launch - mxnet's dependency engine does not really allow for that. That's why we need to provide data about all of the reductions in a group to benefit from aggregation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it valid to push multiple keys in one call? If its valid why not push all keys in one call and decide aggregation size in backend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the if grad_list[0] is None: logic is removed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid to push multiple keys in 1 call. But you still want to set priority to those reductions in frontend - kvstore should not assume the order in which you push gradients to it, so aggregating everything would get rid of priority altogether.
Regarding the missing logic - good point, I forgot about it - will fix.

@eric-haibin-lin
Copy link
Member

@mbaijal We'll need your help setting up a new CI test with NCCL build after this is merged.

/**
* \brief copy from src to dst[i] for every i
*/
virtual void Broadcast(
int key, const NDArray& src,
const std::vector<NDArray*> dst, int priority) = 0;

#if MXNET_USE_NCCL
// Aggregated reductions
virtual void Reduce(const std::vector<int> keys,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If NCCL is going to do everything differently then it shouldn't inherit the comm interface. Do this in KVStoreNCCL directly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will do.

@@ -88,7 +92,7 @@ class GPUPooledStorageManager final : public StorageManager {
}; // class GPUPooledStorageManager

void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now memory allocation on all gpus share the same mutex? This could slow down memory allocation. Especially when using Gluon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you suggest a Gluon benchmark to check the performance impact? If the impact is too big we can move to cooperative launch in NCCL2 (that would not need a shared mutex anymore), but this would mean it will be compatible only with NCCL 2 and CUDA 9. Also, currently cooperative launch is slower than parallel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be looking for places to insert many-read/single-write shared mutexes in places to help performance. Do you think that this would be a good candidate for this, or is there a reason that there needs to be a global lock for this operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a reason it needs to be global lock - NCCL needs to finish scheduling all kernels before anybody can start allocation/deallocation of gpu memory, otherwise deadlock will happen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For multiple GPUs, would this just be applicable to the two GPU's involved in the transfer? Or just one GPU? Or all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL is collective communication library so all gpus are involved at the same time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clarification

@marcoabreu
Copy link
Contributor

@eric-haibin-lin Please create an issue and mention me so I can keep track of this case. I've been thinking about adding a p2/p3-specific job to the Unit Tests - this would cover features which are unsupported by our usual CI-machiens.

@ptrendx
Copy link
Member Author

ptrendx commented Nov 21, 2017

Integrated last comments from @piiswrong and merged with 2-bit compression PR (NCCL does not support gradient compression and prints error message when trying to use it).

@cjolivier01
Copy link
Member

cjolivier01 commented Nov 21, 2017

Is this ready to go in? (assuming CI passes)

@ptrendx
Copy link
Member Author

ptrendx commented Nov 21, 2017

@cjolivier01 As far as I am concerned - yes, it is done. @piiswrong?

@piiswrong
Copy link
Contributor

Yes, we decided to deal with the aggregation policy later

@cjolivier01 cjolivier01 merged commit cace29f into apache:master Nov 21, 2017
eric-haibin-lin pushed a commit to eric-haibin-lin/mxnet that referenced this pull request Dec 3, 2017
* NCCL integration

* Skipping NCCL test (since it requires NCCL library to be present and
    enabled in build)

* Add Apache header to test_nccl.py

* Fixes from review

* Trigger CI

* Removing API change for Pull

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Indentation fixes and importing unittest in test_nccl.py

* sorted_key_attrs -> key_attrs

* More fixes from review

* Fix

* Fix lint

* Support for aggregation in NCCL

* Fix typo

* Fix missing logic

* Move from CommNCCL to KVStoreNCCL

* Fix

* Moved nccl update to separate function

* Add message about not supporting gradient compression

* Fix lint

* Trigger CI
zhreshold pushed a commit to zhreshold/mxnet that referenced this pull request Dec 14, 2017
* NCCL integration

* Skipping NCCL test (since it requires NCCL library to be present and
    enabled in build)

* Add Apache header to test_nccl.py

* Fixes from review

* Trigger CI

* Removing API change for Pull

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Indentation fixes and importing unittest in test_nccl.py

* sorted_key_attrs -> key_attrs

* More fixes from review

* Fix

* Fix lint

* Support for aggregation in NCCL

* Fix typo

* Fix missing logic

* Move from CommNCCL to KVStoreNCCL

* Fix

* Moved nccl update to separate function

* Add message about not supporting gradient compression

* Fix lint

* Trigger CI
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* NCCL integration

* Skipping NCCL test (since it requires NCCL library to be present and
    enabled in build)

* Add Apache header to test_nccl.py

* Fixes from review

* Trigger CI

* Removing API change for Pull

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Indentation fixes and importing unittest in test_nccl.py

* sorted_key_attrs -> key_attrs

* More fixes from review

* Fix

* Fix lint

* Support for aggregation in NCCL

* Fix typo

* Fix missing logic

* Move from CommNCCL to KVStoreNCCL

* Fix

* Moved nccl update to separate function

* Add message about not supporting gradient compression

* Fix lint

* Trigger CI
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants