diff --git a/CMakeLists.txt b/CMakeLists.txt index ede6c5b755af..a026888dfaad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,7 @@ option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF) option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF) ## CUDA option(USE_CUDA "Build with GPU acceleration" OFF) +option(USE_PER_THREAD_DEFAULT_STREAM "Build with per-thread default stream" ON) option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF) option(BUILD_WITH_SHARED_NCCL "Build with shared NCCL library." OFF) set(GPU_COMPUTE_VER "" CACHE STRING diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index cb239f79c9fa..98e96e304cb9 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -133,6 +133,11 @@ function(xgboost_set_cuda_flags target) $<$:-Xcompiler=${OpenMP_CXX_FLAGS}> $<$:-Xfatbin=-compress-all>) + if (USE_PER_THREAD_DEFAULT_STREAM) + target_compile_options(${target} PRIVATE + $<$:--default-stream per-thread>) + endif (USE_PER_THREAD_DEFAULT_STREAM) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18") set_property(TARGET ${target} PROPERTY CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}) endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18") diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index f514eaa68b20..30f8d036bf49 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -422,7 +422,10 @@ object XGBoost extends Serializable { }} - val (booster, metrics) = boostersAndMetrics.collect()(0) + // The repartition step is to make training stage as ShuffleMapStage, so that when one + // of the training task fails the training stage can retry. ResultStage won't retry + // when it fails. + val (booster, metrics) = boostersAndMetrics.repartition(1).collect()(0) val trackerReturnVal = tracker.waitFor(0L) logger.info(s"Rabit returns with exit code $trackerReturnVal") if (trackerReturnVal != 0) { diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu index 470700d2d36e..51fa5693cf50 100644 --- a/src/collective/nccl_device_communicator.cu +++ b/src/collective/nccl_device_communicator.cu @@ -44,16 +44,12 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sy nccl_unique_id_ = GetUniqueId(); dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_)); - dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); } NcclDeviceCommunicator::~NcclDeviceCommunicator() { if (world_size_ == 1) { return; } - if (cuda_stream_) { - dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); - } if (nccl_comm_) { dh::safe_nccl(ncclCommDestroy(nccl_comm_)); } @@ -123,8 +119,8 @@ ncclRedOp_t GetNcclRedOp(Operation const &op) { template void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size, - std::size_t size, cudaStream_t stream) { - dh::LaunchN(size, stream, [=] __device__(std::size_t idx) { + std::size_t size) { + dh::LaunchN(size, [=] __device__(std::size_t idx) { auto result = device_buffer[idx]; for (auto rank = 1; rank < world_size; rank++) { result = func(result, device_buffer[rank * size + idx]); @@ -142,25 +138,22 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si // First gather data from all the workers. dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), - nccl_comm_, cuda_stream_)); + nccl_comm_, dh::DefaultStream())); if (needs_sync_) { - dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); + dh::DefaultStream().Sync(); } // Then reduce locally. auto *out_buffer = static_cast(send_receive_buffer); switch (op) { case Operation::kBitwiseAND: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and(), world_size_, size, - cuda_stream_); + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and(), world_size_, size); break; case Operation::kBitwiseOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or(), world_size_, size, - cuda_stream_); + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or(), world_size_, size); break; case Operation::kBitwiseXOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor(), world_size_, size, - cuda_stream_); + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor(), world_size_, size); break; default: LOG(FATAL) << "Not a bitwise reduce operation."; @@ -179,7 +172,7 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co } else { dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, - cuda_stream_)); + dh::DefaultStream())); } allreduce_bytes_ += count * GetTypeSize(data_type); allreduce_calls_ += 1; @@ -206,7 +199,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b for (int32_t i = 0; i < world_size_; ++i) { size_t as_bytes = segments->at(i); dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, - ncclChar, i, nccl_comm_, cuda_stream_)); + ncclChar, i, nccl_comm_, dh::DefaultStream())); offset += as_bytes; } dh::safe_nccl(ncclGroupEnd()); @@ -217,7 +210,7 @@ void NcclDeviceCommunicator::Synchronize() { return; } dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); + dh::DefaultStream().Sync(); } } // namespace collective diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index bb3fce45c0ff..d99002685eb2 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -77,7 +77,6 @@ class NcclDeviceCommunicator : public DeviceCommunicator { int const world_size_; int const rank_; ncclComm_t nccl_comm_{}; - cudaStream_t cuda_stream_{}; ncclUniqueId nccl_unique_id_{}; size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index db38b2222e4c..dfaac9c35984 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -480,7 +480,7 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator { cub::CachingDeviceAllocator& GetGlobalCachingAllocator() { // Configure allocator with maximum cached bin size of ~1GB and no limit on // maximum cached bytes - static cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29); + thread_local cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29); return *allocator; } pointer allocate(size_t n) { // NOLINT @@ -1176,7 +1176,13 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream})); } -inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; } +inline CUDAStreamView DefaultStream() { +#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM + return CUDAStreamView{cudaStreamPerThread}; +#else + return CUDAStreamView{cudaStreamLegacy}; +#endif +} class CUDAStream { cudaStream_t stream_; diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index f13f01b3e9ed..d7be12749a02 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -134,12 +134,12 @@ void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan batch_iter CHECK(!force_use_u64); auto kernel = GetColumnSizeSharedMemKernel; auto grid_size = EstimateGridSize(device, kernel, required_shared_memory); - dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}( + dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}( kernel, batch_iter, is_valid, out_column_size); } else { auto kernel = GetColumnSizeSharedMemKernel; auto grid_size = EstimateGridSize(device, kernel, required_shared_memory); - dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}( + dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}( kernel, batch_iter, is_valid, out_column_size); } } else { diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 015d817f3640..78b04883ce32 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -18,12 +18,10 @@ RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) dh::safe_cuda(cudaSetDevice(device_idx_)); ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)}); thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size()); - dh::safe_cuda(cudaStreamCreate(&stream_)); } RowPartitioner::~RowPartitioner() { dh::safe_cuda(cudaSetDevice(device_idx_)); - dh::safe_cuda(cudaStreamDestroy(stream_)); } common::Span RowPartitioner::GetRows(bst_node_t nidx) { diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index f1c420ba0c82..215a0e49bde9 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -116,7 +116,7 @@ template void SortPositionBatch(common::Span> d_batch_info, common::Span ridx, common::Span ridx_tmp, common::Span d_counts, std::size_t total_rows, OpT op, - dh::device_vector* tmp, cudaStream_t stream) { + dh::device_vector* tmp) { dh::LDGIterator> batch_info_itr(d_batch_info.data()); WriteResultsFunctor write_results{batch_info_itr, ridx.data(), ridx_tmp.data(), d_counts.data()}; @@ -135,12 +135,12 @@ void SortPositionBatch(common::Span> d_batch_info, size_t temp_bytes = 0; if (tmp->empty()) { cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator, - IndexFlagOp(), total_rows, stream); + IndexFlagOp(), total_rows); tmp->resize(temp_bytes); } temp_bytes = tmp->size(); cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator, - discard_write_iterator, IndexFlagOp(), total_rows, stream); + discard_write_iterator, IndexFlagOp(), total_rows); constexpr int kBlockSize = 256; @@ -149,7 +149,7 @@ void SortPositionBatch(common::Span> d_batch_info, const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread); SortPositionCopyKernel - <<>>(batch_info_itr, ridx, ridx_tmp, total_rows); + <<>>(batch_info_itr, ridx, ridx_tmp, total_rows); } struct NodePositionInfo { @@ -221,7 +221,6 @@ class RowPartitioner { dh::device_vector tmp_; dh::PinnedMemory pinned_; dh::PinnedMemory pinned2_; - cudaStream_t stream_; public: RowPartitioner(int device_idx, size_t num_rows); @@ -278,7 +277,7 @@ class RowPartitioner { } dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), h_batch_info.size() * sizeof(PerNodeData), - cudaMemcpyDefault, stream_)); + cudaMemcpyDefault)); // Temporary arrays auto h_counts = pinned_.GetSpan(nidx.size(), 0); @@ -287,12 +286,12 @@ class RowPartitioner { // Partition the rows according to the operator SortPositionBatch( dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts), - total_rows, op, &tmp_, stream_); + total_rows, op, &tmp_); dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(), - cudaMemcpyDefault, stream_)); + cudaMemcpyDefault)); // TODO(Rory): this synchronisation hurts performance a lot // Future optimisation should find a way to skip this - dh::safe_cuda(cudaStreamSynchronize(stream_)); + dh::DefaultStream().Sync(); // Update segments for (size_t i = 0; i < nidx.size(); i++) { @@ -327,13 +326,13 @@ class RowPartitioner { dh::TemporaryArray d_node_info_storage(ridx_segments_.size()); dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(), sizeof(NodePositionInfo) * ridx_segments_.size(), - cudaMemcpyDefault, stream_)); + cudaMemcpyDefault)); constexpr int kBlockSize = 512; const int kItemsThread = 8; const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread); common::Span d_ridx(ridx_.data().get(), ridx_.size()); - FinalisePositionKernel<<>>( + FinalisePositionKernel<<>>( dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op); } }; diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index f82123452cd8..05098040024e 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -73,7 +73,7 @@ void TestSortPositionBatch(const std::vector& ridx_in, const std::vector tmp; SortPositionBatch(dh::ToSpan(d_batch_info), dh::ToSpan(ridx), dh::ToSpan(ridx_tmp), dh::ToSpan(counts), - total_rows, op, &tmp, nullptr); + total_rows, op, &tmp); auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; }; for (size_t i = 0; i < segments.size(); i++) {