Skip to content

Commit

Permalink
[fwtest] Migrate to tbb::task_group
Browse files Browse the repository at this point in the history
  • Loading branch information
makortel committed Feb 14, 2022
1 parent b419c7d commit bec05c4
Show file tree
Hide file tree
Showing 13 changed files with 288 additions and 152 deletions.
27 changes: 0 additions & 27 deletions src/fwtest/Framework/EmptyWaitingTask.h

This file was deleted.

16 changes: 6 additions & 10 deletions src/fwtest/Framework/FunctorTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,26 @@
#include <exception>
#include <memory>

#include <tbb/task.h>

// user include files
#include "Framework/TaskBase.h"

// forward declarations

namespace edm {
template <typename F>
class FunctorTask : public tbb::task {
class FunctorTask : public TaskBase {
public:
explicit FunctorTask(F f) : func_(std::move(f)) {}

task* execute() override {
func_();
return nullptr;
};
void execute() final { func_(); };

private:
F func_;
};

template <typename ALLOC, typename F>
FunctorTask<F>* make_functor_task(ALLOC&& iAlloc, F f) {
return new (iAlloc) FunctorTask<F>(std::move(f));
template <typename F>
FunctorTask<F>* make_functor_task(F f) {
return new FunctorTask<F>(std::move(f));
}
} // namespace edm

Expand Down
65 changes: 65 additions & 0 deletions src/fwtest/Framework/TaskBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#ifndef FWCore_Concurrency_TaskBase_h
#define FWCore_Concurrency_TaskBase_h
// -*- C++ -*-
//
// Package: Concurrency
// Class : TaskBase
//
/**\class TaskBase TaskBase.h FWCore/Concurrency/interface/TaskBase.h
Description: Base class for tasks.
Usage:
Used as a callback to happen after a task has been completed.
*/
//
// Original Author: Chris Jones
// Created: Tue Jan 5 13:46:31 CST 2020
// $Id$
//

// system include files
#include <atomic>
#include <exception>
#include <memory>

// user include files

// forward declarations

namespace edm {
class TaskBase {
public:
friend class TaskSentry;

///Constructor
TaskBase() : m_refCount{0} {}
virtual ~TaskBase() = default;

virtual void execute() = 0;

void increment_ref_count() { ++m_refCount; }
unsigned int decrement_ref_count() { return --m_refCount; }

private:
virtual void recycle() { delete this; }

std::atomic<unsigned int> m_refCount{0};
};

class TaskSentry {
public:
TaskSentry(TaskBase* iTask) : m_task{iTask} {}
~TaskSentry() { m_task->recycle(); }
TaskSentry() = delete;
TaskSentry(TaskSentry const&) = delete;
TaskSentry(TaskSentry&&) = delete;
TaskSentry operator=(TaskSentry const&) = delete;
TaskSentry operator=(TaskSentry&&) = delete;

private:
TaskBase* m_task;
};
} // namespace edm

#endif
33 changes: 23 additions & 10 deletions src/fwtest/Framework/WaitingTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
#include <exception>
#include <memory>

#include <tbb/task.h>

// user include files
#include "Framework/TaskBase.h"

// forward declarations

Expand All @@ -34,7 +33,7 @@ namespace edm {
class WaitingTaskHolder;
class WaitingTaskWithArenaHolder;

class WaitingTask : public tbb::task {
class WaitingTask : public TaskBase {
public:
friend class WaitingTaskList;
friend class WaitingTaskHolder;
Expand Down Expand Up @@ -70,23 +69,37 @@ namespace edm {
std::atomic<std::exception_ptr*> m_ptr;
};

/** Use this class on the stack to signal the final task to be run.
Call done() to check to see if the task was run and check value of
exceptionPtr() to see if an exception was thrown by any task in the group.
*/
class FinalWaitingTask : public WaitingTask {
public:
FinalWaitingTask() : m_done{false} {}

void execute() final { m_done = true; }

bool done() const { return m_done.load(); }

private:
void recycle() final {}
std::atomic<bool> m_done;
};

template <typename F>
class FunctorWaitingTask : public WaitingTask {
public:
explicit FunctorWaitingTask(F f) : func_(std::move(f)) {}

task* execute() override {
func_(exceptionPtr());
return nullptr;
};
void execute() final { func_(exceptionPtr()); };

private:
F func_;
};

template <typename ALLOC, typename F>
FunctorWaitingTask<F>* make_waiting_task(ALLOC&& iAlloc, F f) {
return new (iAlloc) FunctorWaitingTask<F>(std::move(f));
template <typename F>
FunctorWaitingTask<F>* make_waiting_task(F f) {
return new FunctorWaitingTask<F>(std::move(f));
}

} // namespace edm
Expand Down
48 changes: 40 additions & 8 deletions src/fwtest/Framework/WaitingTaskHolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

// system include files
#include <cassert>
#include <tbb/task_group.h>

// user include files
#include "Framework/WaitingTask.h"
Expand All @@ -29,28 +30,50 @@
namespace edm {
class WaitingTaskHolder {
public:
WaitingTaskHolder() : m_task(nullptr) {}
friend class WaitingTaskList;
friend class WaitingTaskWithArenaHolder;

explicit WaitingTaskHolder(edm::WaitingTask* iTask) : m_task(iTask) { m_task->increment_ref_count(); }
WaitingTaskHolder() : m_task(nullptr), m_group(nullptr) {}

explicit WaitingTaskHolder(tbb::task_group& iGroup, edm::WaitingTask* iTask) : m_task(iTask), m_group(&iGroup) {
m_task->increment_ref_count();
}
~WaitingTaskHolder() {
if (m_task) {
doneWaiting(std::exception_ptr{});
}
}

WaitingTaskHolder(const WaitingTaskHolder& iHolder) : m_task(iHolder.m_task) { m_task->increment_ref_count(); }
WaitingTaskHolder(const WaitingTaskHolder& iHolder) : m_task(iHolder.m_task), m_group(iHolder.m_group) {
m_task->increment_ref_count();
}

WaitingTaskHolder(WaitingTaskHolder&& iOther) : m_task(iOther.m_task) { iOther.m_task = nullptr; }
WaitingTaskHolder(WaitingTaskHolder&& iOther) : m_task(iOther.m_task), m_group(iOther.m_group) {
iOther.m_task = nullptr;
}

WaitingTaskHolder& operator=(const WaitingTaskHolder& iRHS) {
WaitingTaskHolder tmp(iRHS);
std::swap(m_task, tmp.m_task);
std::swap(m_group, tmp.m_group);
return *this;
}

WaitingTaskHolder& operator=(WaitingTaskHolder&& iRHS) {
WaitingTaskHolder tmp(std::move(iRHS));
std::swap(m_task, tmp.m_task);
std::swap(m_group, tmp.m_group);
return *this;
}

// ---------- const member functions ---------------------
bool taskHasFailed() const { return m_task->exceptionPtr() != nullptr; }
bool taskHasFailed() const noexcept { return m_task->exceptionPtr() != nullptr; }

bool hasTask() const noexcept { return m_task != nullptr; }
/** since tbb::task_group is thread safe, we can return it non-const from here since
the object is not really part of the state of the holder
*/
tbb::task_group* group() const noexcept { return m_group; }
// ---------- static member functions --------------------

// ---------- member functions ---------------------------
Expand All @@ -59,7 +82,7 @@ namespace edm {
failure before some other child task which may be run later reports
a different, but related failure. You must later call doneWaiting
in the same thread passing the same exceptoin.
*/
*/
void presetTaskAsFailed(std::exception_ptr iExcept) {
if (iExcept) {
m_task->dependentTaskFailed(iExcept);
Expand All @@ -70,20 +93,29 @@ namespace edm {
if (iExcept) {
m_task->dependentTaskFailed(iExcept);
}
//spawn can run the task before we finish
//task_group::run can run the task before we finish
// doneWaiting and some other thread might
// try to reuse this object. Resetting
// before spawn avoids problems
auto task = m_task;
m_task = nullptr;
if (0 == task->decrement_ref_count()) {
tbb::task::spawn(*task);
m_group->run([task]() {
TaskSentry s{task};
task->execute();
});
}
}

private:
WaitingTask* release_no_decrement() noexcept {
auto t = m_task;
m_task = nullptr;
return t;
}
// ---------- member data --------------------------------
WaitingTask* m_task;
tbb::task_group* m_group;
};
} // namespace edm

Expand Down
Loading

0 comments on commit bec05c4

Please sign in to comment.