Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix specifying gpu_id, add tests. #3851

Merged
merged 4 commits into from
Nov 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/gpu/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ GPU accelerated prediction is enabled by default for the above mentioned ``tree_

The device ordinal can be selected using the ``gpu_id`` parameter, which defaults to 0.

Multiple GPUs can be used with the ``gpu_hist`` tree method using the ``n_gpus`` parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If ``gpu_id`` is specified as non-zero, the gpu device order is ``mod(gpu_id + i) % n_visible_devices`` for ``i=0`` to ``n_gpus-1``. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance.
Multiple GPUs can be used with the ``gpu_hist`` tree method using the ``n_gpus`` parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If ``gpu_id`` is specified as non-zero, the selected gpu devices will be from ``gpu_id`` to ``gpu_id+n_gpus``, please note that ``gpu_id+n_gpus`` must be less than or equal to the number of available GPUs on your system. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance.

.. note:: Enabling multi-GPU training

Expand Down
105 changes: 65 additions & 40 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,61 +147,86 @@ struct AllVisibleImpl {
*/
class GPUSet {
public:
using GpuIdType = int;
static constexpr GpuIdType kAll = -1;

explicit GPUSet(int start = 0, int ndevices = 0)
: devices_(start, start + ndevices) {}

static GPUSet Empty() { return GPUSet(); }

static GPUSet Range(int start, int ndevices) {
return ndevices <= 0 ? Empty() : GPUSet{start, ndevices};
}
/*! \brief ndevices and num_rows both are upper bounds. */
static GPUSet All(int ndevices, int num_rows = std::numeric_limits<int>::max()) {
int n_devices_visible = AllVisible().Size();
if (ndevices < 0 || ndevices > n_devices_visible) {
ndevices = n_devices_visible;
static GPUSet Range(GpuIdType start, GpuIdType n_gpus) {
return n_gpus <= 0 ? Empty() : GPUSet{start, n_gpus};
}
/*! \brief n_gpus and num_rows both are upper bounds. */
static GPUSet All(GpuIdType gpu_id, GpuIdType n_gpus,
GpuIdType num_rows = std::numeric_limits<GpuIdType>::max()) {
CHECK_GE(gpu_id, 0) << "gpu_id must be >= 0.";
CHECK_GE(n_gpus, -1) << "n_gpus must be >= -1.";

GpuIdType const n_devices_visible = AllVisible().Size();
if (n_devices_visible == 0) { return Empty(); }

GpuIdType const n_available_devices = n_devices_visible - gpu_id;

if (n_gpus == kAll) { // Use all devices starting from `gpu_id'.
CHECK(gpu_id < n_devices_visible)
<< "\ngpu_id should be less than number of visible devices.\ngpu_id: "
<< gpu_id
<< ", number of visible devices: "
<< n_devices_visible;
GpuIdType n_devices =
n_available_devices < num_rows ? n_available_devices : num_rows;
return Range(gpu_id, n_devices);
} else { // Use devices in ( gpu_id, gpu_id + n_gpus ).
CHECK_LE(n_gpus, n_available_devices)
<< "Starting from gpu id: " << gpu_id << ", there are only "
<< n_available_devices << " available devices, while n_gpus is set to: "
<< n_gpus;
GpuIdType n_devices = n_gpus < num_rows ? n_gpus : num_rows;
return Range(gpu_id, n_devices);
}
// fix-up device number to be limited by number of rows
ndevices = ndevices > num_rows ? num_rows : ndevices;
return Range(0, ndevices);
}

static GPUSet AllVisible() {
int n = AllVisibleImpl::AllVisible();
GpuIdType n = AllVisibleImpl::AllVisible();
return Range(0, n);
}
/*! \brief Ensure gpu_id is correct, so not dependent upon user knowing details */
static int GetDeviceIdx(int gpu_id) {
auto devices = AllVisible();
CHECK(!devices.IsEmpty()) << "Empty device.";
return (std::abs(gpu_id) + 0) % devices.Size();
}
/*! \brief Counting from gpu_id */
GPUSet Normalised(int gpu_id) const {
return Range(gpu_id, Size());
}
/*! \brief Counting from 0 */
GPUSet Unnormalised() const {
return Range(0, Size());
}

int Size() const {
int res = *devices_.end() - *devices_.begin();
return res < 0 ? 0 : res;
}
/*! \brief Get normalised device id. */
int operator[](int index) const {
CHECK(index >= 0 && index < Size());
return *devices_.begin() + index;
size_t Size() const {
GpuIdType size = *devices_.end() - *devices_.begin();
GpuIdType res = size < 0 ? 0 : size;
return static_cast<size_t>(res);
}

/*
* By default, we have two configurations of identifying device, one
* is the device id obtained from `cudaGetDevice'. But we sometimes
* store objects that allocated one for each device in a list, which
* requires a zero-based index.
*
* Hence, `DeviceId' converts a zero-based index to actual device id,
* `Index' converts a device id to a zero-based index.
*/
GpuIdType DeviceId(size_t index) const {
GpuIdType result = *devices_.begin() + static_cast<GpuIdType>(index);
CHECK(Contains(result)) << "\nDevice " << result << " is not in GPUSet."
<< "\nIndex: " << index
<< "\nGPUSet: (" << *begin() << ", " << *end() << ")"
<< std::endl;
return result;
}
size_t Index(GpuIdType device) const {
CHECK(Contains(device)) << "\nDevice " << device << " is not in GPUSet."
<< "\nGPUSet: (" << *begin() << ", " << *end() << ")"
<< std::endl;
size_t result = static_cast<size_t>(device - *devices_.begin());
return result;
}

bool IsEmpty() const { return Size() == 0; }
/*! \brief Get un-normalised index. */
int Index(int device) const {
CHECK(Contains(device));
return device - *devices_.begin();
}

bool Contains(int device) const {
bool Contains(GpuIdType device) const {
return *devices_.begin() <= device && device < *devices_.end();
}

Expand Down
22 changes: 16 additions & 6 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ T *Raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}

inline void CudaCheckPointerDevice(void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int ptr_device = attr.device;
int cur_device = -1;
cudaGetDevice(&cur_device);
CHECK_EQ(ptr_device, cur_device) << "pointer device: " << ptr_device
<< "current device: " << cur_device;
}

template <typename T>
const T *Raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
Expand All @@ -61,7 +71,7 @@ const T *Raw(const thrust::device_vector<T> &v) { // NOLINT
// if n_devices=-1, then use all visible devices
inline void SynchronizeNDevices(xgboost::GPUSet devices) {
devices = devices.IsEmpty() ? xgboost::GPUSet::AllVisible() : devices;
for (auto const d : devices.Unnormalised()) {
for (auto const d : devices) {
safe_cuda(cudaSetDevice(d));
safe_cuda(cudaDeviceSynchronize());
}
Expand Down Expand Up @@ -743,7 +753,8 @@ void SumReduction(dh::CubMemory &tmp_mem, dh::DVec<T> &in, dh::DVec<T> &out,
* @param nVals number of elements in the input array
*/
template <typename T>
typename std::iterator_traits<T>::value_type SumReduction(dh::CubMemory &tmp_mem, T in, int nVals) {
typename std::iterator_traits<T>::value_type SumReduction(
dh::CubMemory &tmp_mem, T in, int nVals) {
using ValueT = typename std::iterator_traits<T>::value_type;
size_t tmpSize;
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
Expand Down Expand Up @@ -900,11 +911,10 @@ class AllReducer {
double *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised);

dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
comms[communication_group_idx],
streams[communication_group_idx]));
comms.at(communication_group_idx),
streams.at(communication_group_idx)));
#endif
}

Expand Down
8 changes: 4 additions & 4 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,9 @@ struct GPUSketcher {
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
size_t start = dist_.ShardStart(info.num_row_, i);
size_t size = dist_.ShardSize(info.num_row_, i);
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(dist_.Devices()[i], start, start + size, param_));
shard = std::unique_ptr<DeviceShard>(
new DeviceShard(dist_.Devices().DeviceId(i),
start, start + size, param_));
});

// compute sketches for each shard
Expand All @@ -379,8 +380,7 @@ struct GPUSketcher {
}

GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus, n_rows).
Normalised(param_.gpu_id));
dist_ = GPUDistribution::Block(GPUSet::All(param_.gpu_id, param_.n_gpus, n_rows));
}

std::vector<std::unique_ptr<DeviceShard>> shards_;
Expand Down
40 changes: 19 additions & 21 deletions src/common/host_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ template <typename T>
struct HostDeviceVectorImpl {
struct DeviceShard {
DeviceShard()
: index_(-1), proper_size_(0), device_(-1), start_(0), perm_d_(false),
: proper_size_(0), device_(-1), start_(0), perm_d_(false),
cached_size_(~0), vec_(nullptr) {}

void Init(HostDeviceVectorImpl<T>* vec, int device) {
if (vec_ == nullptr) { vec_ = vec; }
CHECK_EQ(vec, vec_);
device_ = device;
index_ = vec_->distribution_.devices_.Index(device);
LazyResize(vec_->Size());
perm_d_ = vec_->perm_h_.Complementary();
}
Expand All @@ -62,7 +61,6 @@ struct HostDeviceVectorImpl {
if (vec_ == nullptr) { vec_ = vec; }
CHECK_EQ(vec, vec_);
device_ = other.device_;
index_ = other.index_;
cached_size_ = other.cached_size_;
start_ = other.start_;
proper_size_ = other.proper_size_;
Expand Down Expand Up @@ -114,10 +112,11 @@ struct HostDeviceVectorImpl {
if (new_size == cached_size_) { return; }
// resize is required
int ndevices = vec_->distribution_.devices_.Size();
start_ = vec_->distribution_.ShardStart(new_size, index_);
proper_size_ = vec_->distribution_.ShardProperSize(new_size, index_);
int device_index = vec_->distribution_.devices_.Index(device_);
start_ = vec_->distribution_.ShardStart(new_size, device_index);
proper_size_ = vec_->distribution_.ShardProperSize(new_size, device_index);
// The size on this device.
size_t size_d = vec_->distribution_.ShardSize(new_size, index_);
size_t size_d = vec_->distribution_.ShardSize(new_size, device_index);
SetDevice();
data_.resize(size_d);
cached_size_ = new_size;
Expand Down Expand Up @@ -154,7 +153,6 @@ struct HostDeviceVectorImpl {
}
}

int index_;
int device_;
thrust::device_vector<T> data_;
// cached vector size
Expand Down Expand Up @@ -183,13 +181,13 @@ struct HostDeviceVectorImpl {
distribution_(other.distribution_), mutex_() {
shards_.resize(other.shards_.size());
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
shard.Init(this, other.shards_[i]);
shard.Init(this, other.shards_.at(i));
});
}

// Init can be std::vector<T> or std::initializer_list<T>
template <class Init>
HostDeviceVectorImpl(const Init& init, GPUDistribution distribution)
// Initializer can be std::vector<T> or std::initializer_list<T>
template <class Initializer>
HostDeviceVectorImpl(const Initializer& init, GPUDistribution distribution)
: distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) {
if (!distribution_.IsEmpty()) {
size_d_ = init.size();
Expand All @@ -204,7 +202,7 @@ struct HostDeviceVectorImpl {
int ndevices = distribution_.devices_.Size();
shards_.resize(ndevices);
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
shard.Init(this, distribution_.devices_[i]);
shard.Init(this, distribution_.devices_.DeviceId(i));
});
}

Expand All @@ -217,41 +215,41 @@ struct HostDeviceVectorImpl {
T* DevicePointer(int device) {
CHECK(distribution_.devices_.Contains(device));
LazySyncDevice(device, GPUAccess::kWrite);
return shards_[distribution_.devices_.Index(device)].data_.data().get();
return shards_.at(distribution_.devices_.Index(device)).data_.data().get();
}

const T* ConstDevicePointer(int device) {
CHECK(distribution_.devices_.Contains(device));
LazySyncDevice(device, GPUAccess::kRead);
return shards_[distribution_.devices_.Index(device)].data_.data().get();
return shards_.at(distribution_.devices_.Index(device)).data_.data().get();
}

common::Span<T> DeviceSpan(int device) {
GPUSet devices = distribution_.devices_;
CHECK(devices.Contains(device));
LazySyncDevice(device, GPUAccess::kWrite);
return {shards_[devices.Index(device)].data_.data().get(),
return {shards_.at(devices.Index(device)).data_.data().get(),
static_cast<typename common::Span<T>::index_type>(DeviceSize(device))};
}

common::Span<const T> ConstDeviceSpan(int device) {
GPUSet devices = distribution_.devices_;
CHECK(devices.Contains(device));
LazySyncDevice(device, GPUAccess::kRead);
return {shards_[devices.Index(device)].data_.data().get(),
return {shards_.at(devices.Index(device)).data_.data().get(),
static_cast<typename common::Span<const T>::index_type>(DeviceSize(device))};
}

size_t DeviceSize(int device) {
CHECK(distribution_.devices_.Contains(device));
LazySyncDevice(device, GPUAccess::kRead);
return shards_[distribution_.devices_.Index(device)].data_.size();
return shards_.at(distribution_.devices_.Index(device)).data_.size();
}

size_t DeviceStart(int device) {
CHECK(distribution_.devices_.Contains(device));
LazySyncDevice(device, GPUAccess::kRead);
return shards_[distribution_.devices_.Index(device)].start_;
return shards_.at(distribution_.devices_.Index(device)).start_;
}

thrust::device_ptr<T> tbegin(int device) { // NOLINT
Expand Down Expand Up @@ -316,7 +314,7 @@ struct HostDeviceVectorImpl {
size_d_ = other->size_d_;
}
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
shard.Copy(&other->shards_[i]);
shard.Copy(&other->shards_.at(i));
});
}

Expand Down Expand Up @@ -405,15 +403,15 @@ struct HostDeviceVectorImpl {
void LazySyncDevice(int device, GPUAccess access) {
GPUSet devices = distribution_.Devices();
CHECK(devices.Contains(device));
shards_[devices.Index(device)].LazySyncDevice(access);
shards_.at(devices.Index(device)).LazySyncDevice(access);
}

bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); }

bool DeviceCanAccess(int device, GPUAccess access) {
GPUSet devices = distribution_.Devices();
if (!devices.Contains(device)) { return false; }
return shards_[devices.Index(device)].perm_d_.CanAccess(access);
return shards_.at(devices.Index(device)).perm_d_.CanAccess(access);
}

std::vector<T> data_h_;
Expand Down
9 changes: 5 additions & 4 deletions src/common/host_device_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ void SetCudaSetDeviceHandler(void (*handler)(int));

template <typename T> struct HostDeviceVectorImpl;

// Distribution for the HostDeviceVector; it specifies such aspects as the devices it is
// distributed on, whether there are copies of elements from other GPUs as well as the granularity
// of splitting. It may also specify explicit boundaries for devices, in which case the size of the
// array cannot be changed.
// Distribution for the HostDeviceVector; it specifies such aspects as the
// devices it is distributed on, whether there are copies of elements from
// other GPUs as well as the granularity of splitting. It may also specify
// explicit boundaries for devices, in which case the size of the array cannot
// be changed.
class GPUDistribution {
template<typename T> friend struct HostDeviceVectorImpl;

Expand Down
Loading