Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1395 Fix GreedyLB Bug #1470

Merged
merged 8 commits into from
Aug 23, 2021
94 changes: 66 additions & 28 deletions src/vt/vrt/collection/balance/greedylb/greedylb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "vt/context/context.h"
#include "vt/vrt/collection/manager.h"
#include "vt/collective/reduce/reduce.h"
#include "vt/vrt/collection/balance/lb_args_enum_converter.h"

#include <unordered_map>
#include <memory>
Expand All @@ -72,11 +73,20 @@ void GreedyLB::init(objgroup::proxy::Proxy<GreedyLB> in_proxy) {
}

void GreedyLB::inputParams(balance::SpecEntry* spec) {
std::vector<std::string> allowed{"min", "max", "auto"};
std::vector<std::string> allowed{"min", "max", "auto", "strategy"};
spec->checkAllowedKeys(allowed);
min_threshold = spec->getOrDefault<double>("min", greedy_threshold_p);
max_threshold = spec->getOrDefault<double>("max", greedy_max_threshold_p);
auto_threshold = spec->getOrDefault<bool>("auto", greedy_auto_threshold_p);

balance::LBArgsEnumConverter<DataDistStrategy> strategy_converter_(
"strategy", "DataDistStrategy", {
{DataDistStrategy::scatter, "scatter"},
{DataDistStrategy::pt2pt, "pt2pt"},
{DataDistStrategy::bcast, "bcast"}
}
);
strat_ = strategy_converter_.getFromSpec(spec, strat_);
}

