Skip to content

Commit

Permalink
fix scatterv, fixes #34
Browse files Browse the repository at this point in the history
  • Loading branch information
rabauke committed Mar 5, 2023
1 parent e211f5b commit 6f7c044
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 35 deletions.
4 changes: 2 additions & 2 deletions mpl/comm_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,7 @@ namespace mpl {
if (rank() == root_rank)
alltoallv(send_data, sendls, senddispls, recv_data, recvls, recvdispls);
else
alltoallv(send_data, sendls, senddispls, recv_data, mpl::layouts<T>(n), recvdispls);
alltoallv(send_data, mpl::layouts<T>(n), senddispls, recv_data, recvls, recvdispls);
}

/// Scatter messages with a variable amount of data from a single root process to all
Expand Down Expand Up @@ -3113,7 +3113,7 @@ namespace mpl {
if (rank() == root_rank)
return ialltoallv(send_data, sendls, senddispls, recv_data, recvls, recvdispls);
else
return ialltoallv(send_data, sendls, senddispls, recv_data, mpl::layouts<T>(n),
return ialltoallv(send_data, mpl::layouts<T>(n), senddispls, recv_data, recvls,
recvdispls);
}

Expand Down
64 changes: 46 additions & 18 deletions test/test_communicator_gatherv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "test_helper.hpp"


template<typename T>
template<use_non_root_overload variant, typename T>
bool gatherv_test(const T &val) {
const mpl::communicator &comm_world{mpl::environment::comm_world()};
const int N{(comm_world.size() * comm_world.size() + comm_world.size()) / 2};
Expand All @@ -25,17 +25,27 @@ bool gatherv_test(const T &val) {
++t_val;
std::iota(begin(v_send), end(v_send), t_val);
mpl::vector_layout<T> layout(comm_world.rank() + 1);
if (comm_world.rank() == 0) {
comm_world.gatherv(0, v_send.data(), layout, v_gather.data(), layouts);
return v_gather == v_expected;
if constexpr (variant == use_non_root_overload::yes) {
if (comm_world.rank() == 0) {
comm_world.gatherv(0, v_send.data(), layout, v_gather.data(), layouts);
return v_gather == v_expected;
} else {
comm_world.gatherv(0, v_send.data(), layout);
return true;
}
} else {
comm_world.gatherv(0, v_send.data(), layout);
return true;
if (comm_world.rank() == 0) {
comm_world.gatherv(0, v_send.data(), layout, v_gather.data(), layouts);
return v_gather == v_expected;
} else {
comm_world.gatherv(0, v_send.data(), layout, v_gather.data(), layouts);
return true;
}
}
}


template<typename T>
template<use_non_root_overload variant, typename T>
bool igatherv_test(const T &val) {
const mpl::communicator &comm_world{mpl::environment::comm_world()};
const int N{(comm_world.size() * comm_world.size() + comm_world.size()) / 2};
Expand All @@ -54,22 +64,40 @@ bool igatherv_test(const T &val) {
++t_val;
std::iota(begin(v_send), end(v_send), t_val);
mpl::vector_layout<T> layout(comm_world.rank() + 1);
if (comm_world.rank() == 0) {
auto r{comm_world.igatherv(0, v_send.data(), layout, v_gather.data(), layouts)};
r.wait();
return v_gather == v_expected;
if constexpr (variant == use_non_root_overload::yes) {
if (comm_world.rank() == 0) {
auto r{comm_world.igatherv(0, v_send.data(), layout, v_gather.data(), layouts)};
r.wait();
return v_gather == v_expected;
} else {
auto r{comm_world.igatherv(0, v_send.data(), layout)};
r.wait();
return true;
}
} else {
auto r{comm_world.igatherv(0, v_send.data(), layout)};
r.wait();
return true;
if (comm_world.rank() == 0) {
auto r{comm_world.igatherv(0, v_send.data(), layout, v_gather.data(), layouts)};
r.wait();
return v_gather == v_expected;
} else {
auto r{comm_world.igatherv(0, v_send.data(), layout, v_gather.data(), layouts)};
r.wait();
return true;
}
}
}


BOOST_AUTO_TEST_CASE(gatherv) {
BOOST_TEST(gatherv_test(1.0));
BOOST_TEST(gatherv_test(tuple{1, 2.0}));
BOOST_TEST(gatherv_test<use_non_root_overload::no>(1.0));
BOOST_TEST(gatherv_test<use_non_root_overload::no>(tuple{1, 2.0}));

BOOST_TEST(igatherv_test(1.0));
BOOST_TEST(igatherv_test(tuple{1, 2.0}));
BOOST_TEST(gatherv_test<use_non_root_overload::yes>(1.0));
BOOST_TEST(gatherv_test<use_non_root_overload::yes>(tuple{1, 2.0}));

BOOST_TEST(igatherv_test<use_non_root_overload::no>(1.0));
BOOST_TEST(igatherv_test<use_non_root_overload::no>(tuple{1, 2.0}));

BOOST_TEST(igatherv_test<use_non_root_overload::yes>(1.0));
BOOST_TEST(igatherv_test<use_non_root_overload::yes>(tuple{1, 2.0}));
}
45 changes: 30 additions & 15 deletions test/test_communicator_scatterv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

#include <boost/test/included/unit_test.hpp>
#include <mpl/mpl.hpp>
#include <tuple>
#include "test_helper.hpp"

template<typename T>

template<use_non_root_overload variant, typename T>
bool scatterv_test(const T &val) {
const mpl::communicator &comm_world{mpl::environment::comm_world()};
const int N{(comm_world.size() * comm_world.size() + comm_world.size()) / 2};
Expand All @@ -23,16 +25,18 @@ bool scatterv_test(const T &val) {
++t_val;
std::iota(begin(v_expected), end(v_expected), t_val);
mpl::vector_layout<T> layout(comm_world.rank() + 1);
if (comm_world.rank() == 0) {
if constexpr (variant == use_non_root_overload::yes) {
if (comm_world.rank() == 0)
comm_world.scatterv(0, v_scatter.data(), layouts, v_recv.data(), layout);
else
comm_world.scatterv(0, v_recv.data(), layout);
} else
comm_world.scatterv(0, v_scatter.data(), layouts, v_recv.data(), layout);
} else {
comm_world.scatterv(0, v_recv.data(), layout);
}
return v_recv == v_expected;
}


template<typename T>
template<use_non_root_overload variant, typename T>
bool iscatterv_test(const T &val) {
const mpl::communicator &comm_world{mpl::environment::comm_world()};
const int N{(comm_world.size() * comm_world.size() + comm_world.size()) / 2};
Expand All @@ -51,21 +55,32 @@ bool iscatterv_test(const T &val) {
++t_val;
std::iota(begin(v_expected), end(v_expected), t_val);
mpl::vector_layout<T> layout(comm_world.rank() + 1);
if (comm_world.rank() == 0) {
auto r{comm_world.iscatterv(0, v_scatter.data(), layouts, v_recv.data(), layout)};
r.wait();
if constexpr (variant == use_non_root_overload::yes) {
if (comm_world.rank() == 0) {
auto r{comm_world.iscatterv(0, v_scatter.data(), layouts, v_recv.data(), layout)};
r.wait();
} else {
auto r{comm_world.iscatterv(0, v_recv.data(), layout)};
r.wait();
}
} else {
auto r{comm_world.iscatterv(0, v_recv.data(), layout)};
r.wait();
auto r{comm_world.iscatterv(0, v_scatter.data(), layouts, v_recv.data(), layout)};
r.wait();
}
return v_recv == v_expected;
}


BOOST_AUTO_TEST_CASE(scatterv) {
BOOST_TEST(scatterv_test(1.0));
BOOST_TEST(scatterv_test(tuple{1, 2.0}));
BOOST_TEST(scatterv_test<use_non_root_overload::no>(1.0));
BOOST_TEST(scatterv_test<use_non_root_overload::no>(tuple{1, 2.0}));

BOOST_TEST(scatterv_test<use_non_root_overload::yes>(1.0));
BOOST_TEST(scatterv_test<use_non_root_overload::yes>(tuple{1, 2.0}));

BOOST_TEST(iscatterv_test<use_non_root_overload::no>(1.0));
BOOST_TEST(iscatterv_test<use_non_root_overload::no>(tuple{1, 2.0}));

BOOST_TEST(iscatterv_test(1.0));
BOOST_TEST(iscatterv_test(tuple{1, 2.0}));
BOOST_TEST(iscatterv_test<use_non_root_overload::yes>(1.0));
BOOST_TEST(iscatterv_test<use_non_root_overload::yes>(tuple{1, 2.0}));
}
3 changes: 3 additions & 0 deletions test/test_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@ class add {
T operator()(const T &a, const T &b) const { return a + b; }
};


enum class use_non_root_overload { no, yes };

#endif // MPL_TEST_HELPER_HPP

0 comments on commit 6f7c044

Please sign in to comment.