diff --git a/src/vt/runtime/runtime.cc b/src/vt/runtime/runtime.cc index fd93df5e18..615d2f0b9b 100644 --- a/src/vt/runtime/runtime.cc +++ b/src/vt/runtime/runtime.cc @@ -54,15 +54,12 @@ #include "vt/rdma/rdma_headers.h" #include "vt/parameterization/parameterization.h" #include "vt/sequence/sequencer_headers.h" -#include "vt/trace/trace.h" #include "vt/pipe/pipe_manager.h" #include "vt/objgroup/manager.h" #include "vt/scheduler/scheduler.h" -#include "vt/termination/termination.h" #include "vt/topos/location/location_headers.h" #include "vt/vrt/context/context_vrtmanager.h" #include "vt/vrt/collection/balance/lb_type.h" -#include "vt/vrt/collection/balance/stats_lb_reader.h" #include "vt/vrt/collection/collection_headers.h" #include "vt/worker/worker_headers.h" #include "vt/configs/generated/vt_git_revision.h" @@ -826,7 +823,7 @@ bool Runtime::initialize(bool const force_now) { auto lbNames = vrt::collection::balance::lb_names_; auto mapLB = vrt::collection::balance::LBType::StatsMapLB; if (ArgType::vt_lb_name == lbNames[mapLB]) { - vrt::collection::balance::StatsLBReader::init(); + vrt::collection::balance::ProcStats::readRestartInfo(); } } #endif diff --git a/src/vt/runtime/runtime.h b/src/vt/runtime/runtime.h index 95e4a27c6f..1f0ca1a8f0 100644 --- a/src/vt/runtime/runtime.h +++ b/src/vt/runtime/runtime.h @@ -48,10 +48,13 @@ #include "vt/config.h" #include "vt/runtime/runtime_common.h" #include "vt/runtime/runtime_component_fwd.h" -#include "vt/trace/trace.h" #include "vt/worker/worker_headers.h" #include "vt/configs/arguments/args.h" +#if backend_check_enabled(trace_enabled) +#include "vt/trace/trace.h" +#endif + #include #include #include diff --git a/src/vt/vrt/collection/balance/lb_invoke/invoke.cc b/src/vt/vrt/collection/balance/lb_invoke/invoke.cc index 40fd03ed6d..75e1362dde 100644 --- a/src/vt/vrt/collection/balance/lb_invoke/invoke.cc +++ b/src/vt/vrt/collection/balance/lb_invoke/invoke.cc @@ -53,7 +53,6 @@ #include "vt/vrt/collection/balance/rotatelb/rotatelb.h" #include "vt/vrt/collection/balance/gossiplb/gossiplb.h" #include "vt/vrt/collection/balance/statsmaplb/statsmaplb.h" -#include "vt/vrt/collection/balance/stats_lb_reader.h" #include "vt/vrt/collection/messages/system_create.h" #include "vt/vrt/collection/manager.fwd.h" diff --git a/src/vt/vrt/collection/balance/proc_stats.cc b/src/vt/vrt/collection/balance/proc_stats.cc index da2c89bd87..000ce2120f 100644 --- a/src/vt/vrt/collection/balance/proc_stats.cc +++ b/src/vt/vrt/collection/balance/proc_stats.cc @@ -44,6 +44,7 @@ #include "vt/config.h" #include "vt/vrt/collection/balance/proc_stats.h" +#include "vt/vrt/collection/balance/proc_stats.util.h" #include "vt/vrt/collection/manager.h" #include "vt/timing/timing.h" #include "vt/configs/arguments/args.h" @@ -51,9 +52,7 @@ #include #include -#include #include -#include #include #include "fmt/format.h" @@ -87,6 +86,241 @@ std::unordered_map /*static*/ std::vector< bool > ProcStats::proc_phase_runs_LB_ = {}; +/*static*/ StatsRestartReader *ProcStats::proc_reader_ = nullptr; + +StatsRestartReader::~StatsRestartReader() { + if (proxy.getProxy() != no_obj_group) { + theObjGroup()->destroyCollective(proxy); + } +} + +/*static*/ +void StatsRestartReader::readStats() { + + // Read the input files + std::deque< std::set > elements_history; + inputStatsFile(elements_history); + if (elements_history.empty()) { + vtWarn("No element history provided"); + return; + } + + ProcStats::proc_reader_->proxy = + theObjGroup()->makeCollective(); + + const auto num_iters = elements_history.size() - 1; + ProcStats::proc_move_list_.resize(num_iters + 1); + ProcStats::proc_phase_runs_LB_.resize(num_iters, true); + if (theContext()->getNode() == 0) { + ProcStats::proc_reader_->msgsReceived.resize(num_iters, 0); + ProcStats::proc_reader_->totalMove.resize(num_iters); + } + + // Communicate the migration information + createMigrationInfo(elements_history); +} + +/*static*/ +void StatsRestartReader::inputStatsFile( + std::deque< std::set > &element_history +) +{ + using ArgType = vt::arguments::ArgConfig; + auto const node = theContext()->getNode(); + const std::string &base_file = ArgType::vt_lb_stats_file_in; + const std::string &dir = ArgType::vt_lb_stats_dir_in; + auto const file = fmt::format("{}.{}.out", base_file, node); + auto const file_name = fmt::format("{}/{}", dir, file); + + vt_print(lb, "inputStatFile: file={}, iter={}\n", file_name, 0); + + std::FILE *pFile = std::fopen(file_name.c_str(), "r"); + if (pFile == nullptr) { + vtAssert(pFile, "File opening failed"); + } + + std::set buffer; + + // Load: Format of a line :size_t, ElementIDType, TimeType + size_t phaseID = 0, prevPhaseID = 0; + ElementIDType elmID; + TimeType tval; + CommBytesType d_buffer; + using vtCommType = typename std::underlying_type::type; + vtCommType typeID; + char separator; + fpos_t pos; + bool finished = false; + while (!finished) { + if (fscanf(pFile, "%zu %c %lli %c %lf", &phaseID, &separator, &elmID, + &separator, &tval) > 0) { + fgetpos (pFile,&pos); + fscanf (pFile, "%c", &separator); + if (separator == ',') { + // COM detected, read the end of line and do nothing else + int res = fscanf (pFile, "%lf %c %hhi", &d_buffer, &separator, &typeID); + vtAssertExpr(res == 3); + } else { + // Load detected, create the new element + fsetpos (pFile,&pos); + if (prevPhaseID != phaseID) { + prevPhaseID = phaseID; + element_history.push_back(buffer); + buffer.clear(); + } + buffer.insert(elmID); + } + } else { + finished = true; + } + } + + if (!buffer.empty()) { + element_history.push_back(buffer); + } + + std::fclose(pFile); +} + +/*static*/ +void StatsRestartReader::createMigrationInfo( + std::deque< std::set > &element_history +) +{ + const auto num_iters = element_history.size() - 1; + const auto myNodeID = static_cast(theContext()->getNode()); + auto myProxy = ProcStats::proc_reader_->proxy; + + for (size_t ii = 0; ii < num_iters; ++ii) { + auto &elms = element_history[ii]; + auto &elmsNext = element_history[ii + 1]; + std::set diff; + std::set_difference(elmsNext.begin(), elmsNext.end(), elms.begin(), + elms.end(), std::inserter(diff, diff.begin())); + const size_t qi = diff.size(); + const size_t pi = elms.size() - (elmsNext.size() - qi); + auto &myList = ProcStats::proc_move_list_[ii]; + myList.reserve(3 * (pi + qi) + 1); + //--- Store the iteration number + myList.push_back(static_cast(ii)); + //--- Store partial migration information (i.e. nodes moving in) + for (auto iEle : diff) { + myList.push_back(iEle); //--- permID to receive + myList.push_back(no_element_id); // node moving from + myList.push_back(myNodeID); // node moving to + } + diff.clear(); + //--- Store partial migration information (i.e. nodes moving out) + std::set_difference(elms.begin(), elms.end(), elmsNext.begin(), + elmsNext.end(), std::inserter(diff, diff.begin())); + for (auto iEle : diff) { + myList.push_back(iEle); //--- permID to send + myList.push_back(myNodeID); // node migrating from + myList.push_back(no_element_id); // node migrating to + } + // + // Create a message storing the vector + // + auto msg = makeSharedMessage(myList); + myProxy[0].send(msg); + // + // Clear old distribution of elements + // + elms.clear(); + } + +} + +void StatsRestartReader::gatherMsgs(VecMsg *msg) { + auto sentVec = msg->getTransfer(); + vtAssert(sentVec.size() % 3 == 1, "Expecting vector of length 3n+1"); + const ElementIDType phaseID = sentVec[0]; + // + // --- Combine the different pieces of information + // + ProcStats::proc_reader_->msgsReceived[phaseID] += 1; + auto &migrate = ProcStats::proc_reader_->totalMove[phaseID]; + for (size_t ii = 1; ii < sentVec.size(); ii += 3) { + const auto permID = sentVec[ii]; + const auto nodeFrom = static_cast(sentVec[ii + 1]); + const auto nodeTo = static_cast(sentVec[ii + 2]); + auto iptr = migrate.find(permID); + if (iptr == migrate.end()) { + migrate.insert(std::make_pair(permID, std::make_pair(nodeFrom, nodeTo))); + } + else { + auto &nodePair = iptr->second; + nodePair.first = std::max(nodePair.first, nodeFrom); + nodePair.second = std::max(nodePair.second, nodeTo); + } + } + // + // --- Check whether all the messages have been received + // + const NodeType numNodes = theContext()->getNumNodes(); + if (ProcStats::proc_reader_->msgsReceived[phaseID] < numNodes) + return; + // + //--- Distribute the information when everything has been received + // + auto myProxy = ProcStats::proc_reader_->proxy; + const size_t header = 2; + for (NodeType in = 0; in < numNodes; ++in) { + size_t iCount = 0; + for (auto iNode : migrate) { + if (iNode.second.first == in) + iCount += 1; + } + std::vector toMove(2 * iCount + header); + iCount = 0; + toMove[iCount++] = phaseID; + toMove[iCount++] = static_cast(migrate.size()); + for (auto iNode : migrate) { + if (iNode.second.first == in) { + toMove[iCount++] = iNode.first; + toMove[iCount++] = static_cast(iNode.second.second); + } + } + if (in > 0) { + auto msg2 = makeSharedMessage(toMove); + myProxy[in].send(msg2); + } else { + ProcStats::proc_phase_runs_LB_[phaseID] = (!migrate.empty()); + auto &myList = ProcStats::proc_move_list_[phaseID]; + myList.resize(toMove.size() - header); + std::copy(&toMove[header], &toMove[0] + toMove.size(), + myList.begin()); + } + } + migrate.clear(); +} + +void StatsRestartReader::scatterMsgs(VecMsg *msg) { + const size_t header = 2; + auto recvVec = msg->getTransfer(); + vtAssert((recvVec.size() -header) % 2 == 0, + "Expecting vector of length 2n+2"); + //--- Get the iteration number associated with the message + const ElementIDType phaseID = recvVec[0]; + //--- Check whether some migration will be done + ProcStats::proc_phase_runs_LB_[phaseID] = static_cast(recvVec[1] > 0); + auto &myList = ProcStats::proc_move_list_[phaseID]; + if (!ProcStats::proc_phase_runs_LB_[phaseID]) { + myList.clear(); + return; + } + //--- Copy the migration information + myList.resize(recvVec.size() - header); + std::copy(&recvVec[header], &recvVec[0]+recvVec.size(), myList.begin()); +} + +/*static*/ void ProcStats::readRestartInfo() { + if (ProcStats::proc_reader_ == nullptr) { + ProcStats::proc_reader_ = new StatsRestartReader; + } + ProcStats::proc_reader_->readStats(); +} + /*static*/ void ProcStats::clearStats() { ProcStats::proc_comm_.clear(); ProcStats::proc_data_.clear(); @@ -94,6 +328,9 @@ std::unordered_map ProcStats::proc_temp_to_perm_.clear(); ProcStats::proc_perm_to_temp_.clear(); next_elm_ = 1; + ProcStats::proc_move_list_.clear(); + ProcStats::proc_phase_runs_LB_.clear(); + delete ProcStats::proc_reader_; } /*static*/ void ProcStats::startIterCleanup() { diff --git a/src/vt/vrt/collection/balance/proc_stats.h b/src/vt/vrt/collection/balance/proc_stats.h index af27c0d1ba..99a8979f0e 100644 --- a/src/vt/vrt/collection/balance/proc_stats.h +++ b/src/vt/vrt/collection/balance/proc_stats.h @@ -64,7 +64,7 @@ struct StatsMapLB; namespace vt { namespace vrt { namespace collection { namespace balance { struct LBManager; -struct StatsLBReader; +struct StatsRestartReader; struct ProcStats { using MigrateFnType = std::function; @@ -82,6 +82,8 @@ struct ProcStats { static void outputStatsFile(); + static void readRestartInfo(); + private: static void createStatsFile(); static void closeStatsFile(); @@ -102,6 +104,7 @@ struct ProcStats { static FILE* stats_file_; static bool created_dir_; +#if backend_check_enabled(lblite) /// \brief Queue of migrations for each iteration. /// \note At each iteration, a vector of length 2 times (# of migrations) /// is specified. The vector contains the "permanent" ID of the element @@ -112,9 +115,13 @@ struct ProcStats { /// map migrates elements for a specific iteration. static std::vector< bool > proc_phase_runs_LB_; - friend struct lb::StatsMapLB; - friend struct balance::StatsLBReader; + /// \brief Private object to migrate information from a (restart) input file + static StatsRestartReader *proc_reader_; +#endif + friend struct balance::LBManager; + friend struct balance::StatsRestartReader; + friend struct lb::StatsMapLB; }; diff --git a/src/vt/vrt/collection/balance/stats_lb_reader.h b/src/vt/vrt/collection/balance/proc_stats.util.h similarity index 64% rename from src/vt/vrt/collection/balance/stats_lb_reader.h rename to src/vt/vrt/collection/balance/proc_stats.util.h index 279804497e..b337aa7602 100644 --- a/src/vt/vrt/collection/balance/stats_lb_reader.h +++ b/src/vt/vrt/collection/balance/proc_stats.util.h @@ -2,7 +2,7 @@ //@HEADER // ***************************************************************************** // -// stats_lb_reader.h +// proc_stats.util.h // DARMA Toolkit v. 1.0.0 // DARMA/vt => Virtual Transport // @@ -42,71 +42,54 @@ //@HEADER */ -#if !defined INCLUDED_VRT_COLLECTION_BALANCE_STATS_LB_READER_H -#define INCLUDED_VRT_COLLECTION_BALANCE_STATS_LB_READER_H +#if !defined INCLUDED_VRT_COLLECTION_BALANCE_PROC_STATS_UTIL_H +#define INCLUDED_VRT_COLLECTION_BALANCE_PROC_STATS_UTIL_H -#include "vt/config.h" -#include "vt/vrt/collection/balance/lb_common.h" -#include "vt/objgroup/headers.h" - -#include -#include #include -#include #include -#include -#include #include - -namespace vt { namespace vrt { namespace collection { namespace lb { - -template -struct TransferMsg; - -using VecMsg = TransferMsg< std::vector< balance::ElementIDType > >; - -} } } } - +#include "vt/config.h" +#include "vt/vrt/collection/balance/baselb/baselb_msgs.h" namespace vt { namespace vrt { namespace collection { namespace balance { -struct StatsLBReader { +struct StatsRestartReader { + /// \brief Vector counting the received messages per iteration + /// \note Only node 0 will use this vector. + std::vector msgsReceived = {}; -public: - StatsLBReader() = default; - StatsLBReader(StatsLBReader const&) = delete; - StatsLBReader(StatsLBReader&&) = default; + /// \brief Queue for storing all the migrations per iteration. + /// \note Only node 0 will use this queue. + std::deque>> + totalMove = {}; - static void init(); - static void destroy(); - static void clearStats(); - static void inputStatsFile(); - static void loadPhaseChangedMap(); + /// \brief Proxy for communicating the migration information + objgroup::proxy::Proxy proxy = {}; public: + StatsRestartReader() = default; - /// \brief Queue to store a map of elements specified by input file. - static std::deque< std::set > user_specified_map_; + ~StatsRestartReader(); -protected: + static void readStats(); - void doSend(lb::VecMsg *msg); - void scatterSend(lb::VecMsg *msg); +private: + static void inputStatsFile( + std::deque > &element_history + ); - /// \brief Vector counting the received messages per iteration - /// \note Only node 0 will use this vector. - static std::vector msgsReceived; + static void createMigrationInfo( + std::deque > &element_history + ); - /// \brief Queue for storing all the migrations per iteration. - /// \note Only node 0 will use this queue. - static std::deque>> - totalMove; + using VecMsg = lb::TransferMsg >; + void gatherMsgs(VecMsg *msg); - static objgroup::proxy::Proxy proxy_; + void scatterMsgs(VecMsg *msg); }; -}}}} /* end namespace vt::vrt::collection::balance */ +}}}} -#endif /*INCLUDED_VRT_COLLECTION_BALANCE_STATS_LB_READER_H*/ +#endif diff --git a/src/vt/vrt/collection/balance/stats_lb_reader.cc b/src/vt/vrt/collection/balance/stats_lb_reader.cc deleted file mode 100644 index f582e52e5c..0000000000 --- a/src/vt/vrt/collection/balance/stats_lb_reader.cc +++ /dev/null @@ -1,277 +0,0 @@ -/* -//@HEADER -// ***************************************************************************** -// -// stats_lb_reader.cc -// DARMA Toolkit v. 1.0.0 -// DARMA/vt => Virtual Transport -// -// Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC -// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. -// Government retains certain rights in this software. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// * Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from this -// software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// -// Questions? Contact darma@sandia.gov -// -// ***************************************************************************** -//@HEADER -*/ - -#include "vt/config.h" -#include "vt/vrt/collection/balance/stats_lb_reader.h" -#include "vt/vrt/collection/balance/baselb/baselb.h" -#include "vt/vrt/collection/manager.h" -#include "vt/timing/timing.h" -#include "vt/configs/arguments/args.h" -#include "vt/runtime/runtime.h" - -#include -#include -#include -#include - -#include "fmt/format.h" - -namespace vt { namespace vrt { namespace collection { namespace balance { - -/*static*/ -std::deque> StatsLBReader::user_specified_map_ = {}; - -/*static*/ -std::vector StatsLBReader::msgsReceived = {}; - -/*static*/ -std::deque>> - StatsLBReader::totalMove = {}; - -/*static*/ objgroup::proxy::Proxy StatsLBReader::proxy_ = {}; - -/*static*/ void StatsLBReader::init() { - // Create the new class dedicated to the input reader - StatsLBReader::proxy_ = theObjGroup()->makeCollective(); - StatsLBReader::inputStatsFile(); - StatsLBReader::loadPhaseChangedMap(); -} - -/*static*/ void StatsLBReader::destroy() { - theObjGroup()->destroyCollective(StatsLBReader::proxy_); -} - -/*static*/ void StatsLBReader::clearStats() { - StatsLBReader::user_specified_map_.clear(); -} - -/*static*/ void StatsLBReader::inputStatsFile() { - - using ArgType = vt::arguments::ArgConfig; - - auto const node = theContext()->getNode(); - auto const base_file = std::string(ArgType::vt_lb_stats_file_in); - auto const dir = std::string(ArgType::vt_lb_stats_dir_in); - auto const file = fmt::format("{}.{}.out", base_file, node); - auto const file_name = fmt::format("{}/{}", dir, file); - - vt_print(lb, "inputStatFile: file={}, iter={}\n", file_name, 0); - - std::FILE *pFile = std::fopen(file_name.c_str(), "r"); - if (pFile == nullptr) { - vtAssert(pFile, "File opening failed"); - } - - std::set buffer; - - // Load: Format of a line :size_t,ElementIDType,TimeType - size_t c1; - ElementIDType c2; - TimeType c3; - CommBytesType c4; - using E = typename std::underlying_type::type; - E c5; - char separator; - fpos_t pos; - bool finished = false; - size_t c1PreviousValue = 0; - while (!finished) { - if (fscanf(pFile, "%zu %c %lli %c %lf", &c1, &separator, &c2, - &separator, &c3) > 0) { - fgetpos (pFile,&pos); - fscanf (pFile, "%c", &separator); - if (separator == ',') { - // COM detected, read the end of line and do nothing else - int res = fscanf (pFile, "%lf %c %hhi", &c4, &separator, &c5); - vtAssertExpr(res == 3); - } else { - // Load detected, create the new element - fsetpos (pFile,&pos); - if (c1PreviousValue != c1) { - c1PreviousValue = c1; - StatsLBReader::user_specified_map_.push_back(buffer); - buffer.clear(); - } - buffer.insert(c2); - } - } else { - finished = true; - } - } - - if (!buffer.empty()) { - StatsLBReader::user_specified_map_.push_back(buffer); - } - - std::fclose(pFile); -} - -/*static*/ void StatsLBReader::loadPhaseChangedMap() { - const auto num_iters = StatsLBReader::user_specified_map_.size() - 1; - vt_print(lb, "StatsLBReader::loadPhaseChangedMap size : {}\n", num_iters); - - balance::ProcStats::proc_move_list_.resize(num_iters + 1); - balance::ProcStats::proc_phase_runs_LB_.resize(num_iters, true); - - const auto myNodeID = static_cast(theContext()->getNode()); - if (myNodeID == 0) { - StatsLBReader::msgsReceived.resize(num_iters, 0); - StatsLBReader::totalMove.resize(num_iters); - } - - for (size_t ii = 0; ii < num_iters; ++ii) { - auto elms = StatsLBReader::user_specified_map_[ii]; - auto elmsNext = StatsLBReader::user_specified_map_[ii + 1]; - std::set diff; - std::set_difference(elmsNext.begin(), elmsNext.end(), elms.begin(), - elms.end(), std::inserter(diff, diff.begin())); - const size_t qi = diff.size(); - const size_t pi = elms.size() - (elmsNext.size() - qi); - auto &myList = ProcStats::proc_move_list_[ii]; - myList.reserve(3 * (pi + qi) + 1); - //--- Store the iteration number - myList.push_back(static_cast(ii)); - //--- Store partial migration information (i.e. nodes moving in) - for (auto iEle : diff) { - myList.push_back(iEle); //--- permID to receive - myList.push_back(no_element_id); // node moving from - myList.push_back(myNodeID); // node moving to - } - diff.clear(); - //--- Store partial migration information (i.e. nodes moving out) - std::set_difference(elms.begin(), elms.end(), elmsNext.begin(), - elmsNext.end(), std::inserter(diff, diff.begin())); - for (auto iEle : diff) { - myList.push_back(iEle); //--- permID to send - myList.push_back(myNodeID); // node migrating from - myList.push_back(no_element_id); // node migrating to - } - // - // Create a message storing the vector - // - auto msg = makeSharedMessage(myList); - StatsLBReader::proxy_[0].send(msg); - } - -} - -void StatsLBReader::doSend(lb::VecMsg *msg) { - auto sendVec = msg->getTransfer(); - const ElementIDType phaseID = sendVec[0]; - // - // --- Combine the different pieces of information - // - StatsLBReader::msgsReceived[phaseID] += 1; - auto &migrate = StatsLBReader::totalMove[phaseID]; - for (size_t ii = 1; ii < sendVec.size(); ii += 3) { - const auto permID = sendVec[ii]; - const auto nodeFrom = static_cast(sendVec[ii+1]); - const auto nodeTo = static_cast(sendVec[ii+2]); - auto iptr = migrate.find(permID); - if (iptr == migrate.end()) { - migrate.insert(std::make_pair(permID, std::make_pair(nodeFrom, nodeTo))); - } - else { - auto &nodePair = iptr->second; - nodePair.first = std::max(nodePair.first, nodeFrom); - nodePair.second = std::max(nodePair.second, nodeTo); - } - } - // - // --- Check whether all the messages have been received - // - const NodeType numNodes = theContext()->getNumNodes(); - if (StatsLBReader::msgsReceived[phaseID] < numNodes) - return; - // - //--- Distribute the information when everything has been received - // - for (NodeType in = 0; in < numNodes; ++in) { - size_t iCount = 0; - for (auto iNode : migrate) { - if (iNode.second.first == in) - iCount += 1; - } - const size_t header = 2; - std::vector toMove(2 * iCount + header); - iCount = 0; - toMove[iCount++] = phaseID; - toMove[iCount++] = static_cast(migrate.size()); - for (auto iNode : migrate) { - if (iNode.second.first == in) { - toMove[iCount++] = iNode.first; - toMove[iCount++] = static_cast(iNode.second.second); - } - } - if (in > 0) { - auto msg2 = makeSharedMessage(toMove); - StatsLBReader::proxy_[in].send - (msg2); - } else { - ProcStats::proc_phase_runs_LB_[phaseID] = (migrate.size() > 0); - auto &myList = ProcStats::proc_move_list_[phaseID]; - myList.resize(toMove.size() - header); - std::copy(&toMove[header], &toMove[0] + toMove.size(), - myList.begin()); - } - } - migrate.clear(); -} - -void StatsLBReader::scatterSend(lb::VecMsg *msg) { - const size_t header = 2; - auto recvVec = msg->getTransfer(); - const ElementIDType phaseID = recvVec[0]; - ProcStats::proc_phase_runs_LB_[phaseID] = static_cast(recvVec[1] > 0); - auto &myList = ProcStats::proc_move_list_[phaseID]; - if (recvVec.size() <= header) { - myList.clear(); - return; - } - // - myList.resize(recvVec.size() - header); - std::copy(&recvVec[header], &recvVec[0]+recvVec.size(), myList.begin()); -} - -}}}} /* end namespace vt::vrt::collection::balance */ diff --git a/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc b/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc index 4f060017ce..6d3dbd3928 100644 --- a/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc +++ b/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc @@ -44,13 +44,10 @@ #include "vt/config.h" #include "vt/vrt/collection/balance/baselb/baselb.h" +#include "vt/vrt/collection/balance/proc_stats.h" #include "vt/vrt/collection/balance/statsmaplb/statsmaplb.h" -#include "vt/vrt/collection/balance/stats_lb_reader.h" #include "vt/context/context.h" - -#include - namespace vt { namespace vrt { namespace collection { namespace lb { void StatsMapLB::init(objgroup::proxy::Proxy in_proxy) { @@ -77,7 +74,7 @@ void StatsMapLB::runLB() { finishMigrationCollective(); myNewList.clear(); - + } }}}} /* end namespace vt::vrt::collection::lb */