diff --git a/CMakeLists.txt b/CMakeLists.txt index b281f79b..a6c4afe2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,8 +30,8 @@ set( CACHE STRING "" FORCE ) -include(get_cmaize) include(nwx_versions) +include(get_cmaize) ### Options ### # These are the default values, used only if the user hasn't already set them. diff --git a/include/parallelzone/runtime/runtime_view.hpp b/include/parallelzone/runtime/runtime_view.hpp index 5b8d255e..14164b54 100644 --- a/include/parallelzone/runtime/runtime_view.hpp +++ b/include/parallelzone/runtime/runtime_view.hpp @@ -97,6 +97,9 @@ class RuntimeView { /// Type of a pointer to the PIMPL using pimpl_pointer = std::shared_ptr; + /// Type of a callback function + using callback_function_type = std::function; + // ------------------------------------------------------------------------- // -- Ctors, Assignment, Dtor // ------------------------------------------------------------------------- @@ -490,6 +493,19 @@ class RuntimeView { // -- Utility methods // ------------------------------------------------------------------------- + /** @brief Adds callback function to call when destructed. + * + * Adds functions to a stack of callback functions that will be called + * upone the destruction of *this. N.b., the functions are called LIFO. + * + * @param[in] cb_func The callback function with the signature 'void()' + * to add to the stack. + * + * @throw std::bad_alloc if there is problem adding the function to the + * stack. Strong throw guarantee. + */ + void stack_callback(callback_function_type cb_func); + /** @brief Swaps the state of *this with @p other. * * This method simply swaps the pointers of the PIMPLs. As such all diff --git a/src/parallelzone/runtime/detail_/runtime_view_pimpl.hpp b/src/parallelzone/runtime/detail_/runtime_view_pimpl.hpp index 2026b61c..1de271cc 100644 --- a/src/parallelzone/runtime/detail_/runtime_view_pimpl.hpp +++ b/src/parallelzone/runtime/detail_/runtime_view_pimpl.hpp @@ -15,7 +15,9 @@ */ #pragma once +#include #include +#include namespace parallelzone::runtime::detail_ { @@ -62,6 +64,9 @@ struct RuntimeViewPIMPL { /// Ultimately a typedef of RuntimeView::argv_type using argv_type = parent_type::argv_type; + /// Type of a callback function + using callback_function_type = parent_type::callback_function_type; + /** @brief Initializes *this from the provided MPI communicator. * * Constructor for the RuntimeViewPIMPL class. @@ -117,6 +122,15 @@ struct RuntimeViewPIMPL { */ bool operator==(const RuntimeViewPIMPL& rhs) const noexcept; + /** @brief Adds callback function to call when destructed. + * + * @param[in] cb_func The callback function to add to the stack + * + * @throw std::bad_alloc if there is problem adding the function to the + * stack. Strong throw guarantee. + */ + void stack_callback(callback_function_type cb_func); + /// Did this PIMPL start MPI? bool m_did_i_start_mpi; @@ -158,6 +172,9 @@ struct RuntimeViewPIMPL { * ResourceSet in a const function we can do that. */ mutable resource_set_container m_resource_sets_; + + /// Stacks of initialize and finalize callback functions + std::stack m_callbacks_final_; }; } // namespace parallelzone::runtime::detail_ diff --git a/src/parallelzone/runtime/detail_/runtime_view_pimpl.ipp b/src/parallelzone/runtime/detail_/runtime_view_pimpl.ipp index 200e59c9..9d8ec7f7 100644 --- a/src/parallelzone/runtime/detail_/runtime_view_pimpl.ipp +++ b/src/parallelzone/runtime/detail_/runtime_view_pimpl.ipp @@ -25,6 +25,12 @@ namespace parallelzone::runtime::detail_ { +inline void RuntimeViewPIMPL::stack_callback(callback_function_type cb_func) { + m_callbacks_final_.push(std::move(cb_func)); +} + +inline void mpi_finalize_wrapper() { MPI_Finalize(); } + inline RuntimeViewPIMPL::RuntimeViewPIMPL(bool did_i_start_mpi, comm_type comm, logger_type logger) : m_did_i_start_mpi(did_i_start_mpi), @@ -33,11 +39,19 @@ inline RuntimeViewPIMPL::RuntimeViewPIMPL(bool did_i_start_mpi, comm_type comm, m_resource_sets_() { // Pre-populate the current rank's resource set. instantiate_resource_set_(m_comm.me()); + + /// Register the finalize callbacks + if(m_did_i_start_mpi) { + stack_callback(callback_function_type{&mpi_finalize_wrapper}); + } } inline RuntimeViewPIMPL::~RuntimeViewPIMPL() noexcept { - if(!m_did_i_start_mpi) return; - MPI_Finalize(); + /// call the initialize callback functions + while(!m_callbacks_final_.empty()) { + m_callbacks_final_.top()(); + m_callbacks_final_.pop(); + } } inline RuntimeViewPIMPL::const_resource_set_reference RuntimeViewPIMPL::at( diff --git a/src/parallelzone/runtime/runtime_view.cpp b/src/parallelzone/runtime/runtime_view.cpp index ac9558a9..dbd31583 100644 --- a/src/parallelzone/runtime/runtime_view.cpp +++ b/src/parallelzone/runtime/runtime_view.cpp @@ -120,6 +120,10 @@ RuntimeView::logger_reference RuntimeView::logger() const { // -- Utility methods // ----------------------------------------------------------------------------- +void RuntimeView::stack_callback(callback_function_type cb_func) { + pimpl_().stack_callback(std::move(cb_func)); +} + void RuntimeView::swap(RuntimeView& other) noexcept { m_pimpl_.swap(other.m_pimpl_); } diff --git a/src/python/runtime/runtime_view.cpp b/src/python/runtime/runtime_view.cpp index 6a78a33e..ee0a4325 100644 --- a/src/python/runtime/runtime_view.cpp +++ b/src/python/runtime/runtime_view.cpp @@ -16,6 +16,7 @@ #include "runtime.hpp" #include +#include #include namespace parallelzone::runtime { @@ -31,6 +32,7 @@ void export_runtime_view(pybind11::module_& m) { .def("my_resource_set", &RuntimeView::my_resource_set) .def("count", &RuntimeView::count) .def("logger", &RuntimeView::logger) + .def("stack_callback", &RuntimeView::stack_callback) .def(pybind11::self == pybind11::self) .def(pybind11::self != pybind11::self); } diff --git a/tests/cxx/unit_tests/parallelzone/runtime/detail_/runtime_view_pimpl.cpp b/tests/cxx/unit_tests/parallelzone/runtime/detail_/runtime_view_pimpl.cpp index 6c6fd4a5..4375d9c1 100644 --- a/tests/cxx/unit_tests/parallelzone/runtime/detail_/runtime_view_pimpl.cpp +++ b/tests/cxx/unit_tests/parallelzone/runtime/detail_/runtime_view_pimpl.cpp @@ -73,4 +73,37 @@ TEST_CASE("RuntimeViewPIMPL") { REQUIRE_FALSE(pimpl == other); } } + + SECTION("stack_callback I") { + // Simulate initialization + bool is_running = true; + + // Simulate finalize callback + auto turn_off = [&is_running]() { is_running = false; }; + + // RuntimeViewPIMPL will fall off, call the turn_off lambda + { + RuntimeViewPIMPL falls_off(false, comm, log); + falls_off.stack_callback(turn_off); + } + REQUIRE(is_running == false); + } + + SECTION("stack_callback II") { + // Testing the stack + int func_no = 1; + + // Two lambdas to be pushed into the stack + auto call_back_1 = [&func_no]() { func_no += 1; }; + auto call_back_2 = [&func_no]() { func_no *= 2; }; + + // RuntimeViewPIMPL will fall off, call the turn_off lambda + { + RuntimeViewPIMPL rt_pimpl(false, comm, log); + rt_pimpl.stack_callback(call_back_1); + rt_pimpl.stack_callback(call_back_2); + } + + REQUIRE(func_no == 3); + } } diff --git a/tests/cxx/unit_tests/parallelzone/runtime/runtime_view.cpp b/tests/cxx/unit_tests/parallelzone/runtime/runtime_view.cpp index ba296eed..53b6304c 100644 --- a/tests/cxx/unit_tests/parallelzone/runtime/runtime_view.cpp +++ b/tests/cxx/unit_tests/parallelzone/runtime/runtime_view.cpp @@ -208,6 +208,39 @@ TEST_CASE("RuntimeView") { } } + SECTION("stack_callback I") { + // Simulate initialization + bool is_running = true; + + // Simulate finalize callback + auto turn_off = [&is_running]() { is_running = false; }; + + // RuntimeView will fall off, call the turn_off lambda + { + RuntimeView falls_off; + falls_off.stack_callback(turn_off); + } + REQUIRE(is_running == false); + } + + SECTION("stack_callback II") { + // Testing the stack + int func_no = 1; + + // Two lambdas to be pushed into the stack + auto call_back_1 = [&func_no]() { func_no += 1; }; + auto call_back_2 = [&func_no]() { func_no *= 2; }; + + // RuntimeView will fall off, call the turn_off lambda + { + RuntimeView rt; + rt.stack_callback(call_back_1); + rt.stack_callback(call_back_2); + } + + REQUIRE(func_no == 3); + } + SECTION("gather") { using data_type = std::vector; data_type local_data(3, "Hello"); diff --git a/tests/python/unit_tests/runtime/test_runtime_view.py b/tests/python/unit_tests/runtime/test_runtime_view.py index 9c11bcc4..6c72a569 100644 --- a/tests/python/unit_tests/runtime/test_runtime_view.py +++ b/tests/python/unit_tests/runtime/test_runtime_view.py @@ -79,6 +79,34 @@ def test_logger(self): self.assertIsNotNone(self.defaulted.logger()) self.defaulted.logger().log("Hello").log("world") + def test_stack_callback_1(self): + is_running = [True] + + def turn_off(val=is_running): + val[0] = False + + falls_off = pz.runtime.RuntimeView() + falls_off.stack_callback(turn_off) + del falls_off + + self.assertFalse(is_running[0]) + + def test_stack_callback_2(self): + func_no = [1] + + def call_back_1(val=func_no): + val[0] = val[0] + 1 + + def call_back_2(val=func_no): + val[0] = val[0] * 2 + + rt = pz.runtime.RuntimeView() + rt.stack_callback(call_back_1) + rt.stack_callback(call_back_2) + del rt + + self.assertEqual(func_no[0], 3) + def test_comparisons(self): self.assertEqual(self.defaulted, pz.runtime.RuntimeView()) self.assertFalse(self.defaulted != pz.runtime.RuntimeView())