Skip to content

Commit

Permalink
Limit the maximum number of threads. (#10872)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 5, 2024
1 parent 1b4c5fb commit b01ec53
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
11 changes: 6 additions & 5 deletions src/common/threading_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*/
#include "threading_utils.h"

#include <algorithm> // for max
#include <algorithm> // for max, min
#include <exception> // for exception
#include <filesystem> // for path, exists
#include <fstream> // for ifstream
Expand Down Expand Up @@ -105,17 +105,18 @@ std::int32_t GetCfsCPUCount() noexcept {
return -1;
}

std::int32_t OmpGetNumThreads(std::int32_t n_threads) {
std::int32_t OmpGetNumThreads(std::int32_t n_threads) noexcept(true) {
// Don't use parallel if we are in a parallel region.
if (omp_in_parallel()) {
return 1;
}
// Honor the openmp thread limit, which can be set via environment variable.
auto max_n_threads = std::min({omp_get_num_procs(), omp_get_max_threads(), OmpGetThreadLimit()});
// If -1 or 0 is specified by the user, we default to maximum number of threads.
if (n_threads <= 0) {
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
n_threads = max_n_threads;
}
// Honor the openmp thread limit, which can be set via environment variable.
n_threads = std::min(n_threads, OmpGetThreadLimit());
n_threads = std::min(n_threads, max_n_threads);
n_threads = std::max(n_threads, 1);
return n_threads;
}
Expand Down
4 changes: 2 additions & 2 deletions src/common/threading_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ inline std::int32_t OmpGetThreadLimit() {
std::int32_t GetCfsCPUCount() noexcept;

/**
* \brief Get the number of available threads based on n_threads specified by users.
* @brief Get the number of available threads based on n_threads specified by users.
*/
std::int32_t OmpGetNumThreads(std::int32_t n_threads);
std::int32_t OmpGetNumThreads(std::int32_t n_threads) noexcept(true);

/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
Expand Down
19 changes: 13 additions & 6 deletions tests/cpp/common/test_threading_utils.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 by XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>

Expand All @@ -9,9 +9,7 @@
#include "dmlc/omp.h" // omp_in_parallel
#include "xgboost/context.h" // Context

namespace xgboost {
namespace common {

namespace xgboost::common {
TEST(ParallelFor2d, CreateBlockedSpace2d) {
constexpr size_t kDim1 = 5;
constexpr size_t kDim2 = 3;
Expand Down Expand Up @@ -102,5 +100,14 @@ TEST(ParallelFor, Basic) {
});
ASSERT_FALSE(omp_in_parallel());
}
} // namespace common
} // namespace xgboost

TEST(OmpGetNumThreads, Max) {
#if defined(_OPENMP)
auto n_threads = OmpGetNumThreads(1 << 18);
ASSERT_LE(n_threads, std::thread::hardware_concurrency()); // le due to container
n_threads = OmpGetNumThreads(0);
ASSERT_GE(n_threads, 1);
ASSERT_LE(n_threads, std::thread::hardware_concurrency());
#endif
}
} // namespace xgboost::common

0 comments on commit b01ec53

Please sign in to comment.