diff --git a/src/vt/vrt/collection/balance/greedylb/greedylb.cc b/src/vt/vrt/collection/balance/greedylb/greedylb.cc index 50f5efd0f0..7ae604ab1b 100644 --- a/src/vt/vrt/collection/balance/greedylb/greedylb.cc +++ b/src/vt/vrt/collection/balance/greedylb/greedylb.cc @@ -56,6 +56,7 @@ #include "vt/context/context.h" #include "vt/vrt/collection/manager.h" #include "vt/collective/reduce/reduce.h" +#include "vt/vrt/collection/balance/lb_args_enum_converter.h" #include #include @@ -72,11 +73,20 @@ void GreedyLB::init(objgroup::proxy::Proxy in_proxy) { } void GreedyLB::inputParams(balance::SpecEntry* spec) { - std::vector allowed{"min", "max", "auto"}; + std::vector allowed{"min", "max", "auto", "strategy"}; spec->checkAllowedKeys(allowed); min_threshold = spec->getOrDefault("min", greedy_threshold_p); max_threshold = spec->getOrDefault("max", greedy_max_threshold_p); auto_threshold = spec->getOrDefault("auto", greedy_auto_threshold_p); + + balance::LBArgsEnumConverter strategy_converter_( + "strategy", "DataDistStrategy", { + {DataDistStrategy::scatter, "scatter"}, + {DataDistStrategy::pt2pt, "pt2pt"}, + {DataDistStrategy::bcast, "bcast"} + } + ); + strat_ = strategy_converter_.getFromSpec(spec, strat_); } void GreedyLB::runLB() { @@ -214,24 +224,39 @@ GreedyLB::ObjIDType GreedyLB::objSetNode( return new_id; } -void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) { +void GreedyLB::recvObjs(GreedySendMsg* msg) { + vt_debug_print( + normal, lb, + "recvObjs: msg->transfer_.size={}\n", msg->transfer_.size() + ); + recvObjsDirect(msg->transfer_.size(), &msg->transfer_[0]); +} + +void GreedyLB::recvObjsBcast(GreedyBcastMsg* msg) { + auto const n = theContext()->getNode(); + vt_debug_print( + normal, lb, + "recvObjs: msg->transfer_.size={}\n", msg->transfer_[n].size() + ); + recvObjsDirect(msg->transfer_[n].size(), &msg->transfer_[n][0]); +} + +void GreedyLB::recvObjsDirect(std::size_t len, GreedyLBTypes::ObjIDType* objs) { auto const& this_node = theContext()->getNode(); - auto const& num_recs = *objs; - auto recs = objs + 1; + auto const& num_recs = len; vt_debug_print( normal, lb, "recvObjsDirect: num_recs={}\n", num_recs ); - for (decltype(+num_recs.id) i = 0; i < num_recs.id; i++) { - auto const to_node = objGetNode(recs[i]); - auto const new_obj_id = objSetNode(this_node,recs[i]); + for (std::size_t i = 0; i < len; i++) { + auto const to_node = objGetNode(objs[i]); + auto const new_obj_id = objSetNode(this_node,objs[i]); vt_debug_print( verbose, lb, - "\t recvObjs: i={}, to_node={}, obj={}, new_obj_id={}, num_recs={}, " - "byte_offset={}\n", - i, to_node, recs[i], new_obj_id, num_recs, - reinterpret_cast(recs) - reinterpret_cast(objs) + "\t recvObjs: i={}, to_node={}, obj={}, new_obj_id={}, num_recs={}" + "\n", + i, to_node, objs[i], new_obj_id, num_recs ); migrateObjectTo(new_obj_id, to_node); @@ -243,11 +268,11 @@ void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) { verbose, lb, "recvObjsHan: num_recs={}\n", *objs ); - scatter_proxy.get()->recvObjsDirect(objs); + scatter_proxy.get()->recvObjsDirect(static_cast(objs->id), objs+1); } void GreedyLB::transferObjs(std::vector&& in_load) { - std::size_t max_recs = 0, max_bytes = 0; + std::size_t max_recs = 1; std::vector load(std::move(in_load)); std::vector> node_transfer(load.size()); for (auto&& elm : load) { @@ -263,23 +288,36 @@ void GreedyLB::transferObjs(std::vector&& in_load) { } } } - max_bytes = max_recs * sizeof(GreedyLBTypes::ObjIDType); - vt_debug_print( - normal, lb, - "GreedyLB::transferObjs: max_recs={}, max_bytes={}\n", - max_recs, max_bytes - ); - theCollective()->scatter( - max_bytes*load.size(),max_bytes,nullptr,[&](NodeType node, void* ptr){ - auto ptr_out = reinterpret_cast(ptr); - auto const& proc = node_transfer[node]; - auto const& rec_size = proc.size(); - ptr_out->id = rec_size; - for (size_t i = 0; i < rec_size; i++) { - *(ptr_out + i + 1) = proc[i]; + + if (strat_ == DataDistStrategy::scatter) { + std::size_t max_bytes = max_recs * sizeof(GreedyLBTypes::ObjIDType); + vt_debug_print( + normal, lb, + "GreedyLB::transferObjs: max_recs={}, max_bytes={}\n", + max_recs, max_bytes + ); + theCollective()->scatter( + max_bytes*load.size(),max_bytes,nullptr,[&](NodeType node, void* ptr){ + auto ptr_out = reinterpret_cast(ptr); + auto const& proc = node_transfer[node]; + auto const& rec_size = proc.size(); + ptr_out->id = rec_size; + for (size_t i = 0; i < rec_size; i++) { + *(ptr_out + i + 1) = proc[i]; + } } + ); + } else if (strat_ == DataDistStrategy::pt2pt) { + for (NodeType n = 0; n < theContext()->getNumNodes(); n++) { + vtAssert( + node_transfer.size() == static_cast(theContext()->getNumNodes()), + "Must contain all nodes" + ); + proxy[n].send(node_transfer[n]); } - ); + } else if (strat_ == DataDistStrategy::bcast) { + proxy.broadcast(node_transfer); + } } double GreedyLB::getAvgLoad() const { diff --git a/src/vt/vrt/collection/balance/greedylb/greedylb.h b/src/vt/vrt/collection/balance/greedylb/greedylb.h index edb3bc0103..614c7ce728 100644 --- a/src/vt/vrt/collection/balance/greedylb/greedylb.h +++ b/src/vt/vrt/collection/balance/greedylb/greedylb.h @@ -60,6 +60,17 @@ namespace vt { namespace vrt { namespace collection { namespace lb { +/** + * \enum DataDistStrategy + * + * \brief How to distribute the data after the centralized LB makes a decision. + */ +enum struct DataDistStrategy : uint8_t { + scatter = 0, + bcast = 1, + pt2pt = 2 +}; + struct GreedyLB : BaseLB { using ElementLoadType = std::unordered_map; using TransferType = std::map>; @@ -86,7 +97,9 @@ struct GreedyLB : BaseLB { void runBalancer(ObjSampleType&& objs, LoadProfileType&& profile); void transferObjs(std::vector&& load); ObjIDType objSetNode(NodeType const& node, ObjIDType const& id); - void recvObjsDirect(GreedyLBTypes::ObjIDType* objs); + void recvObjsDirect(std::size_t len, GreedyLBTypes::ObjIDType* objs); + void recvObjs(GreedySendMsg* msg); + void recvObjsBcast(GreedyBcastMsg* msg); void finishedTransferExchange(); void collectHandler(GreedyCollectMsg* msg); @@ -105,8 +118,24 @@ struct GreedyLB : BaseLB { double max_threshold = 0.0f; double min_threshold = 0.0f; bool auto_threshold = true; + + DataDistStrategy strat_ = DataDistStrategy::scatter; }; }}}} /* end namespace vt::vrt::collection::lb */ +namespace std { + +template <> +struct hash<::vt::vrt::collection::lb::DataDistStrategy> { + size_t operator()(::vt::vrt::collection::lb::DataDistStrategy const& in) const { + using under = std::underlying_type< + ::vt::vrt::collection::lb::DataDistStrategy + >::type; + return std::hash()(static_cast(in)); + } +}; + +} /* end namespace std */ + #endif /*INCLUDED_VT_VRT_COLLECTION_BALANCE_GREEDYLB_GREEDYLB_H*/ diff --git a/src/vt/vrt/collection/balance/greedylb/greedylb_msgs.h b/src/vt/vrt/collection/balance/greedylb/greedylb_msgs.h index 0112a575f9..401545d3f3 100644 --- a/src/vt/vrt/collection/balance/greedylb/greedylb_msgs.h +++ b/src/vt/vrt/collection/balance/greedylb/greedylb_msgs.h @@ -129,6 +129,44 @@ struct GreedyCollectMsg : GreedyLBTypes, collective::ReduceTMsg { } }; +struct GreedySendMsg : GreedyLBTypes, vt::Message { + using MessageParentType = vt::Message; + vt_msg_serialize_required(); // vector + + GreedySendMsg() = default; + explicit GreedySendMsg(std::vector const& in) + : transfer_(in) + { } + + template + void serialize(SerializerT& s) { + MessageParentType::serialize(s); + s | transfer_; + } + + std::vector transfer_; +}; + +struct GreedyBcastMsg : GreedyLBTypes, vt::Message { + using MessageParentType = vt::Message; + vt_msg_serialize_required(); // vector + + using DataType = std::vector>; + + GreedyBcastMsg() = default; + explicit GreedyBcastMsg(DataType const& in) + : transfer_(in) + { } + + template + void serialize(SerializerT& s) { + MessageParentType::serialize(s); + s | transfer_; + } + + DataType transfer_; +}; + }}}} /* end namespace vt::vrt::collection::lb */ #endif /*INCLUDED_VT_VRT_COLLECTION_BALANCE_GREEDYLB_GREEDYLB_MSGS_H*/ diff --git a/tests/unit/collection/test_lb.extended.cc b/tests/unit/collection/test_lb.extended.cc index e2eec57736..7de9b142b4 100644 --- a/tests/unit/collection/test_lb.extended.cc +++ b/tests/unit/collection/test_lb.extended.cc @@ -101,6 +101,12 @@ void TestLoadBalancer::runTest() { fmt::print("Using lb_args {}\n", lb_args); } } + if (lb_name.substr(0, 8).compare("GreedyLB") == 0) { + vt::theConfig()->vt_lb_name = "GreedyLB"; + auto strat_arg = lb_name.substr(9, lb_name.size() - 9); + fmt::print("strat_arg={}\n", strat_arg); + vt::theConfig()->vt_lb_args = strat_arg; + } vt::theCollective()->barrier(); @@ -161,7 +167,9 @@ auto balancers = ::testing::Values( "RotateLB", "HierarchicalLB", "TemperedLB", - "GreedyLB" + "GreedyLB:strategy=scatter", + "GreedyLB:strategy=pt2pt", + "GreedyLB:strategy=bcast" # if vt_check_enabled(zoltan) , "ZoltanLB" # endif