diff --git a/src/vt/context/runnable_context/td.cc b/src/vt/context/runnable_context/td.cc index a397c9d8f3..00958539db 100644 --- a/src/vt/context/runnable_context/td.cc +++ b/src/vt/context/runnable_context/td.cc @@ -63,43 +63,52 @@ TD::TD(EpochType in_ep) void TD::begin() { theMsg()->pushEpoch(ep_); - epoch_stack_size_ = theMsg()->getEpochStack().size(); -} -void TD::end() { auto& epoch_stack = theMsg()->getEpochStack(); vt_debug_print( verbose, context, - "TD::end: top={:x}, size={}\n", + "TD::begin: top={:x}, size={}\n", epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, epoch_stack.size() ); - vtAssertNot( - epoch_stack_size_ < epoch_stack.size(), - "Epoch stack popped below preceding push size in handler" + base_epoch_stack_size_ = epoch_stack.size(); +} + +void TD::end() { + auto& epoch_stack = theMsg()->getEpochStack(); + + vt_debug_print( + verbose, context, + "TD::end: top={:x}, size={}, base_size={}\n", + epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, + epoch_stack.size(), base_epoch_stack_size_ ); vtAssert( - epoch_stack_size_ == epoch_stack.size(), "Stack must be same size" + base_epoch_stack_size_ <= epoch_stack.size(), + "Epoch stack popped below preceding push size in handler" ); - vtAssertNotExpr(epoch_stack.size() == 0); - - while (epoch_stack.size() > epoch_stack_size_) { + while (epoch_stack.size() > base_epoch_stack_size_) { theMsg()->popEpoch(); } - vtAssertExpr(epoch_stack.size() == epoch_stack_size_); - theMsg()->popEpoch(ep_); } void TD::suspend() { auto& epoch_stack = theMsg()->getEpochStack(); - while (epoch_stack.size() > epoch_stack_size_) { + vt_debug_print( + verbose, context, + "TD::suspend: top={:x}, size={}, base_size={}\n", + epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, + epoch_stack.size(), base_epoch_stack_size_ + ); + + while (epoch_stack.size() > base_epoch_stack_size_) { suspended_epochs_.push_back(theMsg()->getEpoch()); theMsg()->popEpoch(); } @@ -108,13 +117,25 @@ void TD::suspend() { } void TD::resume() { - auto const sz = suspended_epochs_.size(); - for (std::size_t i = 0; i < sz; i++) { - theMsg()->pushEpoch(suspended_epochs_[sz - i - 1]); + theMsg()->pushEpoch(ep_); + + auto& epoch_stack = theMsg()->getEpochStack(); + base_epoch_stack_size_ = epoch_stack.size(); + + vt_debug_print( + verbose, context, + "TD::resume: top={:x}, size={}, base_size={}\n", + epoch_stack.size() > 0 ? epoch_stack.top(): no_epoch, + epoch_stack.size(), base_epoch_stack_size_ + ); + + for (auto it = suspended_epochs_.rbegin(); + it != suspended_epochs_.rend(); + ++it) { + theMsg()->pushEpoch(*it); } - suspended_epochs_.clear(); - theMsg()->pushEpoch(ep_); + suspended_epochs_.clear(); } }} /* end namespace vt::ctx */ diff --git a/src/vt/context/runnable_context/td.h b/src/vt/context/runnable_context/td.h index 4bcd052645..40174e188e 100644 --- a/src/vt/context/runnable_context/td.h +++ b/src/vt/context/runnable_context/td.h @@ -108,7 +108,7 @@ struct TD final : Base { private: EpochType ep_ = no_epoch; /**< The epoch for the task */ - std::size_t epoch_stack_size_ = 0; /**< Epoch stack size at start */ + std::size_t base_epoch_stack_size_ = 0; /**< Epoch stack size at start */ std::vector suspended_epochs_; /**< Suspended epoch stack */ }; diff --git a/tests/unit/active/test_async_op_threads.cc b/tests/unit/active/test_async_op_threads.cc index 771264387a..9d56c1ad58 100644 --- a/tests/unit/active/test_async_op_threads.cc +++ b/tests/unit/active/test_async_op_threads.cc @@ -45,6 +45,7 @@ #include #include +#include #include