From 0bb069b131b09882e34f9fd759afae826aaa8bf1 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 5 Oct 2024 15:32:41 +0800 Subject: [PATCH] Limit the maximum number of threads. (#10872) --- src/common/threading_utils.cc | 11 ++++++----- src/common/threading_utils.h | 4 ++-- tests/cpp/common/test_threading_utils.cc | 20 ++++++++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/common/threading_utils.cc b/src/common/threading_utils.cc index 1f4d5be2f361..46a007e3c750 100644 --- a/src/common/threading_utils.cc +++ b/src/common/threading_utils.cc @@ -3,7 +3,7 @@ */ #include "threading_utils.h" -#include // for max +#include // for max, min #include // for exception #include // for path, exists #include // for ifstream @@ -99,17 +99,18 @@ std::int32_t GetCfsCPUCount() noexcept { return -1; } -std::int32_t OmpGetNumThreads(std::int32_t n_threads) { +std::int32_t OmpGetNumThreads(std::int32_t n_threads) noexcept(true) { // Don't use parallel if we are in a parallel region. if (omp_in_parallel()) { return 1; } + // Honor the openmp thread limit, which can be set via environment variable. + auto max_n_threads = std::min({omp_get_num_procs(), omp_get_max_threads(), OmpGetThreadLimit()}); // If -1 or 0 is specified by the user, we default to maximum number of threads. if (n_threads <= 0) { - n_threads = std::min(omp_get_num_procs(), omp_get_max_threads()); + n_threads = max_n_threads; } - // Honor the openmp thread limit, which can be set via environment variable. - n_threads = std::min(n_threads, OmpGetThreadLimit()); + n_threads = std::min(n_threads, max_n_threads); n_threads = std::max(n_threads, 1); return n_threads; } diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index ac71190353a7..38db8e3a5f99 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -257,9 +257,9 @@ inline std::int32_t OmpGetThreadLimit() { std::int32_t GetCfsCPUCount() noexcept; /** - * \brief Get the number of available threads based on n_threads specified by users. + * @brief Get the number of available threads based on n_threads specified by users. */ -std::int32_t OmpGetNumThreads(std::int32_t n_threads); +std::int32_t OmpGetNumThreads(std::int32_t n_threads) noexcept(true); /*! * \brief A C-style array with in-stack allocation. As long as the array is smaller than diff --git a/tests/cpp/common/test_threading_utils.cc b/tests/cpp/common/test_threading_utils.cc index 2b1a2580a90a..844adbc56477 100644 --- a/tests/cpp/common/test_threading_utils.cc +++ b/tests/cpp/common/test_threading_utils.cc @@ -1,17 +1,16 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include // std::size_t +#include // for std::thread #include "../../../src/common/threading_utils.h" // BlockedSpace2d,ParallelFor2d,ParallelFor #include "dmlc/omp.h" // omp_in_parallel #include "xgboost/context.h" // Context -namespace xgboost { -namespace common { - +namespace xgboost::common { TEST(ParallelFor2d, CreateBlockedSpace2d) { constexpr size_t kDim1 = 5; constexpr size_t kDim2 = 3; @@ -102,5 +101,14 @@ TEST(ParallelFor, Basic) { }); ASSERT_FALSE(omp_in_parallel()); } -} // namespace common -} // namespace xgboost + +TEST(OmpGetNumThreads, Max) { +#if defined(_OPENMP) + auto n_threads = OmpGetNumThreads(1 << 18); + ASSERT_LE(n_threads, std::thread::hardware_concurrency()); // le due to container + n_threads = OmpGetNumThreads(0); + ASSERT_GE(n_threads, 1); + ASSERT_LE(n_threads, std::thread::hardware_concurrency()); +#endif +} +} // namespace xgboost::common