Skip to content

Commit

Permalink
add a work around for source (or destination) == self case for isend/…
Browse files Browse the repository at this point in the history
…irecv
  • Loading branch information
seunghwak committed Oct 1, 2020
1 parent 252f660 commit 2be9e5f
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 106 deletions.
175 changes: 105 additions & 70 deletions cpp/include/patterns/copy_to_adj_matrix_row_col.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -194,28 +194,41 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
// partitioning
auto comm_src_rank = row_comm_rank * col_comm_size + col_comm_rank;
auto comm_dst_rank = (comm_rank % col_comm_size) * row_comm_size + comm_rank / col_comm_size;
auto constexpr tuple_size = thrust_tuple_size_or_one<
typename std::iterator_traits<VertexValueInputIterator>::value_type>::value;
std::vector<raft::comms::request_t> requests(2 * tuple_size);
device_isend<VertexValueInputIterator, MatrixMinorValueOutputIterator>(
comm,
vertex_value_input_first,
static_cast<size_t>(graph_view.get_number_of_local_vertices()),
comm_dst_rank,
int{0} /* base_tag */,
requests.data());
device_irecv<VertexValueInputIterator, MatrixMinorValueOutputIterator>(
comm,
matrix_minor_value_output_first +
(graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size + col_comm_rank) -
graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size)),
static_cast<size_t>(graph_view.get_vertex_partition_size(comm_src_rank)),
comm_src_rank,
int{0} /* base_tag */,
requests.data() + tuple_size);
// FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a
// zip iterator having one or more discard iterator
comm.waitall(requests.size(), requests.data());
// FIXME: it seems like raft::isend and raft::irecv do not properly handle the destination (or
// source) == self case. Need to double check and fix this if this is indeed the case (or RAFT
// may use ncclSend/ncclRecv instead of UCX for device data).
if (comm_src_rank == comm_rank) {
assert(comm_dst_rank == comm_rank);
thrust::copy(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
vertex_value_input_first,
vertex_value_input_first + graph_view.get_number_of_local_vertices(),
matrix_minor_value_output_first +
(graph_view.get_vertex_partition_first(comm_src_rank) -
graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size)));
} else {
auto constexpr tuple_size = thrust_tuple_size_or_one<
typename std::iterator_traits<VertexValueInputIterator>::value_type>::value;
std::vector<raft::comms::request_t> requests(2 * tuple_size);
device_isend<VertexValueInputIterator, MatrixMinorValueOutputIterator>(
comm,
vertex_value_input_first,
static_cast<size_t>(graph_view.get_number_of_local_vertices()),
comm_dst_rank,
int{0} /* base_tag */,
requests.data());
device_irecv<VertexValueInputIterator, MatrixMinorValueOutputIterator>(
comm,
matrix_minor_value_output_first +
(graph_view.get_vertex_partition_first(comm_src_rank) -
graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size)),
static_cast<size_t>(graph_view.get_vertex_partition_size(comm_src_rank)),
comm_src_rank,
int{0} /* base_tag */,
requests.data() + tuple_size);
// FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a
// zip iterator having one or more discard iterator
comm.waitall(requests.size(), requests.data());
}

// FIXME: these broadcast operations can be placed between ncclGroupStart() and
// ncclGroupEnd()
Expand Down Expand Up @@ -272,63 +285,85 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
// hypergraph partitioning is applied or not
auto comm_src_rank = row_comm_rank * col_comm_size + col_comm_rank;
auto comm_dst_rank = (comm_rank % col_comm_size) * row_comm_size + comm_rank / col_comm_size;
auto constexpr tuple_size = thrust_tuple_size_or_one<
typename std::iterator_traits<VertexValueInputIterator>::value_type>::value;

std::vector<raft::comms::request_t> count_requests(2);
auto tx_count = thrust::distance(vertex_first, vertex_last);
auto rx_count = tx_count;
comm.isend(&tx_count, 1, comm_dst_rank, 0 /* tag */, count_requests.data());
comm.irecv(&rx_count, 1, comm_src_rank, 0 /* tag */, count_requests.data() + 1);
comm.waitall(count_requests.size(), count_requests.data());

