Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
NCCL integration (#8294)
Browse files Browse the repository at this point in the history
* NCCL integration

* Skipping NCCL test (since it requires NCCL library to be present and
    enabled in build)

* Add Apache header to test_nccl.py

* Fixes from review

* Trigger CI

* Removing API change for Pull

* Fixes

* Fix

* Fix

* Fix

* Fix

* Fix

* Indentation fixes and importing unittest in test_nccl.py

* sorted_key_attrs -> key_attrs

* More fixes from review

* Fix

* Fix lint

* Support for aggregation in NCCL

* Fix typo

* Fix missing logic

* Move from CommNCCL to KVStoreNCCL

* Fix

* Moved nccl update to separate function

* Add message about not supporting gradient compression

* Fix lint

* Trigger CI
  • Loading branch information
ptrendx authored and cjolivier01 committed Nov 21, 2017
1 parent 068b589 commit cace29f
Show file tree
Hide file tree
Showing 14 changed files with 707 additions and 25 deletions.
11 changes: 11 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,19 @@ ifeq ($(USE_CUDA), 1)
ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ) $(PLUGIN_CUOBJ)
LDFLAGS += -lcuda -lcufft -lnvrtc
SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-gpu
ifeq ($(USE_NCCL), 1)
ifneq ($(USE_NCCL_PATH), NONE)
CFLAGS += -I$(USE_NCCL_PATH)/include
LDFLAGS += -L$(USE_NCCL_PATH)/lib
endif
LDFLAGS += -lnccl
CFLAGS += -DMXNET_USE_NCCL=1
else
CFLAGS += -DMXNET_USE_NCCL=0
endif
else
SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-cpu
CFLAGS += -DMXNET_USE_NCCL=0
endif

ifeq ($(USE_LIBJPEG_TURBO), 1)
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ size_t num_aux_data(NDArrayStorageType stype);
* \note The function name explicitly marks the order of from and to
* due to different possible convention carried by copy function.
*/
void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0);
void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0);

/*!
* \brief issue an copy operation from one NDArray to another
Expand Down
14 changes: 14 additions & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ class Storage {
* \brief Destructor.
*/
virtual ~Storage() {}
/*!
* \brief Returns mutex used by storage manager
*/
std::mutex& GetMutex(Context::DeviceType dev) {
if (dev == Context::kCPU) {
return cpu_mutex_;
} else {
return gpu_mutex_;
}
}
/*!
* \return Storage singleton.
*/
Expand All @@ -112,6 +122,10 @@ class Storage {
* \return A shared pointer to Storage singleton.
*/
static std::shared_ptr<Storage> _GetSharedRef();

private:
std::mutex cpu_mutex_;
std::mutex gpu_mutex_;
}; // class Storage
} // namespace mxnet
#endif // MXNET_STORAGE_H_
5 changes: 5 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ USE_CUDA_PATH = NONE
# whether use CuDNN R3 library
USE_CUDNN = 0

#whether to use NCCL library
USE_NCCL = 0
#add the path to NCCL library
USE_NCCL_PATH = NONE

# whether use opencv during compilation
# you can disable it, however, you will not able to use
# imbin iterator
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def create(name='local'):
Parameters
----------
name : {'local', 'device', 'dist_sync', 'dist_device_sync', 'dist_async'}
name : {'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async'}
The type of KVStore.
Returns
-------
Expand Down
32 changes: 29 additions & 3 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""MXNet model module"""
from __future__ import absolute_import, print_function

import os
import time
import logging
import warnings
Expand Down Expand Up @@ -102,6 +103,26 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_o
if update_on_kvstore:
kvstore.pull(name, param_on_devs, priority=-idx)

def _update_params_on_kvstore_nccl(param_arrays, grad_arrays, kvstore, param_names):
"""Perform update of param_arrays from grad_arrays on NCCL kvstore."""
valid_indices = [index for index, grad_list in
enumerate(grad_arrays) if grad_list[0] is not None]
valid_grad_arrays = [grad_arrays[i] for i in valid_indices]
valid_param_arrays = [param_arrays[i] for i in valid_indices]
valid_param_names = [param_names[i] for i in valid_indices]
size = len(valid_grad_arrays)
start = 0
# Use aggregation by default only with NCCL
default_batch = 16
batch = int(os.getenv('MXNET_UPDATE_AGGREGATION_SIZE', default_batch))
while start < size:
end = start + batch if start + batch < size else size
# push gradient, priority is negative index
kvstore.push(valid_param_names[start:end], valid_grad_arrays[start:end], priority=-start)
# pull back the weights
kvstore.pull(valid_param_names[start:end], valid_param_arrays[start:end], priority=-start)
start = end

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
"""Perform update of param_arrays from grad_arrays on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
Expand Down Expand Up @@ -263,9 +284,14 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
executor_manager.backward()

