From bd29c4aa829bbdd4b521bedcc89a75993d444527 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 1 Nov 2018 20:57:13 +1300 Subject: [PATCH] Rewrite gpu_id related code. * Remove normalised/unnormalised operatios. * Address difference between `Index' and `Device ID'. * Modify doc for `gpu_id'. * Better LOG for GPUSet. --- doc/gpu/index.rst | 2 +- src/common/common.h | 85 ++++++++++++++---------- src/common/device_helpers.cuh | 22 ++++-- src/common/hist_util.cu | 8 +-- src/common/host_device_vector.cu | 40 ++++++----- src/common/host_device_vector.h | 9 +-- src/common/transform.h | 24 ++++--- src/linear/updater_gpu_coordinate.cu | 31 +++++---- src/objective/hinge.cu | 2 +- src/objective/multiclass_obj.cu | 2 +- src/objective/regression_obj.cu | 8 +-- src/predictor/gpu_predictor.cu | 14 ++-- src/tree/updater_gpu.cu | 14 ++-- src/tree/updater_gpu_hist.cu | 85 ++++++++++++------------ tests/cpp/common/test_common.cc | 17 ++--- tests/cpp/common/test_common.cu | 30 ++++++--- tests/cpp/common/test_transform_range.cu | 67 +++++++++++++++++++ tests/cpp/tree/test_gpu_hist.cu | 10 +-- tests/python-gpu/test_gpu_updaters.py | 34 +++++++--- 19 files changed, 312 insertions(+), 192 deletions(-) diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index 06cd2c08c0ae..44229338a1df 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -63,7 +63,7 @@ GPU accelerated prediction is enabled by default for the above mentioned ``tree_ The device ordinal can be selected using the ``gpu_id`` parameter, which defaults to 0. -Multiple GPUs can be used with the ``gpu_hist`` tree method using the ``n_gpus`` parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If ``gpu_id`` is specified as non-zero, the gpu device order is ``mod(gpu_id + i) % n_visible_devices`` for ``i=0`` to ``n_gpus-1``. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. +Multiple GPUs can be used with the ``gpu_hist`` tree method using the ``n_gpus`` parameter. which defaults to 1. If this is set to -1 all available GPUs will be used. If ``gpu_id`` is specified as non-zero, the selected gpu devices will be from ``gpu_id`` to ``gpu_id+n_gpus``, please note that ``gpu_id+n_gpus`` must be less than all equal to the number of available GPUs on your system. As with GPU vs. CPU, multi-GPU will not always be faster than a single GPU due to PCI bus bandwidth that can limit performance. .. note:: Enabling multi-GPU training diff --git a/src/common/common.h b/src/common/common.h index f521d972d417..665870e1a8aa 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -147,61 +147,74 @@ struct AllVisibleImpl { */ class GPUSet { public: + using GpuIndex = int; + static constexpr GpuIndex kAll = -1; + explicit GPUSet(int start = 0, int ndevices = 0) : devices_(start, start + ndevices) {} static GPUSet Empty() { return GPUSet(); } - static GPUSet Range(int start, int ndevices) { - return ndevices <= 0 ? Empty() : GPUSet{start, ndevices}; + static GPUSet Range(GpuIndex start, GpuIndex n_gpu) { + return n_gpu <=0 ? Empty() : GPUSet{start, n_gpu}; } /*! \brief ndevices and num_rows both are upper bounds. */ - static GPUSet All(int ndevices, int num_rows = std::numeric_limits::max()) { - int n_devices_visible = AllVisible().Size(); - if (ndevices < 0 || ndevices > n_devices_visible) { - ndevices = n_devices_visible; + static GPUSet All(GpuIndex gpu_id, GpuIndex n_gpu, + GpuIndex num_rows = std::numeric_limits::max()) { + GpuIndex n_devices_visible = AllVisible().Size(); + if (n_gpu == kAll) { // Use all devices starting from `gpu_id'. + CHECK(gpu_id < n_devices_visible || gpu_id == 0) + << "\ngpu_id should be less than available devices.\ngpu_id: " + << gpu_id + << ", number of available devices: " + << n_devices_visible << std::endl; + return Range(gpu_id, n_devices_visible - gpu_id); + } else { + GpuIndex n_available_devices = n_devices_visible - gpu_id; + GpuIndex n_devices = + n_available_devices < n_gpu ? n_devices_visible : n_gpu; + return Range(gpu_id, n_devices); } - // fix-up device number to be limited by number of rows - ndevices = ndevices > num_rows ? num_rows : ndevices; - return Range(0, ndevices); } + static GPUSet AllVisible() { - int n = AllVisibleImpl::AllVisible(); + GpuIndex n = AllVisibleImpl::AllVisible(); return Range(0, n); } - /*! \brief Ensure gpu_id is correct, so not dependent upon user knowing details */ - static int GetDeviceIdx(int gpu_id) { - auto devices = AllVisible(); - CHECK(!devices.IsEmpty()) << "Empty device."; - return (std::abs(gpu_id) + 0) % devices.Size(); - } - /*! \brief Counting from gpu_id */ - GPUSet Normalised(int gpu_id) const { - return Range(gpu_id, Size()); - } - /*! \brief Counting from 0 */ - GPUSet Unnormalised() const { - return Range(0, Size()); - } - int Size() const { - int res = *devices_.end() - *devices_.begin(); + GpuIndex Size() const { + GpuIndex res = *devices_.end() - *devices_.begin(); return res < 0 ? 0 : res; } - /*! \brief Get normalised device id. */ - int operator[](int index) const { - CHECK(index >= 0 && index < Size()); - return *devices_.begin() + index; + + /* + * By default, we have two configurations of identifying device, one + * is the device id obtained from `cudaGetDevice'. But we sometimes + * store objects that allocated one for each device in a list, which + * requires a zero-based index. + * + * Hence, `DeviceId' converts a zero-based index to actual device id, + * `Index' converts a device id to a zero-based index. + */ + GpuIndex DeviceId(GpuIndex index) const { + GpuIndex result = *devices_.begin() + index; + CHECK(Contains(result)) << "\nDevice " << result << " is not in GPUSet." + << "\nIndex: " << index + << "\nGPUSet: (" << *begin() << ", " << *end() << ")" + << std::endl; + return result; + } + GpuIndex Index(GpuIndex device) const { + CHECK(Contains(device)) << "\nDevice " << device << " is not in GPUSet." + << "\nGPUSet: (" << *begin() << ", " << *end() << ")" + << std::endl; + GpuIndex result = device - *devices_.begin(); + return result; } bool IsEmpty() const { return Size() == 0; } - /*! \brief Get un-normalised index. */ - int Index(int device) const { - CHECK(Contains(device)); - return device - *devices_.begin(); - } - bool Contains(int device) const { + bool Contains(GpuIndex device) const { return *devices_.begin() <= device && device < *devices_.end(); } diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 874fd311a98d..87aa050021df 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -53,6 +53,16 @@ T *Raw(thrust::device_vector &v) { // NOLINT return raw_pointer_cast(v.data()); } +inline void CudaCheckPointerDevice(void* ptr) { + cudaPointerAttributes attr; + dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); + int ptr_device = attr.device; + int cur_device = -1; + cudaGetDevice(&cur_device); + CHECK_EQ(ptr_device, cur_device) << "pointer device: " << ptr_device + << "current device: " << cur_device; +} + template const T *Raw(const thrust::device_vector &v) { // NOLINT return raw_pointer_cast(v.data()); @@ -61,7 +71,7 @@ const T *Raw(const thrust::device_vector &v) { // NOLINT // if n_devices=-1, then use all visible devices inline void SynchronizeNDevices(xgboost::GPUSet devices) { devices = devices.IsEmpty() ? xgboost::GPUSet::AllVisible() : devices; - for (auto const d : devices.Unnormalised()) { + for (auto const d : devices) { safe_cuda(cudaSetDevice(d)); safe_cuda(cudaDeviceSynchronize()); } @@ -743,7 +753,8 @@ void SumReduction(dh::CubMemory &tmp_mem, dh::DVec &in, dh::DVec &out, * @param nVals number of elements in the input array */ template -typename std::iterator_traits::value_type SumReduction(dh::CubMemory &tmp_mem, T in, int nVals) { +typename std::iterator_traits::value_type SumReduction( + dh::CubMemory &tmp_mem, T in, int nVals) { using ValueT = typename std::iterator_traits::value_type; size_t tmpSize; dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals)); @@ -900,11 +911,10 @@ class AllReducer { double *recvbuff, int count) { #ifdef XGBOOST_USE_NCCL CHECK(initialised); - - dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx])); + dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx))); dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, - comms[communication_group_idx], - streams[communication_group_idx])); + comms.at(communication_group_idx), + streams.at(communication_group_idx))); #endif } diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index de7a21231d4b..7c9e7605955f 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -352,8 +352,9 @@ struct GPUSketcher { dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { size_t start = dist_.ShardStart(info.num_row_, i); size_t size = dist_.ShardSize(info.num_row_, i); - shard = std::unique_ptr - (new DeviceShard(dist_.Devices()[i], start, start + size, param_)); + shard = std::unique_ptr( + new DeviceShard(dist_.Devices().DeviceId(i), + start, start + size, param_)); }); // compute sketches for each shard @@ -379,8 +380,7 @@ struct GPUSketcher { } GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) { - dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus, n_rows). - Normalised(param_.gpu_id)); + dist_ = GPUDistribution::Block(GPUSet::All(param_.gpu_id, param_.n_gpus, n_rows)); } std::vector> shards_; diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 5477394b7856..12811daa7048 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -46,14 +46,13 @@ template struct HostDeviceVectorImpl { struct DeviceShard { DeviceShard() - : index_(-1), proper_size_(0), device_(-1), start_(0), perm_d_(false), + : proper_size_(0), device_(-1), start_(0), perm_d_(false), cached_size_(~0), vec_(nullptr) {} void Init(HostDeviceVectorImpl* vec, int device) { if (vec_ == nullptr) { vec_ = vec; } CHECK_EQ(vec, vec_); device_ = device; - index_ = vec_->distribution_.devices_.Index(device); LazyResize(vec_->Size()); perm_d_ = vec_->perm_h_.Complementary(); } @@ -62,7 +61,6 @@ struct HostDeviceVectorImpl { if (vec_ == nullptr) { vec_ = vec; } CHECK_EQ(vec, vec_); device_ = other.device_; - index_ = other.index_; cached_size_ = other.cached_size_; start_ = other.start_; proper_size_ = other.proper_size_; @@ -114,10 +112,11 @@ struct HostDeviceVectorImpl { if (new_size == cached_size_) { return; } // resize is required int ndevices = vec_->distribution_.devices_.Size(); - start_ = vec_->distribution_.ShardStart(new_size, index_); - proper_size_ = vec_->distribution_.ShardProperSize(new_size, index_); + int device_index = vec_->distribution_.devices_.Index(device_); + start_ = vec_->distribution_.ShardStart(new_size, device_index); + proper_size_ = vec_->distribution_.ShardProperSize(new_size, device_index); // The size on this device. - size_t size_d = vec_->distribution_.ShardSize(new_size, index_); + size_t size_d = vec_->distribution_.ShardSize(new_size, device_index); SetDevice(); data_.resize(size_d); cached_size_ = new_size; @@ -154,7 +153,6 @@ struct HostDeviceVectorImpl { } } - int index_; int device_; thrust::device_vector data_; // cached vector size @@ -183,13 +181,13 @@ struct HostDeviceVectorImpl { distribution_(other.distribution_), mutex_() { shards_.resize(other.shards_.size()); dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { - shard.Init(this, other.shards_[i]); + shard.Init(this, other.shards_.at(i)); }); } - // Init can be std::vector or std::initializer_list - template - HostDeviceVectorImpl(const Init& init, GPUDistribution distribution) + // Initializer can be std::vector or std::initializer_list + template + HostDeviceVectorImpl(const Initializer& init, GPUDistribution distribution) : distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) { if (!distribution_.IsEmpty()) { size_d_ = init.size(); @@ -204,7 +202,7 @@ struct HostDeviceVectorImpl { int ndevices = distribution_.devices_.Size(); shards_.resize(ndevices); dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { - shard.Init(this, distribution_.devices_[i]); + shard.Init(this, distribution_.devices_.DeviceId(i)); }); } @@ -217,20 +215,20 @@ struct HostDeviceVectorImpl { T* DevicePointer(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kWrite); - return shards_[distribution_.devices_.Index(device)].data_.data().get(); + return shards_.at(distribution_.devices_.Index(device)).data_.data().get(); } const T* ConstDevicePointer(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); - return shards_[distribution_.devices_.Index(device)].data_.data().get(); + return shards_.at(distribution_.devices_.Index(device)).data_.data().get(); } common::Span DeviceSpan(int device) { GPUSet devices = distribution_.devices_; CHECK(devices.Contains(device)); LazySyncDevice(device, GPUAccess::kWrite); - return {shards_[devices.Index(device)].data_.data().get(), + return {shards_.at(devices.Index(device)).data_.data().get(), static_cast::index_type>(DeviceSize(device))}; } @@ -238,20 +236,20 @@ struct HostDeviceVectorImpl { GPUSet devices = distribution_.devices_; CHECK(devices.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); - return {shards_[devices.Index(device)].data_.data().get(), + return {shards_.at(devices.Index(device)).data_.data().get(), static_cast::index_type>(DeviceSize(device))}; } size_t DeviceSize(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); - return shards_[distribution_.devices_.Index(device)].data_.size(); + return shards_.at(distribution_.devices_.Index(device)).data_.size(); } size_t DeviceStart(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); - return shards_[distribution_.devices_.Index(device)].start_; + return shards_.at(distribution_.devices_.Index(device)).start_; } thrust::device_ptr tbegin(int device) { // NOLINT @@ -316,7 +314,7 @@ struct HostDeviceVectorImpl { size_d_ = other->size_d_; } dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { - shard.Copy(&other->shards_[i]); + shard.Copy(&other->shards_.at(i)); }); } @@ -405,7 +403,7 @@ struct HostDeviceVectorImpl { void LazySyncDevice(int device, GPUAccess access) { GPUSet devices = distribution_.Devices(); CHECK(devices.Contains(device)); - shards_[devices.Index(device)].LazySyncDevice(access); + shards_.at(devices.Index(device)).LazySyncDevice(access); } bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); } @@ -413,7 +411,7 @@ struct HostDeviceVectorImpl { bool DeviceCanAccess(int device, GPUAccess access) { GPUSet devices = distribution_.Devices(); if (!devices.Contains(device)) { return false; } - return shards_[devices.Index(device)].perm_d_.CanAccess(access); + return shards_.at(devices.Index(device)).perm_d_.CanAccess(access); } std::vector data_h_; diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index 8daa19fe436b..d1abb604b736 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -78,10 +78,11 @@ void SetCudaSetDeviceHandler(void (*handler)(int)); template struct HostDeviceVectorImpl; -// Distribution for the HostDeviceVector; it specifies such aspects as the devices it is -// distributed on, whether there are copies of elements from other GPUs as well as the granularity -// of splitting. It may also specify explicit boundaries for devices, in which case the size of the -// array cannot be changed. +// Distribution for the HostDeviceVector; it specifies such aspects as the +// devices it is distributed on, whether there are copies of elements from +// other GPUs as well as the granularity of splitting. It may also specify +// explicit boundaries for devices, in which case the size of the array cannot +// be changed. class GPUDistribution { template friend struct HostDeviceVectorImpl; diff --git a/src/common/transform.h b/src/common/transform.h index 9c374beb86c2..ac5c91aeb75e 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -86,11 +86,13 @@ class Transform { // CUDA UnpackHDV template Span UnpackHDV(HostDeviceVector* _vec, int _device) const { - return _vec->DeviceSpan(_device); + auto span = _vec->DeviceSpan(_device); + return span; } template Span UnpackHDV(const HostDeviceVector* _vec, int _device) const { - return _vec->ConstDeviceSpan(_device); + auto span = _vec->ConstDeviceSpan(_device); + return span; } // CPU UnpackHDV template @@ -125,19 +127,23 @@ class Transform { GPUSet devices = distribution_.Devices(); size_t range_size = *range_.end() - *range_.begin(); + + // Extract index to deal with possible old OpenMP. + size_t device_beg = *(devices.begin()); + size_t device_end = *(devices.end()); #pragma omp parallel for schedule(static, 1) if (devices.Size() > 1) - for (omp_ulong i = 0; i < devices.Size(); ++i) { - int d = devices.Index(i); + for (omp_ulong device = device_beg; device < device_end; ++device) { // NOLINT // Ignore other attributes of GPUDistribution for spliting index. - size_t shard_size = - GPUDistribution::Block(devices).ShardSize(range_size, d); + // This deals with situation like multi-class setting where + // granularity is used in data vector. + size_t shard_size = GPUDistribution::Block(devices).ShardSize( + range_size, devices.Index(device)); Range shard_range {0, static_cast(shard_size)}; - dh::safe_cuda(cudaSetDevice(d)); + dh::safe_cuda(cudaSetDevice(device)); const int GRID_SIZE = static_cast(dh::DivRoundUp(*(range_.end()), kBlockThreads)); - detail::LaunchCUDAKernel<<>>( - _func, shard_range, UnpackHDV(_vectors, d)...); + _func, shard_range, UnpackHDV(_vectors, device)...); dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaDeviceSynchronize()); } diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index fbed99425e62..938df76af787 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -90,7 +90,6 @@ void RescaleIndices(size_t ridx_begin, dh::DVec *data) { class DeviceShard { int device_idx_; - int normalised_device_idx_; // Device index counting from param.gpu_id dh::BulkAllocator ba_; std::vector row_ptr_; dh::DVec data_; @@ -100,12 +99,11 @@ class DeviceShard { size_t ridx_end_; public: - DeviceShard(int device_idx, int normalised_device_idx, const SparsePage &batch, + DeviceShard(int device_idx, const SparsePage &batch, bst_uint row_begin, bst_uint row_end, const GPUCoordinateTrainParam ¶m, const gbm::GBLinearModelParam &model_param) : device_idx_(device_idx), - normalised_device_idx_(normalised_device_idx), ridx_begin_(row_begin), ridx_end_(row_end) { dh::safe_cuda(cudaSetDevice(device_idx)); @@ -215,16 +213,16 @@ class GPUCoordinateUpdater : public LinearUpdater { void LazyInitShards(DMatrix *p_fmat, const gbm::GBLinearModelParam &model_param) { if (!shards.empty()) return; - int n_devices = GPUSet::All(param.n_gpus, p_fmat->Info().num_row_).Size(); + + dist_ = GPUDistribution::Block(GPUSet::All(param.gpu_id, param.n_gpus, + p_fmat->Info().num_row_)); + auto devices = dist_.Devices(); + + int n_devices = devices.Size(); bst_uint row_begin = 0; bst_uint shard_size = std::ceil(static_cast(p_fmat->Info().num_row_) / n_devices); - device_list.resize(n_devices); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - int device_idx = GPUSet::GetDeviceIdx(param.gpu_id + d_idx); - device_list[d_idx] = device_idx; - } // Partition input matrix into row segments std::vector row_segments; row_segments.push_back(0); @@ -240,13 +238,14 @@ class GPUCoordinateUpdater : public LinearUpdater { shards.resize(n_devices); // Create device shards - dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { - auto idx = &shard - &shards[0]; - shard = std::unique_ptr( - new DeviceShard(device_list[idx], idx, batch, row_segments[idx], - row_segments[idx + 1], param, model_param)); - }); + dh::ExecuteIndexShards(&shards, + [&](int i, std::unique_ptr& shard) { + shard = std::unique_ptr( + new DeviceShard(devices.DeviceId(i), batch, row_segments[i], + row_segments[i + 1], param, model_param)); + }); } + void Update(HostDeviceVector *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model, double sum_instance_weight) override { param.DenormalizePenalties(sum_instance_weight); @@ -329,11 +328,11 @@ class GPUCoordinateUpdater : public LinearUpdater { // training parameter GPUCoordinateTrainParam param; + GPUDistribution dist_; std::unique_ptr selector; common::Monitor monitor; std::vector> shards; - std::vector device_list; }; DMLC_REGISTER_PARAMETER(GPUCoordinateTrainParam); diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index e46716ce4349..9c218a266b04 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -38,7 +38,7 @@ class HingeObj : public ObjFunction { void Configure( const std::vector > &args) override { param_.InitAllowUnknown(args); - devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + devices_ = GPUSet::All(param_.gpu_id, param_.n_gpus); label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); } diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 317adda707f4..a7919f9a5130 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -50,7 +50,7 @@ class SoftmaxMultiClassObj : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 - devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + devices_ = GPUSet::All(param_.gpu_id, param_.n_gpus); label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); } void GetGradient(const HostDeviceVector& preds, diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 590072d8f9d2..a6c474f2a399 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -54,7 +54,7 @@ class RegLossObj : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 - devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + devices_ = GPUSet::All(param_.gpu_id, param_.n_gpus); label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); } @@ -198,7 +198,7 @@ class PoissonRegression : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 - devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + devices_ = GPUSet::All(param_.gpu_id, param_.n_gpus); label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); } @@ -380,7 +380,7 @@ class GammaRegression : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 - devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + devices_ = GPUSet::All(param_.gpu_id, param_.n_gpus); label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); } @@ -477,7 +477,7 @@ class TweedieRegression : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1 - devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id); + devices_ = GPUSet::All(param_.gpu_id, param_.n_gpus); label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size()); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 3787a9b1c9a9..cf0c7e9ddcb7 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -230,7 +230,7 @@ class GPUPredictor : public xgboost::Predictor { offsets[0] = 0; #pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1) for (int shard = 0; shard < devices_.Size(); ++shard) { - int device = devices_[shard]; + int device = devices_.DeviceId(shard); auto data_span = data.DeviceSpan(device); dh::safe_cuda(cudaSetDevice(device)); // copy the last element from every shard @@ -271,6 +271,7 @@ class GPUPredictor : public xgboost::Predictor { const int BLOCK_THREADS = 128; size_t num_rows = batch.offset.DeviceSize(device_) - 1; + if (num_rows < 1) { return; } const int GRID_SIZE = static_cast(dh::DivRoundUp(num_rows, BLOCK_THREADS)); @@ -282,8 +283,8 @@ class GPUPredictor : public xgboost::Predictor { use_shared = false; } const auto& data_distr = batch.data.Distribution(); - int index = data_distr.Devices().Index(device_); - size_t entry_start = data_distr.ShardStart(batch.data.Size(), index); + size_t entry_start = data_distr.ShardStart(batch.data.Size(), + data_distr.Devices().Index(device_)); PredictKernel<<>> (dh::ToSpan(nodes), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments), @@ -291,6 +292,7 @@ class GPUPredictor : public xgboost::Predictor { batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, num_rows, entry_start, use_shared, model.param.num_output_group); + dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaDeviceSynchronize()); } @@ -350,7 +352,7 @@ class GPUPredictor : public xgboost::Predictor { const gbm::GBTreeModel& model, int tree_begin, unsigned ntree_limit = 0) override { GPUSet devices = GPUSet::All( - param.n_gpus, dmat->Info().num_row_).Normalised(param.gpu_id); + param.gpu_id, param.n_gpus, dmat->Info().num_row_); ConfigureShards(devices); if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { @@ -464,7 +466,7 @@ class GPUPredictor : public xgboost::Predictor { cpu_predictor->Init(cfg, cache); param.InitAllowUnknown(cfg); - GPUSet devices = GPUSet::All(param.n_gpus).Normalised(param.gpu_id); + GPUSet devices = GPUSet::All(param.gpu_id, param.n_gpus); ConfigureShards(devices); } @@ -477,7 +479,7 @@ class GPUPredictor : public xgboost::Predictor { shards.clear(); shards.resize(devices_.Size()); dh::ExecuteIndexShards(&shards, [=](size_t i, DeviceShard& shard){ - shard.Init(devices_[i]); + shard.Init(devices_.DeviceId(i)); }); } diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index cba1287044e5..fff55c69b458 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -376,7 +376,7 @@ void argMaxByKey(ExactSplitCandidate* nodeSplits, const GradientPair* gradScans, NodeIdT nodeStart, int len, const TrainParam param, ArgMaxByKeyAlgo algo) { dh::FillConst( - GPUSet::GetDeviceIdx(param.gpu_id), nodeSplits, nUniqKeys, + param.gpu_id, nodeSplits, nUniqKeys, ExactSplitCandidate()); int nBlks = dh::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM); switch (algo) { @@ -517,7 +517,7 @@ class GPUMaker : public TreeUpdater { maxNodes = (1 << (param.max_depth + 1)) - 1; maxLeaves = 1 << param.max_depth; - devices_ = GPUSet::All(param.n_gpus).Normalised(param.gpu_id); + devices_ = GPUSet::All(param.gpu_id, param.n_gpus); } void Update(HostDeviceVector* gpair, DMatrix* dmat, @@ -625,7 +625,7 @@ class GPUMaker : public TreeUpdater { void allocateAllData(int offsetSize) { int tmpBuffSize = ScanTempBufferSize(nVals); - ba.Allocate(GPUSet::GetDeviceIdx(param.gpu_id), param.silent, &vals, nVals, + ba.Allocate(param.gpu_id, param.silent, &vals, nVals, &vals_cached, nVals, &instIds, nVals, &instIds_cached, nVals, &colOffsets, offsetSize, &gradsInst, nRows, &nodeAssigns, nVals, &nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst, @@ -635,9 +635,9 @@ class GPUMaker : public TreeUpdater { } void setupOneTimeData(DMatrix* dmat) { - size_t free_memory = dh::AvailableMemory(GPUSet::GetDeviceIdx(param.gpu_id)); + size_t free_memory = dh::AvailableMemory(param.gpu_id); if (!dmat->SingleColBlock()) { - throw std::runtime_error("exact::GPUBuilder - must have 1 column block"); + LOG(FATAL) << "exact::GPUBuilder - must have 1 column block"; } std::vector fval; std::vector fId; @@ -724,7 +724,7 @@ class GPUMaker : public TreeUpdater { nodeAssigns.Current(), instIds.Current(), nodes.Data(), colOffsets.Data(), vals.Current(), nVals, nCols); // gather the node assignments across all other columns too - dh::Gather(GPUSet::GetDeviceIdx(param.gpu_id), nodeAssigns.Current(), + dh::Gather(param.gpu_id, nodeAssigns.Current(), nodeAssignsPerInst.Data(), instIds.Current(), nVals); sortKeys(level); } @@ -735,7 +735,7 @@ class GPUMaker : public TreeUpdater { // but we don't need more than level+1 bits for sorting! SegmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols, colOffsets, 0, level + 1); - dh::Gather(GPUSet::GetDeviceIdx(param.gpu_id), vals.other(), + dh::Gather(param.gpu_id, vals.other(), vals.Current(), instIds.other(), instIds.Current(), nodeLocations.Current(), nVals); vals.buff().selector ^= 1; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 45d307078508..f7cb35a51be6 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -251,15 +251,15 @@ struct DeviceHistogram { thrust::device_vector data; const size_t kStopGrowingSize = 1 << 26; // Do not grow beyond this size int n_bins; - int device_idx; + int device_id_; void Init(int device_idx, int n_bins) { this->n_bins = n_bins; - this->device_idx = device_idx; + this->device_id_ = device_idx; } void Reset() { - dh::safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaSetDevice(device_id_)); data.resize(0); nidx_map.clear(); } @@ -281,7 +281,7 @@ struct DeviceHistogram { } else { // Append new node histogram nidx_map[nidx] = data.size(); - dh::safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaSetDevice(device_id_)); // x 2: Hess and Grad. data.resize(data.size() + (n_bins * 2)); } @@ -396,13 +396,12 @@ struct DeviceShard; struct GPUHistBuilderBase { public: virtual void Build(DeviceShard* shard, int idx) = 0; + virtual ~GPUHistBuilderBase() = default; }; // Manage memory for a single GPU struct DeviceShard { - int device_idx; - /*! \brief Device index counting from param.gpu_id */ - int normalised_device_idx; + int device_idx_; dh::BulkAllocator ba; /*! \brief HistCutMatrix stored in device. */ @@ -463,10 +462,9 @@ struct DeviceShard { std::unique_ptr hist_builder; // TODO(canonizer): do add support multi-batch DMatrix here - DeviceShard(int device_idx, int normalised_device_idx, + DeviceShard(int device_idx, bst_uint row_begin, bst_uint row_end, TrainParam _param) : - device_idx(device_idx), - normalised_device_idx(normalised_device_idx), + device_idx_(device_idx), row_begin_idx(row_begin), row_end_idx(row_end), row_stride(0), @@ -479,7 +477,7 @@ struct DeviceShard { /* Init row_ptrs and row_stride */ void InitRowPtrs(const SparsePage& row_batch) { - dh::safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaSetDevice(device_idx_)); const auto& offset_vec = row_batch.offset.HostVector(); row_ptrs.resize(n_rows + 1); thrust::copy(offset_vec.data() + row_begin_idx, @@ -537,7 +535,7 @@ struct DeviceShard { // Reset values for each update iteration void Reset(HostDeviceVector* dh_gpair) { - dh::safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaSetDevice(device_idx_)); position.CurrentDVec().Fill(0); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPair()); @@ -546,7 +544,8 @@ struct DeviceShard { std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0)); ridx_segments.front() = Segment(0, ridx.Size()); - this->gpair.copy(dh_gpair->tcbegin(device_idx), dh_gpair->tcend(device_idx)); + this->gpair.copy(dh_gpair->tcbegin(device_idx_), + dh_gpair->tcend(device_idx_)); SubsampleGradientPair(&gpair, param.subsample, row_begin_idx); hist.Reset(); } @@ -562,7 +561,7 @@ struct DeviceShard { auto d_node_hist_histogram = hist.GetHistPtr(nidx_histogram); auto d_node_hist_subtraction = hist.GetHistPtr(nidx_subtraction); - dh::LaunchN(device_idx, hist.n_bins, [=] __device__(size_t idx) { + dh::LaunchN(device_idx_, hist.n_bins, [=] __device__(size_t idx) { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); @@ -589,7 +588,7 @@ struct DeviceShard { int64_t split_gidx, bool default_dir_left, bool is_dense, int fidx_begin, // cut.row_ptr[fidx] int fidx_end) { // cut.row_ptr[fidx + 1] - dh::safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaSetDevice(device_idx_)); temp_memory.LazyAllocate(sizeof(int64_t)); int64_t* d_left_count = temp_memory.Pointer(); dh::safe_cuda(cudaMemset(d_left_count, 0, sizeof(int64_t))); @@ -600,7 +599,7 @@ struct DeviceShard { size_t row_stride = this->row_stride; // Launch 1 thread for each row dh::LaunchN<1, 512>( - device_idx, segment.Size(), [=] __device__(bst_uint idx) { + device_idx_, segment.Size(), [=] __device__(bst_uint idx) { idx += segment.begin; bst_uint ridx = d_ridx[idx]; auto row_begin = row_stride * ridx; @@ -669,7 +668,7 @@ struct DeviceShard { } void UpdatePredictionCache(bst_float* out_preds_d) { - dh::safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaSetDevice(device_idx_)); if (!prediction_cache_initialised) { dh::safe_cuda(cudaMemcpy( prediction_cache.Data(), out_preds_d, @@ -689,7 +688,7 @@ struct DeviceShard { auto d_prediction_cache = prediction_cache.Data(); dh::LaunchN( - device_idx, prediction_cache.Size(), [=] __device__(int local_idx) { + device_idx_, prediction_cache.Size(), [=] __device__(int local_idx) { int pos = d_position[local_idx]; bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]); d_prediction_cache[d_ridx[local_idx]] += @@ -723,7 +722,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase { if (grid_size <= 0) { return; } - dh::safe_cuda(cudaSetDevice(shard->device_idx)); + dh::safe_cuda(cudaSetDevice(shard->device_idx_)); sharedMemHistKernel<<>> (shard->row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist, d_gpair, segment_begin, n_elements); @@ -742,7 +741,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase { size_t const row_stride = shard->row_stride; int const null_gidx_value = shard->null_gidx_value; - dh::LaunchN(shard->device_idx, n_elements, [=] __device__(size_t idx) { + dh::LaunchN(shard->device_idx_, n_elements, [=] __device__(size_t idx) { int ridx = d_ridx[(idx / row_stride) + segment.begin]; // lookup the index (bin) of histogram. int gidx = d_gidx[ridx * row_stride + idx % row_stride]; @@ -762,7 +761,7 @@ inline void DeviceShard::InitCompressedData( int max_nodes = param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth); - ba.Allocate(device_idx, param.silent, + ba.Allocate(device_idx_, param.silent, &gpair, n_rows, &ridx, n_rows, &position, n_rows, @@ -780,7 +779,7 @@ inline void DeviceShard::InitCompressedData( node_sum_gradients.resize(max_nodes); ridx_segments.resize(max_nodes); - dh::safe_cuda(cudaSetDevice(device_idx)); + dh::safe_cuda(cudaSetDevice(device_idx_)); // allocate compressed bin data int num_symbols = n_bins + 1; @@ -792,7 +791,7 @@ inline void DeviceShard::InitCompressedData( CHECK(!(param.max_leaves == 0 && param.max_depth == 0)) << "Max leaves and max depth cannot both be unconstrained for " "gpu_hist."; - ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes); + ba.Allocate(device_idx_, param.silent, &gidx_buffer, compressed_size_bytes); gidx_buffer.Fill(0); int nbits = common::detail::SymbolBits(num_symbols); @@ -804,7 +803,7 @@ inline void DeviceShard::InitCompressedData( // check if we can use shared memory for building histograms // (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding) auto histogram_size = sizeof(GradientPairSumT) * null_gidx_value; - auto max_smem = dh::MaxSharedMemory(device_idx); + auto max_smem = dh::MaxSharedMemory(device_idx_); if (histogram_size <= max_smem) { hist_builder.reset(new SharedMemHistBuilder); } else { @@ -812,7 +811,7 @@ inline void DeviceShard::InitCompressedData( } // Init histogram - hist.Init(device_idx, hmat.row_ptr.back()); + hist.Init(device_idx_, hmat.row_ptr.back()); dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t))); } @@ -820,9 +819,10 @@ inline void DeviceShard::InitCompressedData( inline void DeviceShard::CreateHistIndices(const SparsePage& row_batch) { int num_symbols = n_bins + 1; // bin and compress entries in batches of rows - size_t gpu_batch_nrows = std::min - (dh::TotalMemory(device_idx) / (16 * row_stride * sizeof(Entry)), - static_cast(n_rows)); + size_t gpu_batch_nrows = + std::min + (dh::TotalMemory(device_idx_) / (16 * row_stride * sizeof(Entry)), + static_cast(n_rows)); const std::vector& data_vec = row_batch.data.HostVector(); thrust::device_vector entries_d(gpu_batch_nrows * row_stride); @@ -876,8 +876,7 @@ class GPUHistMaker : public TreeUpdater { param_.InitAllowUnknown(args); CHECK(param_.n_gpus != 0) << "Must have at least one device"; n_devices_ = param_.n_gpus; - dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus) - .Normalised(param_.gpu_id)); + dist_ = GPUDistribution::Block(GPUSet::All(param_.gpu_id, param_.n_gpus)); dh::CheckComputeCapability(); @@ -914,12 +913,12 @@ class GPUHistMaker : public TreeUpdater { void InitDataOnce(DMatrix* dmat) { info_ = &dmat->Info(); - int n_devices = GPUSet::All(param_.n_gpus, info_->num_row_).Size(); + int n_devices = dist_.Devices().Size(); device_list_.resize(n_devices); - for (int d_idx = 0; d_idx < n_devices; ++d_idx) { - int device_idx = GPUSet::GetDeviceIdx(param_.gpu_id + d_idx); - device_list_[d_idx] = device_idx; + for (int index = 0; index < n_devices; ++index) { + int device_id = dist_.Devices().DeviceId(index); + device_list_[index] = device_id; } reducer_.Init(device_list_); @@ -932,8 +931,8 @@ class GPUHistMaker : public TreeUpdater { size_t start = dist_.ShardStart(info_->num_row_, i); size_t size = dist_.ShardSize(info_->num_row_, i); shard = std::unique_ptr - (new DeviceShard(device_list_.at(i), i, - start, start + size, param_)); + (new DeviceShard(dist_.Devices().DeviceId(i), + start, start + size, param_)); shard->InitRowPtrs(batch); }); @@ -979,7 +978,7 @@ class GPUHistMaker : public TreeUpdater { for (auto& shard : shards_) { auto d_node_hist = shard->hist.GetHistPtr(nidx); reducer_.AllReduceSum( - shard->normalised_device_idx, + dist_.Devices().Index(shard->device_idx_), reinterpret_cast(d_node_hist), reinterpret_cast(d_node_hist), n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT))); @@ -1050,7 +1049,7 @@ class GPUHistMaker : public TreeUpdater { // FIXME: Multi-gpu support? // Use first device auto& shard = shards_.front(); - dh::safe_cuda(cudaSetDevice(shard->device_idx)); + dh::safe_cuda(cudaSetDevice(shard->device_idx_)); shard->temp_memory.LazyAllocate(candidates_size_bytes); auto d_split = shard->temp_memory.Pointer(); @@ -1063,7 +1062,7 @@ class GPUHistMaker : public TreeUpdater { int depth = p_tree->GetDepth(nidx); HostDeviceVector& feature_set = column_sampler_.GetFeatureSet(depth); - feature_set.Reshard(GPUSet::Range(shard->device_idx, 1)); + feature_set.Reshard(GPUSet::Range(shard->device_idx_, 1)); auto& h_feature_set = feature_set.HostVector(); // One block for each feature int constexpr BLOCK_THREADS = 256; @@ -1071,7 +1070,7 @@ class GPUHistMaker : public TreeUpdater { <<>>( shard->hist.GetHistPtr(nidx), info_->num_col_, - feature_set.DevicePointer(shard->device_idx), + feature_set.DevicePointer(shard->device_idx_), node, shard->cut_.feature_segments.Data(), shard->cut_.min_fvalue.Data(), @@ -1105,7 +1104,7 @@ class GPUHistMaker : public TreeUpdater { std::vector tmp_sums(shards_.size()); dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr& shard) { - dh::safe_cuda(cudaSetDevice(shard->device_idx)); + dh::safe_cuda(cudaSetDevice(shard->device_idx_)); tmp_sums[i] = dh::SumReduction(shard->temp_memory, shard->gpair.Data(), shard->gpair.Size()); @@ -1265,7 +1264,8 @@ class GPUHistMaker : public TreeUpdater { return false; p_out_preds->Reshard(dist_.Devices()); dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { - shard->UpdatePredictionCache(p_out_preds->DevicePointer(shard->device_idx)); + shard->UpdatePredictionCache( + p_out_preds->DevicePointer(shard->device_idx_)); }); monitor_.Stop("UpdatePredictionCache", dist_.Devices()); return true; @@ -1336,6 +1336,7 @@ class GPUHistMaker : public TreeUpdater { common::Monitor monitor_; dh::AllReducer reducer_; std::vector node_value_constraints_; + /*! List storing device id. */ std::vector device_list_; DMatrix* p_last_fmat_; diff --git a/tests/cpp/common/test_common.cc b/tests/cpp/common/test_common.cc index 655e9a00028c..ba6946e80854 100644 --- a/tests/cpp/common/test_common.cc +++ b/tests/cpp/common/test_common.cc @@ -10,11 +10,7 @@ TEST(GPUSet, Basic) { ASSERT_TRUE(devices != GPUSet::Empty()); EXPECT_EQ(devices.Size(), 1); - EXPECT_ANY_THROW(devices.Index(1)); - EXPECT_ANY_THROW(devices.Index(-1)); - devices = GPUSet::Range(1, 0); - EXPECT_EQ(devices, GPUSet::Empty()); EXPECT_EQ(devices.Size(), 0); EXPECT_TRUE(devices.IsEmpty()); @@ -25,18 +21,17 @@ TEST(GPUSet, Basic) { EXPECT_EQ(devices.Size(), 0); EXPECT_TRUE(devices.IsEmpty()); - devices = GPUSet::Range(2, 8); + devices = GPUSet::Range(2, 8); // 2 ~ 10 EXPECT_EQ(devices.Size(), 8); - EXPECT_ANY_THROW(devices[8]); - EXPECT_ANY_THROW(devices.Index(0)); + EXPECT_ANY_THROW(devices.DeviceId(8)); - devices = devices.Unnormalised(); + auto device_id = devices.DeviceId(0); + EXPECT_EQ(device_id, 2); + auto device_index = devices.Index(2); + EXPECT_EQ(device_index, 0); - EXPECT_EQ(*devices.begin(), 0); - EXPECT_EQ(*devices.end(), devices.Size()); #ifndef XGBOOST_USE_CUDA EXPECT_EQ(GPUSet::AllVisible(), GPUSet::Empty()); #endif } } // namespace xgboost - diff --git a/tests/cpp/common/test_common.cu b/tests/cpp/common/test_common.cu index 90ad56a1493a..069fe7aef798 100644 --- a/tests/cpp/common/test_common.cu +++ b/tests/cpp/common/test_common.cu @@ -7,12 +7,10 @@ TEST(GPUSet, GPUBasic) { GPUSet devices = GPUSet::Empty(); ASSERT_TRUE(devices.IsEmpty()); - devices = GPUSet{0, 1}; + devices = GPUSet{1, 1}; ASSERT_TRUE(devices != GPUSet::Empty()); EXPECT_EQ(devices.Size(), 1); - - EXPECT_ANY_THROW(devices.Index(1)); - EXPECT_ANY_THROW(devices.Index(-1)); + EXPECT_EQ(*(devices.begin()), 1); devices = GPUSet::Range(1, 0); EXPECT_EQ(devices, GPUSet::Empty()); @@ -23,15 +21,12 @@ TEST(GPUSet, GPUBasic) { devices = GPUSet::Range(2, -1); EXPECT_EQ(devices, GPUSet::Empty()); - EXPECT_EQ(devices.Size(), 0); - EXPECT_TRUE(devices.IsEmpty()); devices = GPUSet::Range(2, 8); EXPECT_EQ(devices.Size(), 8); - devices = devices.Unnormalised(); - EXPECT_EQ(*devices.begin(), 0); - EXPECT_EQ(*devices.end(), devices.Size()); + EXPECT_EQ(*devices.begin(), 2); + EXPECT_EQ(*devices.end(), 2 + devices.Size()); EXPECT_EQ(8, devices.Size()); ASSERT_NO_THROW(GPUSet::AllVisible()); @@ -41,4 +36,21 @@ TEST(GPUSet, GPUBasic) { } } +#if defined(XGBOOST_USE_NCCL) +TEST(GPUSet, MGPU_GPUBasic) { + { + GPUSet devices = GPUSet::All(1, 1); + ASSERT_EQ(*(devices.begin()), 1); + ASSERT_EQ(*(devices.end()), 2); + ASSERT_EQ(devices.Size(), 1); + ASSERT_TRUE(devices.Contains(1)); + } + + { + GPUSet devices = GPUSet::All(0, -1); + ASSERT_GE(devices.Size(), 2); + } +} +#endif + } // namespace xgboost diff --git a/tests/cpp/common/test_transform_range.cu b/tests/cpp/common/test_transform_range.cu index 39517cedcda2..abd11e86b1e2 100644 --- a/tests/cpp/common/test_transform_range.cu +++ b/tests/cpp/common/test_transform_range.cu @@ -38,6 +38,73 @@ TEST(Transform, MGPU_Basic) { ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); } +// Test for multi-classes setting. +template +struct TestTransformRangeGranular { + const size_t granularity = 8; + + TestTransformRangeGranular(const size_t granular) : granularity{granular} {} + void XGBOOST_DEVICE operator()(size_t _idx, + Span _out, Span _in) { + auto in_sub = _in.subspan(_idx * granularity, granularity); + auto out_sub = _out.subspan(_idx * granularity, granularity); + for (size_t i = 0; i < granularity; ++i) { + out_sub[i] = in_sub[i]; + } + } +}; + +TEST(Transform, MGPU_Granularity) { + GPUSet devices = GPUSet::All(0, -1); + + const size_t size {8990}; + const size_t granularity = 10; + + GPUDistribution distribution = + GPUDistribution::Granular(devices, granularity); + + std::vector h_in(size); + std::vector h_out(size); + InitializeRange(h_in.begin(), h_in.end()); + std::vector h_sol(size); + InitializeRange(h_sol.begin(), h_sol.end()); + + const HostDeviceVector in_vec {h_in, distribution}; + HostDeviceVector out_vec {h_out, distribution}; + + ASSERT_NO_THROW( + Transform<>::Init( + TestTransformRangeGranular{granularity}, + Range{0, size / granularity}, + distribution) + .Eval(&out_vec, &in_vec)); + std::vector res = out_vec.HostVector(); + + ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); +} + +TEST(Transform, MGPU_SpecifiedGpuId) { + // Use 1 GPU, Numbering of GPU starts from 1 + auto devices = GPUSet::All(1, 1); + const size_t size {256}; + std::vector h_in(size); + std::vector h_out(size); + InitializeRange(h_in.begin(), h_in.end()); + std::vector h_sol(size); + InitializeRange(h_sol.begin(), h_sol.end()); + + const HostDeviceVector in_vec {h_in, + GPUDistribution::Block(devices)}; + HostDeviceVector out_vec {h_out, + GPUDistribution::Block(devices)}; + + ASSERT_NO_THROW( + Transform<>::Init(TestTransformRange{}, Range{0, size}, devices) + .Eval(&out_vec, &in_vec)); + std::vector res = out_vec.HostVector(); + ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); +} + } // namespace xgboost } // namespace common #endif \ No newline at end of file diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 40228b8c1201..cf7641ee0f16 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -48,7 +48,7 @@ TEST(GpuHist, BuildGidxDense) { param.n_gpus = 1; param.max_leaves = 0; - DeviceShard shard(0, 0, 0, n_rows, param); + DeviceShard shard(0, 0, n_rows, param); BuildGidx(&shard, n_rows, n_cols); std::vector h_gidx_buffer; @@ -87,7 +87,7 @@ TEST(GpuHist, BuildGidxSparse) { param.n_gpus = 1; param.max_leaves = 0; - DeviceShard shard(0, 0, 0, n_rows, param); + DeviceShard shard(0, 0, n_rows, param); BuildGidx(&shard, n_rows, n_cols, 0.9f); std::vector h_gidx_buffer; @@ -130,7 +130,7 @@ void TestBuildHist(GPUHistBuilderBase& builder) { param.n_gpus = 1; param.max_leaves = 0; - DeviceShard shard(0, 0, 0, n_rows, param); + DeviceShard shard(0, 0, n_rows, param); BuildGidx(&shard, n_rows, n_cols); @@ -236,7 +236,7 @@ TEST(GpuHist, EvaluateSplits) { int max_bins = 4; // Initialize DeviceShard - std::unique_ptr shard {new DeviceShard(0, 0, 0, n_rows, param)}; + std::unique_ptr shard {new DeviceShard(0, 0, n_rows, param)}; // Initialize DeviceShard::node_sum_gradients shard->node_sum_gradients = {{6.4, 12.8}}; @@ -316,7 +316,7 @@ TEST(GpuHist, ApplySplit) { } hist_maker.shards_.resize(1); - hist_maker.shards_[0].reset(new DeviceShard(0, 0, 0, n_rows, param)); + hist_maker.shards_[0].reset(new DeviceShard(0, 0, n_rows, param)); auto& shard = hist_maker.shards_.at(0); shard->ridx_segments.resize(3); // 3 nodes. diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index aab07431d2f4..4f8ac58b5a02 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -1,21 +1,25 @@ import numpy as np import sys import unittest +from nose.plugins.attrib import attr sys.path.append("tests/python") import xgboost as xgb from regression_test_utilities import run_suite, parameter_combinations, \ assert_results_non_increasing -from nose.plugins.attrib import attr + def assert_gpu_results(cpu_results, gpu_results): for cpu_res, gpu_res in zip(cpu_results, gpu_results): # Check final eval result roughly equivalent - assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-2, 1e-2) + assert np.allclose(cpu_res["eval"][-1], + gpu_res["eval"][-1], 1e-2, 1e-2) + datasets = ["Boston", "Cancer", "Digits", "Sparse regression", "Sparse regression with weights", "Small weights regression"] + class TestGPU(unittest.TestCase): def test_gpu_exact(self): variable_param = {'max_depth': [2, 6, 15], } @@ -28,8 +32,10 @@ def test_gpu_exact(self): assert_gpu_results(cpu_results, gpu_results) def test_gpu_hist(self): - variable_param = {'n_gpus': [-1], 'max_depth': [2, 8], 'max_leaves': [255, 4], - 'max_bin': [2, 256], 'min_child_weight': [0, 1], 'lambda': [0.0, 1.0], + variable_param = {'n_gpus': [-1], 'max_depth': [2, 8], + 'max_leaves': [255, 4], + 'max_bin': [2, 256], 'min_child_weight': [0, 1], + 'lambda': [0.0, 1.0], 'grow_policy': ['lossguide']} for param in parameter_combinations(variable_param): param['tree_method'] = 'gpu_hist' @@ -41,14 +47,24 @@ def test_gpu_hist(self): @attr('mgpu') def test_gpu_hist_mgpu(self): - variable_param = {'n_gpus': [-1], 'max_depth': [2, 10], 'max_leaves': [255, 4], + variable_param = {'n_gpus': [-1], 'max_depth': [2, 10], + 'max_leaves': [255, 4], 'max_bin': [2, 256], 'grow_policy': ['lossguide']} for param in parameter_combinations(variable_param): param['tree_method'] = 'gpu_hist' gpu_results = run_suite(param, select_datasets=datasets) assert_results_non_increasing(gpu_results, 1e-2) - # FIXME: re-enable next three lines, to compare against CPU - #param['tree_method'] = 'hist' - #cpu_results = run_suite(param, select_datasets=datasets) - #assert_gpu_results(cpu_results, gpu_results) + + @attr('mgpu') + def test_specified_gpu_id_gpu_update(self): + variable_param = {'n_gpus': [1], + 'gpu_id': [1], + 'max_depth': [8], + 'max_leaves': [255, 4], + 'max_bin': [2, 64], + 'grow_policy': ['lossguide'], + 'tree_method': ['gpu_hist', 'gpu_exact']} + for param in parameter_combinations(variable_param): + gpu_results = run_suite(param, select_datasets=datasets) + assert_results_non_increasing(gpu_results, 1e-2)