auto src_tmp_buffer =
allocate_comm_buffer<typename std::iterator_traits<VertexValueInputIterator>::value_type>(
tx_count, handle.get_stream());
auto src_value_first =
get_comm_buffer_begin<typename std::iterator_traits<VertexValueInputIterator>::value_type>(
src_tmp_buffer);

size_t tx_count = thrust::distance(vertex_first, vertex_last);
size_t rx_count{};
// FIXME: it seems like raft::isend and raft::irecv do not properly handle the destination (or
// source) == self case. Need to double check and fix this if this is indeed the case (or RAFT
// may use ncclSend/ncclRecv instead of UCX for device data).
if (comm_src_rank == comm_rank) {
assert(comm_dst_rank == comm_rank);
rx_count = tx_count;
} else {
std::vector<raft::comms::request_t> count_requests(2);
comm.isend(&tx_count, 1, comm_dst_rank, 0 /* tag */, count_requests.data());
comm.irecv(&rx_count, 1, comm_src_rank, 0 /* tag */, count_requests.data() + 1);
comm.waitall(count_requests.size(), count_requests.data());
}
rmm::device_uvector<vertex_t> dst_vertices(rx_count, handle.get_stream());
auto dst_tmp_buffer =
allocate_comm_buffer<typename std::iterator_traits<VertexValueInputIterator>::value_type>(
rx_count, handle.get_stream());
auto dst_value_first =
get_comm_buffer_begin<typename std::iterator_traits<VertexValueInputIterator>::value_type>(
dst_tmp_buffer);

thrust::gather(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
if (comm_src_rank == comm_rank) {
thrust::copy(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
vertex_first,
vertex_last,
vertex_value_input_first,
src_value_first);

std::vector<raft::comms::request_t> value_requests(2 * (1 + tuple_size));
device_isend<decltype(vertex_first), decltype(dst_vertices.begin())>(
comm, vertex_first, tx_count, comm_dst_rank, int{0} /* base_tag */, value_requests.data());
device_isend<decltype(src_value_first), decltype(dst_value_first)>(comm,
src_value_first,
tx_count,
comm_dst_rank,
int{1} /* base_tag */,
value_requests.data() + 1);
device_irecv<decltype(vertex_first), decltype(dst_vertices.begin())>(
comm,
dst_vertices.begin(),
rx_count,
comm_src_rank,
int{0} /* base_tag */,
value_requests.data() + (1 + tuple_size));
device_irecv<decltype(src_value_first), decltype(dst_value_first)>(
comm,
dst_value_first,
rx_count,
comm_src_rank,
int{0} /* base_tag */,
value_requests.data() + ((1 + tuple_size) + 1));
// FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a
// zip iterator having one or more discard iterator
comm.waitall(value_requests.size(), value_requests.data());
dst_vertices.begin());
thrust::gather(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
vertex_first,
vertex_last,
vertex_value_input_first,
dst_value_first);
} else {
auto constexpr tuple_size = thrust_tuple_size_or_one<
typename std::iterator_traits<VertexValueInputIterator>::value_type>::value;

auto src_tmp_buffer =
allocate_comm_buffer<typename std::iterator_traits<VertexValueInputIterator>::value_type>(
tx_count, handle.get_stream());
auto src_value_first = get_comm_buffer_begin<
typename std::iterator_traits<VertexValueInputIterator>::value_type>(src_tmp_buffer);

thrust::gather(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
vertex_first,
vertex_last,
vertex_value_input_first,
src_value_first);

std::vector<raft::comms::request_t> value_requests(2 * (1 + tuple_size));
device_isend<decltype(vertex_first), decltype(dst_vertices.begin())>(comm,
vertex_first,
tx_count,
comm_dst_rank,
int{0} /* base_tag */,
value_requests.data());
device_isend<decltype(src_value_first), decltype(dst_value_first)>(
comm,
src_value_first,
tx_count,
comm_dst_rank,
int{1} /* base_tag */,
value_requests.data() + 1);
device_irecv<decltype(vertex_first), decltype(dst_vertices.begin())>(
comm,
dst_vertices.begin(),
rx_count,
comm_src_rank,
int{0} /* base_tag */,
value_requests.data() + (1 + tuple_size));
device_irecv<decltype(src_value_first), decltype(dst_value_first)>(
comm,
dst_value_first,
rx_count,
comm_src_rank,
int{0} /* base_tag */,
value_requests.data() + ((1 + tuple_size) + 1));
// FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a
// zip iterator having one or more discard iterator
comm.waitall(value_requests.size(), value_requests.data());
}