if update_on_kvstore:
_update_params_on_kvstore(executor_manager.param_arrays,
executor_manager.grad_arrays,
kvstore, executor_manager.param_names)
if 'nccl' in kvstore.type:
_update_params_on_kvstore_nccl(executor_manager.param_arrays,
executor_manager.grad_arrays,
kvstore, executor_manager.param_names)
else:
_update_params_on_kvstore(executor_manager.param_arrays,
executor_manager.grad_arrays,
kvstore, executor_manager.param_names)
else:
_update_params(executor_manager.param_arrays,
executor_manager.grad_arrays,
Expand Down
14 changes: 13 additions & 1 deletion src/kvstore/kvstore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
#if MXNET_USE_DIST_KVSTORE
#include "./kvstore_dist.h"
#endif // MXNET_USE_DIST_KVSTORE
#if MXNET_USE_NCCL
#include "./kvstore_nccl.h"
#endif // MXNET_USE_NCCL

namespace mxnet {

Expand Down Expand Up @@ -56,7 +59,16 @@ KVStore* KVStore::Create(const char *type_name) {
return nullptr;
#endif // MXNET_USE_DIST_KVSTORE
} else {
kv = new kvstore::KVStoreLocal(use_device_comm);
if (has("nccl")) {
#if MXNET_USE_NCCL
kv = new kvstore::KVStoreNCCL();
#else
LOG(FATAL) << "compile with USE_NCCL=1 to use " << tname;
return nullptr;
#endif
} else {
kv = new kvstore::KVStoreLocal(use_device_comm);
}
}
kv->type_ = tname;
return kv;
Expand Down
30 changes: 17 additions & 13 deletions src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class KVStoreLocal : public KVStore {

virtual ~KVStoreLocal() {
delete comm_;
comm_ = nullptr;
}

void Init(const std::vector<int>& keys,
Expand Down Expand Up @@ -234,6 +235,7 @@ class KVStoreLocal : public KVStore {
}

protected:
KVStoreLocal() : KVStore() {}
/**
* \brief set the key type of the kvstore if haven't already.
* If the key type is already defined, check if it matches the provided key type
Expand All @@ -246,10 +248,10 @@ class KVStoreLocal : public KVStore {
/**
* \brief group values on keys for push
*/
void GroupKVPairsPush(const std::vector<int>& keys,
const std::vector<NDArray>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray>> *grouped_vals) {
virtual void GroupKVPairsPush(const std::vector<int>& keys,
const std::vector<NDArray>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const NDArray& nd) -> bool {
auto stype = nd.storage_type();
Expand All @@ -264,10 +266,10 @@ class KVStoreLocal : public KVStore {
/**
* \brief group values on keys for pull
*/
void GroupKVPairsPull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray*>> *grouped_vals) {
virtual void GroupKVPairsPull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray*>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const NDArray* nd) -> bool {
// valid
Expand All @@ -283,15 +285,17 @@ class KVStoreLocal : public KVStore {
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
}

typedef std::pair<NDArray*, NDArray> RSPVal;
/**
* \brief group values on keys for row_sparse_pull
*/
void GroupKVPairsPullRsp(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<std::pair<NDArray*, NDArray>>> *grouped_vals) {
virtual void GroupKVPairsPullRsp(const std::vector<int>& keys,
const std::vector<RSPVal>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<RSPVal>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const std::pair<NDArray*, NDArray>& val_rowid) -> bool {
auto validator = [this](const int key, const RSPVal& val_rowid) -> bool {
auto val_stype = val_rowid.first->storage_type();
auto rowid_stype = val_rowid.second.storage_type();
// check storage types
Expand Down
Loading

0 comments on commit cace29f

Please sign in to comment.