diff --git a/src/vt/termination/termination.cc b/src/vt/termination/termination.cc index f4d42809a4..96c0a1cfa9 100644 --- a/src/vt/termination/termination.cc +++ b/src/vt/termination/termination.cc @@ -1233,4 +1233,18 @@ std::size_t TerminationDetector::getNumTerminatedCollectiveEpochs() const { return window->getTotalTerminated(); } +void TerminationDetector::disableTD(EpochType in_epoch) { + vtAssert(not isDS(in_epoch), "Must be a wave based epoch"); + auto& state = in_epoch == any_epoch_sentinel ? + any_epoch_state_ : findOrCreateState(in_epoch, false); + state.incrementDependency(); +} + +void TerminationDetector::enableTD(EpochType in_epoch) { + vtAssert(not isDS(in_epoch), "Must be a wave based epoch"); + auto& state = in_epoch == any_epoch_sentinel ? + any_epoch_state_ : findOrCreateState(in_epoch, false); + state.decrementDependency(); +} + }} // end namespace vt::term diff --git a/src/vt/termination/termination.h b/src/vt/termination/termination.h index 8834006308..42128bfdf1 100644 --- a/src/vt/termination/termination.h +++ b/src/vt/termination/termination.h @@ -644,6 +644,25 @@ struct TerminationDetector : */ void addDependency(EpochType predecessor, EpochType successoor); + /** + * \brief Disable termination detection on an epoch. Local counting is still + * enabled, but any non-local progress is halted until it is enabled + * + * \warning Does not work with DS epochs. It must be a wave based epoch. + * + * \param[in] in_epoch the epoch + */ + void disableTD(EpochType in_epoch = any_epoch_sentinel); + + /** + * \brief Enable termination detection on an epoch. + * + * \warning Does not work with DS epochs. It must be a wave based epoch. + * + * \param[in] in_epoch the epoch + */ + void enableTD(EpochType in_epoch = any_epoch_sentinel); + public: // Methods for testing state of TD from unit tests EpochContainerType const& getEpochState() { return epoch_state_; }