diff --git a/scripts/JSON_data_files_validator.py b/scripts/JSON_data_files_validator.py index 63daa81e21..bacf268dde 100644 --- a/scripts/JSON_data_files_validator.py +++ b/scripts/JSON_data_files_validator.py @@ -115,7 +115,8 @@ def _get_valid_schema(self) -> Schema: }, 'bytes': float } - ] + ], + Optional('user_defined'): dict }, ] } diff --git a/src/vt/phase/phase_manager.h b/src/vt/phase/phase_manager.h index 62f209a100..4e8591eff8 100644 --- a/src/vt/phase/phase_manager.h +++ b/src/vt/phase/phase_manager.h @@ -49,6 +49,7 @@ #include "vt/phase/phase_hook_enum.h" #include "vt/phase/phase_hook_id.h" #include "vt/vrt/collection/balance/lb_invoke/phase_info.h" +#include "vt/vrt/collection/balance/node_lb_data.h" #include #include @@ -169,6 +170,36 @@ struct PhaseManager : runtime::component::Component { */ void setStartTime(); + +template +struct is_jsonable : std::false_type {}; + +template +struct is_jsonable()))>> : std::true_type {}; + +template +void addUserDefinedData(PhaseType phase, const KeyT& key, const ValueT& value) { + static_assert( + is_jsonable::value, + "PhaseManager::addUserDefinedData: KeyT type is not jsonable" + ); + static_assert( + is_jsonable::value, + "PhaseManager::addUserDefinedData: ValueT type is not jsonable" + ); + + auto perPhase = theNodeLBData()->getLBData()->user_per_phase_json_; + + if (perPhase.find(phase) == perPhase.end()) { + auto j = std::make_shared(); + j->emplace(key, value); + + theNodeLBData()->getLBData()->user_per_phase_json_[phase] = j; + } else { + perPhase[phase]->emplace(key, value); + } +} + /** * \brief Print summary for the current phase. * diff --git a/src/vt/vrt/collection/balance/lb_data_holder.cc b/src/vt/vrt/collection/balance/lb_data_holder.cc index 3f0da67f2a..d679c70fdb 100644 --- a/src/vt/vrt/collection/balance/lb_data_holder.cc +++ b/src/vt/vrt/collection/balance/lb_data_holder.cc @@ -218,6 +218,14 @@ std::unique_ptr LBDataHolder::toJson(PhaseType phase) const { } } + if (user_per_phase_json_.find(phase) != user_per_phase_json_.end()) { + auto& user_def_this_phase = user_per_phase_json_.at(phase); + + if (!user_def_this_phase->empty()) { + j["user_defined"] = *user_def_this_phase; + } + } + return std::make_unique(std::move(j)); } @@ -362,6 +370,12 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) } } } + + if (phase.find("user_defined") != phase.end()) { + auto userDefined = phase["user_defined"]; + user_per_phase_json_[phase] = std::make_shared(); + *(user_per_phase_json_[phase]) = userDefined; + } } } diff --git a/src/vt/vrt/collection/balance/lb_data_holder.h b/src/vt/vrt/collection/balance/lb_data_holder.h index f12a614426..3a69999d98 100644 --- a/src/vt/vrt/collection/balance/lb_data_holder.h +++ b/src/vt/vrt/collection/balance/lb_data_holder.h @@ -79,6 +79,7 @@ struct LBDataHolder { s | node_comm_; s | node_subphase_comm_; s | user_defined_json_; + s | user_per_phase_json_; s | node_idx_; s | count_; s | skipped_phases_; @@ -134,6 +135,8 @@ struct LBDataHolder { std::unordered_map >> user_defined_json_; + + std::unordered_map> user_per_phase_json_; /// User-defined data from each phase for LB std::unordered_map user_defined_lb_info_; /// Node indices for each ID along with the proxy ID diff --git a/tests/unit/collection/test_lb.extended.cc b/tests/unit/collection/test_lb.extended.cc index 0d01f82dca..d9d0dd199c 100644 --- a/tests/unit/collection/test_lb.extended.cc +++ b/tests/unit/collection/test_lb.extended.cc @@ -408,6 +408,12 @@ TEST_P(TestNodeLBDataDumper, test_node_lb_data_dumping_with_interval) { proxy.broadcastCollective(); }); + vt::thePhase()->addUserDefinedData( + phase, std::string{"time"}, static_cast(phase)); + + vt::thePhase()->addUserDefinedData( + phase, std::string{"new_time"}, static_cast(phase)); + // Go to the next phase vt::thePhase()->nextPhaseCollective(); } @@ -424,6 +430,14 @@ TEST_P(TestNodeLBDataDumper, test_node_lb_data_dumping_with_interval) { EXPECT_TRUE(json.find("phases") != json.end()); EXPECT_EQ(json["phases"].size(), num_phases); + for(const auto& phase : json["phases"]){ + EXPECT_TRUE(phase.find("user_defined") != phase.end()); + EXPECT_TRUE(phase["user_defined"].contains("time")); + EXPECT_EQ(phase["user_defined"]["time"], static_cast(phase["id"])); + EXPECT_TRUE(phase["user_defined"].contains("new_time")); + EXPECT_EQ(phase["user_defined"]["new_time"], static_cast(phase["id"])); + } + }); if (vt::theContext()->getNode() == 0) {