Skip to content

Commit

Permalink
#2070: variant: replace SafeUnion with std::variant
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander committed Aug 3, 2023
1 parent 23bb0a9 commit da59a8b
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 1,118 deletions.
2 changes: 1 addition & 1 deletion cmake/trace_only_functions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function(create_trace_only_target)
# vt/utils
vt/utils/demangle/demangle.h vt/utils/bits/bits_counter.h
vt/utils/bits/bits_common.h vt/utils/bits/bits_packer.h
vt/utils/bits/bits_packer.impl.h vt/utils/adt/union.h
vt/utils/bits/bits_packer.impl.h
vt/utils/adt/histogram_approx.h

# vt/collective
Expand Down
2 changes: 1 addition & 1 deletion src/vt/collective/reduce/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Reduce::Reduce(
detail::ReduceStamp Reduce::generateNextID() {
*next_seq_ = *next_seq_ + 1;
detail::ReduceStamp stamp;
stamp.init<detail::StrongSeq>(next_seq_);
stamp = detail::StrongSeq{next_seq_};
return stamp;
}

Expand Down
5 changes: 3 additions & 2 deletions src/vt/collective/reduce/reduce.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,10 @@ detail::ReduceStamp Reduce::reduceImmediate(
NodeType root, MsgT* const msg, detail::ReduceStamp id,
ReduceNumType num_contrib
) {
if (scope_.get().is<detail::StrongGroup>()) {
envelopeSetGroup(msg->env, scope_.get().get<detail::StrongGroup>().get());
if (std::holds_alternative<detail::StrongGroup>(scope_.get())) {
envelopeSetGroup(msg->env, std::get<detail::StrongGroup>(scope_.get()).get());
}

auto cur_id = id == detail::ReduceStamp{} ? generateNextID() : id;

auto const han = auto_registry::makeAutoHandler<MsgT,f>();
Expand Down
30 changes: 16 additions & 14 deletions src/vt/collective/reduce/reduce_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
#define INCLUDED_VT_COLLECTIVE_REDUCE_REDUCE_SCOPE_H

#include "vt/collective/reduce/scoping/strong_types.h"
#include "vt/utils/adt/union.h"

#include <unordered_map>
#include <variant>

namespace vt { namespace collective { namespace reduce { namespace detail {

Expand All @@ -58,7 +58,9 @@ namespace vt { namespace collective { namespace reduce { namespace detail {
* proxy, virtual proxy, group ID, or component ID.
*/
struct ReduceScope {
using ValueType = vt::adt::SafeUnion<
using isByteCopyable = std::true_type;

using ValueType = std::variant<
StrongObjGroup, StrongVrtProxy, StrongGroup, StrongCom, StrongUserID
>;

Expand All @@ -83,16 +85,16 @@ struct ReduceScope {
}

std::string str() const {
if (l0_.is<StrongObjGroup>()) {
return fmt::format("objgroup[{:x}]", l0_.get<StrongObjGroup>().get());
} else if (l0_.is<StrongVrtProxy>()) {
return fmt::format("vrtproxy[{:x}]", l0_.get<StrongVrtProxy>().get());
} else if (l0_.is<StrongGroup>()) {
return fmt::format("group[{:x}]", l0_.get<StrongGroup>().get());
} else if (l0_.is<StrongCom>()) {
return fmt::format("component[{}]", l0_.get<StrongCom>().get());
} else if (l0_.is<StrongUserID>()) {
return fmt::format("userID[{}]", l0_.get<StrongUserID>().get());
if (std::holds_alternative<StrongObjGroup>(l0_)) {
return fmt::format("objgroup[{:x}]", std::get<StrongObjGroup>(l0_).get());
} else if (std::holds_alternative<StrongVrtProxy>(l0_)) {
return fmt::format("vrtproxy[{:x}]", std::get<StrongVrtProxy>(l0_).get());
} else if (std::holds_alternative<StrongGroup>(l0_)) {
return fmt::format("group[{:x}]", std::get<StrongGroup>(l0_).get());
} else if (std::holds_alternative<StrongCom>(l0_)) {
return fmt::format("component[{}]", std::get<StrongCom>(l0_).get());
} else if (std::holds_alternative<StrongUserID>(l0_)) {
return fmt::format("userID[{}]", std::get<StrongUserID>(l0_).get());
} else {
return "<unknown-type>";
}
Expand All @@ -115,7 +117,7 @@ inline ReduceScope makeScope(Args&&... args);
/**
* \brief Reduction stamp bits to identify a specific instance of a reduction.
*/
using ReduceStamp = vt::adt::SafeUnion<
using ReduceStamp = std::variant<
StrongTag, TagPair, StrongSeq, StrongUserID, StrongEpoch
>;

Expand Down Expand Up @@ -251,7 +253,7 @@ using TagPair = detail::TagPair;
template <typename T, typename... Args>
ReduceStamp makeStamp(Args&&... args) {
ReduceStamp stamp;
stamp.init<T>(std::forward<Args>(args)...);
stamp = T{std::forward<Args>(args)...};
return stamp;
}

Expand Down
26 changes: 13 additions & 13 deletions src/vt/collective/reduce/reduce_scope.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,25 @@ namespace vt { namespace collective { namespace reduce { namespace detail {
template <typename T, typename... Args>
inline ReduceScope makeScope(Args&&... args) {
ReduceScope scope;
scope.l0_.init<T>(std::forward<Args>(args)...);
scope.l0_ = T{std::forward<Args>(args)...};
return scope;
}

inline std::string stringizeStamp(ReduceStamp const& stamp) {
if (stamp.is<StrongTag>()) {
return fmt::format("tag[{}]", stamp.get<StrongTag>().get());
} else if (stamp.is<TagPair>()) {
if (std::holds_alternative<StrongTag>(stamp)) {
return fmt::format("tag[{}]", std::get<StrongTag>(stamp).get());
} else if (std::holds_alternative<TagPair>(stamp)) {
return fmt::format(
"tagPair[{},{}]",
stamp.get<TagPair>().first(),
stamp.get<TagPair>().second()
std::get<TagPair>(stamp).first(),
std::get<TagPair>(stamp).second()
);
} else if (stamp.is<StrongSeq>()) {
return fmt::format("seq[{}]", stamp.get<StrongSeq>().get());
} else if (stamp.is<StrongUserID>()) {
return fmt::format("userID[{}]", stamp.get<StrongUserID>().get());
} else if (stamp.is<StrongEpoch>()) {
return fmt::format("epoch[{:x}]", stamp.get<StrongEpoch>().get());
} else if (std::holds_alternative<StrongSeq>(stamp)) {
return fmt::format("seq[{}]", std::get<StrongSeq>(stamp).get());
} else if (std::holds_alternative<StrongUserID>(stamp)) {
return fmt::format("userID[{}]", std::get<StrongUserID>(stamp).get());
} else if (std::holds_alternative<StrongEpoch>(stamp)) {
return fmt::format("epoch[{:x}]", std::get<StrongEpoch>(stamp).get());
} else {
return "<no-stamp>";
}
Expand All @@ -94,7 +94,7 @@ T& ReduceScopeHolder<T>::getOnDemand(U&& scope) {
auto iter = scopes_.find(scope);
if (iter == scopes_.end()) {
vtAssert(
not scope.get().template is<StrongGroup>(),
not std::holds_alternative<StrongGroup>(scope.get()),
"Group reducers cannot be on-demand created -- needs spanning tree"
);

Expand Down
9 changes: 3 additions & 6 deletions src/vt/runtime/component/diagnostic_erased_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@
#define INCLUDED_VT_RUNTIME_COMPONENT_DIAGNOSTIC_ERASED_VALUE_H

#include "vt/runtime/component/diagnostic_types.h"
#include "vt/utils/adt/union.h"
#include "vt/utils/adt/histogram_approx.h"
#include "vt/utils/strong/strong_type.h"

#include <string>
#include <variant>

namespace vt { namespace runtime { namespace component {

Expand All @@ -61,11 +62,7 @@ namespace vt { namespace runtime { namespace component {
struct DiagnosticErasedValue {
/// These are the set of valid diagnostic value types after being erased from
/// \c DiagnosticValue<T> get turned into this union for saving the value.
using UnionValueType = vt::adt::SafeUnion<
double, float,
int8_t, int16_t, int32_t, int64_t,
uint8_t, uint16_t, uint32_t, uint64_t
>;
using UnionValueType = std::variant<double, int64_t>;

// The trio (min, max, sum) save the actual type with the value to print it
// correctly
Expand Down
9 changes: 1 addition & 8 deletions src/vt/runtime/component/diagnostic_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,7 @@ void reduceHelper(


DIAGNOSTIC_VALUE_INSTANCE(int64_t)
DIAGNOSTIC_VALUE_INSTANCE(int32_t)
DIAGNOSTIC_VALUE_INSTANCE(int16_t)
DIAGNOSTIC_VALUE_INSTANCE(int8_t)
DIAGNOSTIC_VALUE_INSTANCE(uint64_t)
DIAGNOSTIC_VALUE_INSTANCE(uint32_t)
DIAGNOSTIC_VALUE_INSTANCE(uint16_t)
DIAGNOSTIC_VALUE_INSTANCE(uint8_t)
// DIAGNOSTIC_VALUE_INSTANCE(uint64_t)
DIAGNOSTIC_VALUE_INSTANCE(double)
DIAGNOSTIC_VALUE_INSTANCE(float)

}}}} /* end namespace vt::runtime::component::detail */
6 changes: 3 additions & 3 deletions src/vt/runtime/component/diagnostic_value_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ struct DiagnosticEraser {
*/
static DiagnosticErasedValue get(DiagnosticValueWrapper<T> wrapper) {
DiagnosticErasedValue eval;
eval.min_.template init<T>(wrapper.min());
eval.max_.template init<T>(wrapper.max());
eval.sum_.template init<T>(wrapper.sum());
eval.min_ = T{wrapper.min()};
eval.max_ = T{wrapper.max()};
eval.sum_ = T{wrapper.sum()};
eval.avg_ = wrapper.avg();
eval.std_ = wrapper.stdv();
return eval;
Expand Down
25 changes: 14 additions & 11 deletions src/vt/runtime/runtime_diagnostics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ struct FormatHelper {
std::string default_format = is_decimal ? decimal_format : std::string{"{}"};

if (base_) {
return fmt::format(default_format, eval.get<T>());
return fmt::format(default_format, std::get<T>(eval));
} else {
return DF::getValueWithUnits(eval.get<T>(), unit_, default_format, align_);
return DF::getValueWithUnits(std::get<T>(eval), unit_, default_format, align_);
}
}

Expand All @@ -100,22 +100,25 @@ struct FormatHelper {
bool base_ = false;
};

template <>
std::string FormatHelper::apply<void>(
typename component::DiagnosticErasedValue::UnionValueType
) {
vtAssert(false, "Failed to extract type from union");
return "";
}

std::string valueFormatHelper(
typename component::DiagnosticErasedValue::UnionValueType eval,
component::DiagnosticUnit unit,
bool align = true,
bool base = false
) {
FormatHelper fn(unit, align, base);
return eval.switchOn(fn);
std::string out = "";
std::visit([&](auto&& t){
using T = std::decay_t<decltype(t)>;
if constexpr (std::is_same_v<T, double>) {
out = fn.apply<double>(eval);
} else if constexpr (std::is_same_v<T, int64_t>) {
out = fn.apply<int64_t>(eval);
} else if constexpr (std::is_same_v<T, uint64_t>) {
out = fn.apply<uint64_t>(eval);
}
}, eval);
return out;
}

std::string valueFormatHelper(
Expand Down
Loading

0 comments on commit da59a8b

Please sign in to comment.