diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 6cdcb3add7cf0..9292b4a8c7c8d 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -39,9 +39,23 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { fn(0); return; } - // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism // We will simply rely on the work queue and stealing in the short term. + if (total > NumThreads()) { + //The dispatcher thread will be idle at here + Barrier barrier(static_cast(total)); + std::function handle_iteration = [&barrier, &fn](int iteration) { + fn(iteration); + barrier.Notify(); + }; + + for (int32_t id = 0; id < total; ++id) { + Schedule([=, &handle_iteration]() { handle_iteration(id); }); + } + + barrier.Wait(); + return; + } Barrier barrier(static_cast(total - 1)); std::function handle_iteration = [&barrier, &fn](int iteration) { fn(iteration); @@ -51,7 +65,7 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { for (int32_t id = 1; id < total; ++id) { Schedule([=, &handle_iteration]() { handle_iteration(id); }); } - + //reuse the current thread for one task fn(0); barrier.Wait(); } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6ba13c728d849..c19ac16694183 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -545,7 +545,7 @@ MlasGetMaximumThreadCount( MLAS_UNREFERENCED_PARAMETER(ThreadPool); #else if (ThreadPool != nullptr) { - return ThreadPool->NumThreads() + 1; + return ThreadPool->NumThreads(); } #endif