diff --git a/src/common/threadpool.h b/src/common/threadpool.h index 95d1deaaabc3..21e27aa760a1 100644 --- a/src/common/threadpool.h +++ b/src/common/threadpool.h @@ -26,20 +26,25 @@ class ThreadPool { bool stop_{false}; public: - explicit ThreadPool(std::int32_t n_threads) { + /** + * @param n_threads The number of threads this pool should hold. + * @param init_fn Function called once during thread creation. + */ + template + explicit ThreadPool(std::int32_t n_threads, InitFn&& init_fn) { for (std::int32_t i = 0; i < n_threads; ++i) { - pool_.emplace_back([&] { + pool_.emplace_back([&, init_fn = std::forward(init_fn)] { + init_fn(); + while (true) { std::unique_lock lock{mu_}; cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; }); if (this->stop_) { - if (!tasks_.empty()) { - while (!tasks_.empty()) { - auto fn = tasks_.front(); - tasks_.pop(); - fn(); - } + while (!tasks_.empty()) { + auto fn = tasks_.front(); + tasks_.pop(); + fn(); } return; } @@ -81,8 +86,13 @@ class ThreadPool { // Use shared ptr to make the task copy constructible. auto p{std::make_shared>()}; auto fut = p->get_future(); - auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable { - task->set_value(fn()); + auto ffn = std::function{[task = std::move(p), fn = std::forward(fn)]() mutable { + if constexpr (std::is_void_v) { + fn(); + task->set_value(); + } else { + task->set_value(fn()); + } }}; std::unique_lock lock{mu_}; diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 89aa86ace614..7bc8f77112d0 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -237,7 +237,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol exce_.Rethrow(); - auto const config = *GlobalConfigThreadLocalStore::Get(); for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring if (ring_->at(fetch_it).valid()) { @@ -245,8 +244,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol } auto const* self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); - ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] { - *GlobalConfigThreadLocalStore::Get() = config; + ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] { auto page = std::make_shared(); this->exce_.Run([&] { std::unique_ptr fmt{this->CreatePageFormat()}; @@ -296,7 +294,10 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol public: SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches, std::shared_ptr cache) - : workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads. + : workers_{std::max(2, std::min(nthreads, 16)), + [config = *GlobalConfigThreadLocalStore::Get()] { + *GlobalConfigThreadLocalStore::Get() = config; + }}, missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, diff --git a/tests/cpp/common/test_threadpool.cc b/tests/cpp/common/test_threadpool.cc index bd54a9dedbe2..ca8a73b55ff6 100644 --- a/tests/cpp/common/test_threadpool.cc +++ b/tests/cpp/common/test_threadpool.cc @@ -2,6 +2,7 @@ * Copyright 2024, XGBoost Contributors */ #include +#include // for GlobalConfigThreadLocalStore #include // for size_t #include // for int32_t @@ -13,7 +14,23 @@ namespace xgboost::common { TEST(ThreadPool, Basic) { std::int32_t n_threads = std::thread::hardware_concurrency(); - ThreadPool pool{n_threads}; + + // Set verbosity to 0 for thread-local variable. + auto orig = GlobalConfigThreadLocalStore::Get()->verbosity; + GlobalConfigThreadLocalStore::Get()->verbosity = 4; + // 4 is an invalid value, it's only possible to set it by bypassing the parameter + // validation. + ASSERT_NE(orig, GlobalConfigThreadLocalStore::Get()->verbosity); + ThreadPool pool{n_threads, [config = *GlobalConfigThreadLocalStore::Get()] { + *GlobalConfigThreadLocalStore::Get() = config; + }}; + GlobalConfigThreadLocalStore::Get()->verbosity = orig; // restore + + { + auto fut = pool.Submit([] { return GlobalConfigThreadLocalStore::Get()->verbosity; }); + ASSERT_EQ(fut.get(), 4); + ASSERT_EQ(GlobalConfigThreadLocalStore::Get()->verbosity, orig); + } { auto fut = pool.Submit([] { return 3; }); ASSERT_EQ(fut.get(), 3); @@ -45,5 +62,12 @@ TEST(ThreadPool, Basic) { ASSERT_EQ(futures[i].get(), i); } } + { + std::int32_t val{0}; + auto fut = pool.Submit([&] { val = 3; }); + static_assert(std::is_void_v); + fut.get(); + ASSERT_EQ(val, 3); + } } } // namespace xgboost::common