Skip to content

Commit

Permalink
A small improvement to the parallel_for when task count > thread count (
Browse files Browse the repository at this point in the history
#1839)

* Revert "MlasGetMaximumThreadCount: plus 1 to the NumThreads from ORT thread pool (#1646)"

This reverts commit 4137303.

* A small fix to the parallel for
  • Loading branch information
snnn authored Sep 16, 2019
1 parent 8a9c4cd commit 166b1f8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
18 changes: 16 additions & 2 deletions onnxruntime/core/common/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,23 @@ void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
fn(0);
return;
}

// TODO: Eigen supports a more efficient ThreadPoolDevice mechanism
// We will simply rely on the work queue and stealing in the short term.
if (total > NumThreads()) {
//The dispatcher thread will be idle at here
Barrier barrier(static_cast<unsigned int>(total));
std::function<void(int32_t)> handle_iteration = [&barrier, &fn](int iteration) {
fn(iteration);
barrier.Notify();
};

for (int32_t id = 0; id < total; ++id) {
Schedule([=, &handle_iteration]() { handle_iteration(id); });
}

barrier.Wait();
return;
}
Barrier barrier(static_cast<unsigned int>(total - 1));
std::function<void(int32_t)> handle_iteration = [&barrier, &fn](int iteration) {
fn(iteration);
Expand All @@ -51,7 +65,7 @@ void ThreadPool::ParallelFor(int32_t total, std::function<void(int32_t)> fn) {
for (int32_t id = 1; id < total; ++id) {
Schedule([=, &handle_iteration]() { handle_iteration(id); });
}

//reuse the current thread for one task
fn(0);
barrier.Wait();
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ MlasGetMaximumThreadCount(
MLAS_UNREFERENCED_PARAMETER(ThreadPool);
#else
if (ThreadPool != nullptr) {
return ThreadPool->NumThreads() + 1;
return ThreadPool->NumThreads();
}
#endif

Expand Down

0 comments on commit 166b1f8

Please sign in to comment.