void GreedyLB::runLB() {
Expand Down Expand Up @@ -214,24 +224,39 @@ GreedyLB::ObjIDType GreedyLB::objSetNode(
return new_id;
}

void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) {
void GreedyLB::recvObjs(GreedySendMsg* msg) {
vt_debug_print(
normal, lb,
"recvObjs: msg->transfer_.size={}\n", msg->transfer_.size()
);
recvObjsDirect(msg->transfer_.size(), &msg->transfer_[0]);
}

void GreedyLB::recvObjsBcast(GreedyBcastMsg* msg) {
auto const n = theContext()->getNode();
vt_debug_print(
normal, lb,
"recvObjs: msg->transfer_.size={}\n", msg->transfer_[n].size()
);
recvObjsDirect(msg->transfer_[n].size(), &msg->transfer_[n][0]);
}

void GreedyLB::recvObjsDirect(std::size_t len, GreedyLBTypes::ObjIDType* objs) {
auto const& this_node = theContext()->getNode();
auto const& num_recs = *objs;
auto recs = objs + 1;
auto const& num_recs = len;
vt_debug_print(
normal, lb,
"recvObjsDirect: num_recs={}\n", num_recs
);

for (decltype(+num_recs.id) i = 0; i < num_recs.id; i++) {
auto const to_node = objGetNode(recs[i]);
auto const new_obj_id = objSetNode(this_node,recs[i]);
for (std::size_t i = 0; i < len; i++) {
auto const to_node = objGetNode(objs[i]);
auto const new_obj_id = objSetNode(this_node,objs[i]);
vt_debug_print(
verbose, lb,
"\t recvObjs: i={}, to_node={}, obj={}, new_obj_id={}, num_recs={}, "
"byte_offset={}\n",
i, to_node, recs[i], new_obj_id, num_recs,
reinterpret_cast<char*>(recs) - reinterpret_cast<char*>(objs)
"\t recvObjs: i={}, to_node={}, obj={}, new_obj_id={}, num_recs={}"
"\n",
i, to_node, objs[i], new_obj_id, num_recs
);

migrateObjectTo(new_obj_id, to_node);
Expand All @@ -243,11 +268,11 @@ void GreedyLB::recvObjsDirect(GreedyLBTypes::ObjIDType* objs) {
verbose, lb,
"recvObjsHan: num_recs={}\n", *objs
);
scatter_proxy.get()->recvObjsDirect(objs);
scatter_proxy.get()->recvObjsDirect(static_cast<std::size_t>(objs->id), objs+1);
}

void GreedyLB::transferObjs(std::vector<GreedyProc>&& in_load) {
std::size_t max_recs = 0, max_bytes = 0;
std::size_t max_recs = 1;
std::vector<GreedyProc> load(std::move(in_load));
std::vector<std::vector<GreedyLBTypes::ObjIDType>> node_transfer(load.size());
for (auto&& elm : load) {
Expand All @@ -263,23 +288,36 @@ void GreedyLB::transferObjs(std::vector<GreedyProc>&& in_load) {
}
}
}
max_bytes = max_recs * sizeof(GreedyLBTypes::ObjIDType);
vt_debug_print(
normal, lb,
"GreedyLB::transferObjs: max_recs={}, max_bytes={}\n",
max_recs, max_bytes
);
theCollective()->scatter<GreedyLBTypes::ObjIDType,recvObjsHan>(
max_bytes*load.size(),max_bytes,nullptr,[&](NodeType node, void* ptr){
auto ptr_out = reinterpret_cast<GreedyLBTypes::ObjIDType*>(ptr);
auto const& proc = node_transfer[node];
auto const& rec_size = proc.size();
ptr_out->id = rec_size;
for (size_t i = 0; i < rec_size; i++) {
*(ptr_out + i + 1) = proc[i];

if (strat_ == DataDistStrategy::scatter) {
std::size_t max_bytes = max_recs * sizeof(GreedyLBTypes::ObjIDType);
vt_debug_print(
normal, lb,
"GreedyLB::transferObjs: max_recs={}, max_bytes={}\n",
max_recs, max_bytes
);
theCollective()->scatter<GreedyLBTypes::ObjIDType,recvObjsHan>(
max_bytes*load.size(),max_bytes,nullptr,[&](NodeType node, void* ptr){
auto ptr_out = reinterpret_cast<GreedyLBTypes::ObjIDType*>(ptr);
auto const& proc = node_transfer[node];
auto const& rec_size = proc.size();
ptr_out->id = rec_size;
for (size_t i = 0; i < rec_size; i++) {
*(ptr_out + i + 1) = proc[i];
}
}
);
} else if (strat_ == DataDistStrategy::pt2pt) {
for (NodeType n = 0; n < theContext()->getNumNodes(); n++) {
vtAssert(
node_transfer.size() == static_cast<size_t>(theContext()->getNumNodes()),
"Must contain all nodes"
);
proxy[n].send<GreedySendMsg, &GreedyLB::recvObjs>(node_transfer[n]);
}
);
} else if (strat_ == DataDistStrategy::bcast) {
proxy.broadcast<GreedyBcastMsg, &GreedyLB::recvObjsBcast>(node_transfer);
}
}

double GreedyLB::getAvgLoad() const {
Expand Down
31 changes: 30 additions & 1 deletion src/vt/vrt/collection/balance/greedylb/greedylb.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@

namespace vt { namespace vrt { namespace collection { namespace lb {

/**
* \enum DataDistStrategy
*
* \brief How to distribute the data after the centralized LB makes a decision.
*/
enum struct DataDistStrategy : uint8_t {
scatter = 0,
bcast = 1,
pt2pt = 2
};

struct GreedyLB : BaseLB {
using ElementLoadType = std::unordered_map<ObjIDType,TimeType>;
using TransferType = std::map<NodeType, std::vector<ObjIDType>>;
Expand All @@ -86,7 +97,9 @@ struct GreedyLB : BaseLB {
void runBalancer(ObjSampleType&& objs, LoadProfileType&& profile);
void transferObjs(std::vector<GreedyProc>&& load);
ObjIDType objSetNode(NodeType const& node, ObjIDType const& id);
void recvObjsDirect(GreedyLBTypes::ObjIDType* objs);
void recvObjsDirect(std::size_t len, GreedyLBTypes::ObjIDType* objs);
void recvObjs(GreedySendMsg* msg);
void recvObjsBcast(GreedyBcastMsg* msg);
void finishedTransferExchange();
void collectHandler(GreedyCollectMsg* msg);

Expand All @@ -105,8 +118,24 @@ struct GreedyLB : BaseLB {
double max_threshold = 0.0f;
double min_threshold = 0.0f;
bool auto_threshold = true;

DataDistStrategy strat_ = DataDistStrategy::scatter;
};

}}}} /* end namespace vt::vrt::collection::lb */

namespace std {

template <>
struct hash<::vt::vrt::collection::lb::DataDistStrategy> {
size_t operator()(::vt::vrt::collection::lb::DataDistStrategy const& in) const {
using under = std::underlying_type<
::vt::vrt::collection::lb::DataDistStrategy
>::type;
return std::hash<under>()(static_cast<under>(in));
}
};

} /* end namespace std */

#endif /*INCLUDED_VT_VRT_COLLECTION_BALANCE_GREEDYLB_GREEDYLB_H*/
38 changes: 38 additions & 0 deletions src/vt/vrt/collection/balance/greedylb/greedylb_msgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,44 @@ struct GreedyCollectMsg : GreedyLBTypes, collective::ReduceTMsg<GreedyPayload> {
}
};

struct GreedySendMsg : GreedyLBTypes, vt::Message {
using MessageParentType = vt::Message;
vt_msg_serialize_required(); // vector

GreedySendMsg() = default;
explicit GreedySendMsg(std::vector<GreedyLBTypes::ObjIDType> const& in)
: transfer_(in)
{ }

template <typename SerializerT>
void serialize(SerializerT& s) {
MessageParentType::serialize(s);
s | transfer_;
}

std::vector<GreedyLBTypes::ObjIDType> transfer_;
};

struct GreedyBcastMsg : GreedyLBTypes, vt::Message {
using MessageParentType = vt::Message;
vt_msg_serialize_required(); // vector

using DataType = std::vector<std::vector<GreedyLBTypes::ObjIDType>>;

GreedyBcastMsg() = default;
explicit GreedyBcastMsg(DataType const& in)
: transfer_(in)
{ }

template <typename SerializerT>
void serialize(SerializerT& s) {
MessageParentType::serialize(s);
s | transfer_;
}

DataType transfer_;
};

}}}} /* end namespace vt::vrt::collection::lb */

#endif /*INCLUDED_VT_VRT_COLLECTION_BALANCE_GREEDYLB_GREEDYLB_MSGS_H*/
10 changes: 9 additions & 1 deletion tests/unit/collection/test_lb.extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ void TestLoadBalancer::runTest() {
fmt::print("Using lb_args {}\n", lb_args);
}
}
if (lb_name.substr(0, 8).compare("GreedyLB") == 0) {
vt::theConfig()->vt_lb_name = "GreedyLB";
auto strat_arg = lb_name.substr(9, lb_name.size() - 9);
fmt::print("strat_arg={}\n", strat_arg);
vt::theConfig()->vt_lb_args = strat_arg;
}

vt::theCollective()->barrier();

Expand Down Expand Up @@ -161,7 +167,9 @@ auto balancers = ::testing::Values(
"RotateLB",
"HierarchicalLB",
"TemperedLB",
"GreedyLB"
"GreedyLB:strategy=scatter",
"GreedyLB:strategy=pt2pt",
"GreedyLB:strategy=bcast"
# if vt_check_enabled(zoltan)
, "ZoltanLB"
# endif
Expand Down