Skip to content

Commit

Permalink
Merge pull request #128 from NWChemEx/add_callback
Browse files Browse the repository at this point in the history
Add finalize callback
  • Loading branch information
yzhang-23 authored Jan 22, 2024
2 parents f9c04cf + 08cc66b commit 8799d74
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ set(
CACHE STRING "" FORCE
)

include(get_cmaize)
include(nwx_versions)
include(get_cmaize)

### Options ###
# These are the default values, used only if the user hasn't already set them.
Expand Down
16 changes: 16 additions & 0 deletions include/parallelzone/runtime/runtime_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class RuntimeView {
/// Type of a pointer to the PIMPL
using pimpl_pointer = std::shared_ptr<pimpl_type>;

/// Type of a callback function
using callback_function_type = std::function<void()>;

// -------------------------------------------------------------------------
// -- Ctors, Assignment, Dtor
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -490,6 +493,19 @@ class RuntimeView {
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Adds callback function to call when destructed.
*
* Adds functions to a stack of callback functions that will be called
* upone the destruction of *this. N.b., the functions are called LIFO.
*
* @param[in] cb_func The callback function with the signature 'void()'
* to add to the stack.
*
* @throw std::bad_alloc if there is problem adding the function to the
* stack. Strong throw guarantee.
*/
void stack_callback(callback_function_type cb_func);

/** @brief Swaps the state of *this with @p other.
*
* This method simply swaps the pointers of the PIMPLs. As such all
Expand Down
17 changes: 17 additions & 0 deletions src/parallelzone/runtime/detail_/runtime_view_pimpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/

#pragma once
#include <functional>
#include <parallelzone/runtime/runtime_view.hpp>
#include <stack>

namespace parallelzone::runtime::detail_ {

Expand Down Expand Up @@ -62,6 +64,9 @@ struct RuntimeViewPIMPL {
/// Ultimately a typedef of RuntimeView::argv_type
using argv_type = parent_type::argv_type;

/// Type of a callback function
using callback_function_type = parent_type::callback_function_type;

/** @brief Initializes *this from the provided MPI communicator.
*
* Constructor for the RuntimeViewPIMPL class.
Expand Down Expand Up @@ -117,6 +122,15 @@ struct RuntimeViewPIMPL {
*/
bool operator==(const RuntimeViewPIMPL& rhs) const noexcept;

/** @brief Adds callback function to call when destructed.
*
* @param[in] cb_func The callback function to add to the stack
*
* @throw std::bad_alloc if there is problem adding the function to the
* stack. Strong throw guarantee.
*/
void stack_callback(callback_function_type cb_func);

/// Did this PIMPL start MPI?
bool m_did_i_start_mpi;

Expand Down Expand Up @@ -158,6 +172,9 @@ struct RuntimeViewPIMPL {
* ResourceSet in a const function we can do that.
*/
mutable resource_set_container m_resource_sets_;

/// Stacks of initialize and finalize callback functions
std::stack<callback_function_type> m_callbacks_final_;
};

} // namespace parallelzone::runtime::detail_
Expand Down
18 changes: 16 additions & 2 deletions src/parallelzone/runtime/detail_/runtime_view_pimpl.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

namespace parallelzone::runtime::detail_ {

inline void RuntimeViewPIMPL::stack_callback(callback_function_type cb_func) {
m_callbacks_final_.push(std::move(cb_func));
}

inline void mpi_finalize_wrapper() { MPI_Finalize(); }

inline RuntimeViewPIMPL::RuntimeViewPIMPL(bool did_i_start_mpi, comm_type comm,
logger_type logger) :
m_did_i_start_mpi(did_i_start_mpi),
Expand All @@ -33,11 +39,19 @@ inline RuntimeViewPIMPL::RuntimeViewPIMPL(bool did_i_start_mpi, comm_type comm,
m_resource_sets_() {
// Pre-populate the current rank's resource set.
instantiate_resource_set_(m_comm.me());

/// Register the finalize callbacks
if(m_did_i_start_mpi) {
stack_callback(callback_function_type{&mpi_finalize_wrapper});
}
}

inline RuntimeViewPIMPL::~RuntimeViewPIMPL() noexcept {
if(!m_did_i_start_mpi) return;
MPI_Finalize();
/// call the initialize callback functions
while(!m_callbacks_final_.empty()) {
m_callbacks_final_.top()();
m_callbacks_final_.pop();
}
}

inline RuntimeViewPIMPL::const_resource_set_reference RuntimeViewPIMPL::at(
Expand Down
4 changes: 4 additions & 0 deletions src/parallelzone/runtime/runtime_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ RuntimeView::logger_reference RuntimeView::logger() const {
// -- Utility methods
// -----------------------------------------------------------------------------

void RuntimeView::stack_callback(callback_function_type cb_func) {
pimpl_().stack_callback(std::move(cb_func));
}

void RuntimeView::swap(RuntimeView& other) noexcept {
m_pimpl_.swap(other.m_pimpl_);
}
Expand Down
2 changes: 2 additions & 0 deletions src/python/runtime/runtime_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "runtime.hpp"
#include <parallelzone/runtime/runtime_view.hpp>
#include <pybind11/functional.h>
#include <pybind11/operators.h>

namespace parallelzone::runtime {
Expand All @@ -31,6 +32,7 @@ void export_runtime_view(pybind11::module_& m) {
.def("my_resource_set", &RuntimeView::my_resource_set)
.def("count", &RuntimeView::count)
.def("logger", &RuntimeView::logger)
.def("stack_callback", &RuntimeView::stack_callback)
.def(pybind11::self == pybind11::self)
.def(pybind11::self != pybind11::self);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,37 @@ TEST_CASE("RuntimeViewPIMPL") {
REQUIRE_FALSE(pimpl == other);
}
}

SECTION("stack_callback I") {
// Simulate initialization
bool is_running = true;

// Simulate finalize callback
auto turn_off = [&is_running]() { is_running = false; };

// RuntimeViewPIMPL will fall off, call the turn_off lambda
{
RuntimeViewPIMPL falls_off(false, comm, log);
falls_off.stack_callback(turn_off);
}
REQUIRE(is_running == false);
}

SECTION("stack_callback II") {
// Testing the stack
int func_no = 1;

// Two lambdas to be pushed into the stack
auto call_back_1 = [&func_no]() { func_no += 1; };
auto call_back_2 = [&func_no]() { func_no *= 2; };

// RuntimeViewPIMPL will fall off, call the turn_off lambda
{
RuntimeViewPIMPL rt_pimpl(false, comm, log);
rt_pimpl.stack_callback(call_back_1);
rt_pimpl.stack_callback(call_back_2);
}

REQUIRE(func_no == 3);
}
}
33 changes: 33 additions & 0 deletions tests/cxx/unit_tests/parallelzone/runtime/runtime_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,39 @@ TEST_CASE("RuntimeView") {
}
}

SECTION("stack_callback I") {
// Simulate initialization
bool is_running = true;

// Simulate finalize callback
auto turn_off = [&is_running]() { is_running = false; };

// RuntimeView will fall off, call the turn_off lambda
{
RuntimeView falls_off;
falls_off.stack_callback(turn_off);
}
REQUIRE(is_running == false);
}

SECTION("stack_callback II") {
// Testing the stack
int func_no = 1;

// Two lambdas to be pushed into the stack
auto call_back_1 = [&func_no]() { func_no += 1; };
auto call_back_2 = [&func_no]() { func_no *= 2; };

// RuntimeView will fall off, call the turn_off lambda
{
RuntimeView rt;
rt.stack_callback(call_back_1);
rt.stack_callback(call_back_2);
}

REQUIRE(func_no == 3);
}

SECTION("gather") {
using data_type = std::vector<std::string>;
data_type local_data(3, "Hello");
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unit_tests/runtime/test_runtime_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,34 @@ def test_logger(self):
self.assertIsNotNone(self.defaulted.logger())
self.defaulted.logger().log("Hello").log("world")

def test_stack_callback_1(self):
is_running = [True]

def turn_off(val=is_running):
val[0] = False

falls_off = pz.runtime.RuntimeView()
falls_off.stack_callback(turn_off)
del falls_off

self.assertFalse(is_running[0])

def test_stack_callback_2(self):
func_no = [1]

def call_back_1(val=func_no):
val[0] = val[0] + 1

def call_back_2(val=func_no):
val[0] = val[0] * 2

rt = pz.runtime.RuntimeView()
rt.stack_callback(call_back_1)
rt.stack_callback(call_back_2)
del rt

self.assertEqual(func_no[0], 3)

def test_comparisons(self):
self.assertEqual(self.defaulted, pz.runtime.RuntimeView())
self.assertFalse(self.defaulted != pz.runtime.RuntimeView())
Expand Down

0 comments on commit 8799d74

Please sign in to comment.