// FIXME: now we can clear tx_tmp_buffer

Expand Down
122 changes: 86 additions & 36 deletions cpp/include/patterns/update_frontier_v_push_if_out_nbr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#pragma once

#include <cstdlib>
#include <experimental/graph_view.hpp>
#include <matrix_partition_device.cuh>
#include <partition_manager.hpp>
Expand All @@ -37,6 +38,8 @@
#include <thrust/type_traits/integer_sequence.h>
#include <cub/cub.cuh>

#include <algorithm>
#include <limits>
#include <numeric>
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -501,6 +504,7 @@ void update_frontier_v_push_if_out_nbr(

if (GraphViewType::is_multi_gpu) {
auto& comm = handle.get_comms();
auto const comm_rank = comm.get_rank();
auto& row_comm = handle.get_subcomm(cugraph::partition_2d::key_naming_t().row_name());
auto const row_comm_rank = row_comm.get_rank();
auto const row_comm_size = row_comm.get_size();
Expand Down Expand Up @@ -537,22 +541,41 @@ void update_frontier_v_push_if_out_nbr(
std::vector<edge_t> rx_counts(graph_view.is_hypergraph_partitioned() ? row_comm_size
: col_comm_size);
std::vector<raft::comms::request_t> count_requests(tx_counts.size() + rx_counts.size());
size_t tx_self_i = std::numeric_limits<size_t>::max();
for (size_t i = 0; i < tx_counts.size(); ++i) {
comm.isend(&tx_counts[i],
1,
graph_view.is_hypergraph_partitioned() ? col_comm_rank * row_comm_size + i
: row_comm_rank * col_comm_size + i,
0 /* tag */,
count_requests.data() + i);
auto comm_dst_rank = graph_view.is_hypergraph_partitioned()
? col_comm_rank * row_comm_size + static_cast<int>(i)
: row_comm_rank * col_comm_size + static_cast<int>(i);
if (comm_dst_rank == comm_rank) {
tx_self_i = i;
// FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms
count_requests[i] = std::numeric_limits<raft::comms::request_t>::max();
} else {
comm.isend(&tx_counts[i], 1, comm_dst_rank, 0 /* tag */, count_requests.data() + i);
}
}
for (size_t i = 0; i < rx_counts.size(); ++i) {
comm.irecv(&rx_counts[i],
1,
graph_view.is_hypergraph_partitioned() ? col_comm_rank * row_comm_size + i
: row_comm_rank + i * row_comm_size,
0 /* tag */,
count_requests.data() + tx_counts.size() + i);
auto comm_src_rank = graph_view.is_hypergraph_partitioned()
? col_comm_rank * row_comm_size + static_cast<int>(i)
: row_comm_rank + static_cast<int>(i) * row_comm_size;
if (comm_src_rank == comm_rank) {
assert(self_tx_i != std::numeric_limits<size_t>::max());
rx_counts[i] = tx_counts[tx_self_i];
// FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms
count_requests[tx_counts.size() + i] = std::numeric_limits<raft::comms::request_t>::max();
} else {
comm.irecv(&rx_counts[i],
1,
comm_src_rank,
0 /* tag */,
count_requests.data() + tx_counts.size() + i);
}
}
// FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms, if
// raft::comms::wait immediately returns on seeing request_null, this remove is unnecessary
std::remove(count_requests.begin(),
count_requests.end(),
std::numeric_limits<raft::comms::request_t>::max());
comm.waitall(count_requests.size(), count_requests.data());

std::vector<edge_t> tx_offsets(tx_counts.size() + 1, edge_t{0});
Expand All @@ -577,36 +600,63 @@ void update_frontier_v_push_if_out_nbr(
auto comm_dst_rank = graph_view.is_hypergraph_partitioned()
? col_comm_rank * row_comm_size + i
: row_comm_rank * col_comm_size + i;
comm.isend(detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[i]),
static_cast<size_t>(tx_counts[i]),
comm_dst_rank,
int{0} /* tag */,
buffer_requests.data() + i * (1 + tuple_size));
device_isend<decltype(buffer_payload_first), decltype(buffer_payload_first)>(
comm,
buffer_payload_first + tx_offsets[i],
static_cast<size_t>(tx_counts[i]),
comm_dst_rank,
int{1} /* base tag */,
buffer_requests.data() + (i * (1 + tuple_size) + 1));
if (comm_dst_rank == comm_rank) {
assert(i == tx_self_i);
// FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms
std::fill(buffer_requests.data() + i * (1 + tuple_size),
buffer_requests.data() + (i + 1) * (1 + tuple_size),
std::numeric_limits<raft::comms::request_t>::max());
} else {
comm.isend(detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[i]),
static_cast<size_t>(tx_counts[i]),
comm_dst_rank,
int{0} /* tag */,
buffer_requests.data() + i * (1 + tuple_size));
device_isend<decltype(buffer_payload_first), decltype(buffer_payload_first)>(
comm,
buffer_payload_first + tx_offsets[i],
static_cast<size_t>(tx_counts[i]),
comm_dst_rank,
int{1} /* base tag */,
buffer_requests.data() + (i * (1 + tuple_size) + 1));
}
}
for (size_t i = 0; i < rx_counts.size(); ++i) {
auto comm_src_rank = graph_view.is_hypergraph_partitioned()
? col_comm_rank * row_comm_size + i
: row_comm_rank + i * row_comm_size;
comm.irecv(detail::iter_to_raw_ptr(buffer_key_first + num_buffer_elements + rx_offsets[i]),
static_cast<size_t>(rx_counts[i]),
comm_src_rank,
int{0} /* tag */,
buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size)));
device_irecv<decltype(buffer_payload_first), decltype(buffer_payload_first)>(
comm,
buffer_payload_first + num_buffer_elements + rx_offsets[i],
static_cast<size_t>(rx_counts[i]),
comm_src_rank,
int{1} /* base tag */,
buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size) + 1));
if (comm_src_rank == comm_rank) {
assert(self_tx_i != std::numeric_limits<size_t>::max());
assert(rx_counts[i] == tx_counts[tx_self_i]);
thrust::copy(
rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[tx_self_i]),
detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[tx_self_i] + tx_counts[tx_self_i]),
detail::iter_to_raw_ptr(buffer_key_first + num_buffer_elements + rx_offsets[i]));
// FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms
std::fill(buffer_requests.data() + (tx_counts.size() + i) * (1 + tuple_size),
buffer_requests.data() + (tx_counts.size() + i + 1) * (1 + tuple_size),
std::numeric_limits<raft::comms::request_t>::max());
} else {
comm.irecv(detail::iter_to_raw_ptr(buffer_key_first + num_buffer_elements + rx_offsets[i]),
static_cast<size_t>(rx_counts[i]),
comm_src_rank,
int{0} /* tag */,
buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size)));
device_irecv<decltype(buffer_payload_first), decltype(buffer_payload_first)>(
comm,
buffer_payload_first + num_buffer_elements + rx_offsets[i],
static_cast<size_t>(rx_counts[i]),
comm_src_rank,
int{1} /* base tag */,
buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size) + 1));
}
}
// FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms, if
// raft::comms::wait immediately returns on seeing request_null, this remove is unnecessary
std::remove(buffer_requests.begin(),
buffer_requests.end(),
std::numeric_limits<raft::comms::request_t>::max());
comm.waitall(buffer_requests.size(), buffer_requests.data());

// FIXME: this does not exploit the fact that each segment is sorted. Lost performance
Expand Down

0 comments on commit 2be9e5f

Please sign in to comment.