From 2be9e5f9a016d5884423b6e2b59e43ed646cde07 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Thu, 1 Oct 2020 01:53:08 -0400 Subject: [PATCH] add a work around for source (or destination) == self case for isend/irecv --- .../patterns/copy_to_adj_matrix_row_col.cuh | 175 +++++++++++------- .../update_frontier_v_push_if_out_nbr.cuh | 122 ++++++++---- 2 files changed, 191 insertions(+), 106 deletions(-) diff --git a/cpp/include/patterns/copy_to_adj_matrix_row_col.cuh b/cpp/include/patterns/copy_to_adj_matrix_row_col.cuh index e8e11b85913..0aac0c0d053 100644 --- a/cpp/include/patterns/copy_to_adj_matrix_row_col.cuh +++ b/cpp/include/patterns/copy_to_adj_matrix_row_col.cuh @@ -194,28 +194,41 @@ void copy_to_matrix_minor(raft::handle_t const& handle, // partitioning auto comm_src_rank = row_comm_rank * col_comm_size + col_comm_rank; auto comm_dst_rank = (comm_rank % col_comm_size) * row_comm_size + comm_rank / col_comm_size; - auto constexpr tuple_size = thrust_tuple_size_or_one< - typename std::iterator_traits::value_type>::value; - std::vector requests(2 * tuple_size); - device_isend( - comm, - vertex_value_input_first, - static_cast(graph_view.get_number_of_local_vertices()), - comm_dst_rank, - int{0} /* base_tag */, - requests.data()); - device_irecv( - comm, - matrix_minor_value_output_first + - (graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size + col_comm_rank) - - graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size)), - static_cast(graph_view.get_vertex_partition_size(comm_src_rank)), - comm_src_rank, - int{0} /* base_tag */, - requests.data() + tuple_size); - // FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a - // zip iterator having one or more discard iterator - comm.waitall(requests.size(), requests.data()); + // FIXME: it seems like raft::isend and raft::irecv do not properly handle the destination (or + // source) == self case. Need to double check and fix this if this is indeed the case (or RAFT + // may use ncclSend/ncclRecv instead of UCX for device data). + if (comm_src_rank == comm_rank) { + assert(comm_dst_rank == comm_rank); + thrust::copy(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()), + vertex_value_input_first, + vertex_value_input_first + graph_view.get_number_of_local_vertices(), + matrix_minor_value_output_first + + (graph_view.get_vertex_partition_first(comm_src_rank) - + graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size))); + } else { + auto constexpr tuple_size = thrust_tuple_size_or_one< + typename std::iterator_traits::value_type>::value; + std::vector requests(2 * tuple_size); + device_isend( + comm, + vertex_value_input_first, + static_cast(graph_view.get_number_of_local_vertices()), + comm_dst_rank, + int{0} /* base_tag */, + requests.data()); + device_irecv( + comm, + matrix_minor_value_output_first + + (graph_view.get_vertex_partition_first(comm_src_rank) - + graph_view.get_vertex_partition_first(row_comm_rank * col_comm_size)), + static_cast(graph_view.get_vertex_partition_size(comm_src_rank)), + comm_src_rank, + int{0} /* base_tag */, + requests.data() + tuple_size); + // FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a + // zip iterator having one or more discard iterator + comm.waitall(requests.size(), requests.data()); + } // FIXME: these broadcast operations can be placed between ncclGroupStart() and // ncclGroupEnd() @@ -272,23 +285,20 @@ void copy_to_matrix_minor(raft::handle_t const& handle, // hypergraph partitioning is applied or not auto comm_src_rank = row_comm_rank * col_comm_size + col_comm_rank; auto comm_dst_rank = (comm_rank % col_comm_size) * row_comm_size + comm_rank / col_comm_size; - auto constexpr tuple_size = thrust_tuple_size_or_one< - typename std::iterator_traits::value_type>::value; - - std::vector count_requests(2); - auto tx_count = thrust::distance(vertex_first, vertex_last); - auto rx_count = tx_count; - comm.isend(&tx_count, 1, comm_dst_rank, 0 /* tag */, count_requests.data()); - comm.irecv(&rx_count, 1, comm_src_rank, 0 /* tag */, count_requests.data() + 1); - comm.waitall(count_requests.size(), count_requests.data()); - - auto src_tmp_buffer = - allocate_comm_buffer::value_type>( - tx_count, handle.get_stream()); - auto src_value_first = - get_comm_buffer_begin::value_type>( - src_tmp_buffer); - + size_t tx_count = thrust::distance(vertex_first, vertex_last); + size_t rx_count{}; + // FIXME: it seems like raft::isend and raft::irecv do not properly handle the destination (or + // source) == self case. Need to double check and fix this if this is indeed the case (or RAFT + // may use ncclSend/ncclRecv instead of UCX for device data). + if (comm_src_rank == comm_rank) { + assert(comm_dst_rank == comm_rank); + rx_count = tx_count; + } else { + std::vector count_requests(2); + comm.isend(&tx_count, 1, comm_dst_rank, 0 /* tag */, count_requests.data()); + comm.irecv(&rx_count, 1, comm_src_rank, 0 /* tag */, count_requests.data() + 1); + comm.waitall(count_requests.size(), count_requests.data()); + } rmm::device_uvector dst_vertices(rx_count, handle.get_stream()); auto dst_tmp_buffer = allocate_comm_buffer::value_type>( @@ -296,39 +306,64 @@ void copy_to_matrix_minor(raft::handle_t const& handle, auto dst_value_first = get_comm_buffer_begin::value_type>( dst_tmp_buffer); - - thrust::gather(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()), + if (comm_src_rank == comm_rank) { + thrust::copy(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()), vertex_first, vertex_last, - vertex_value_input_first, - src_value_first); - - std::vector value_requests(2 * (1 + tuple_size)); - device_isend( - comm, vertex_first, tx_count, comm_dst_rank, int{0} /* base_tag */, value_requests.data()); - device_isend(comm, - src_value_first, - tx_count, - comm_dst_rank, - int{1} /* base_tag */, - value_requests.data() + 1); - device_irecv( - comm, - dst_vertices.begin(), - rx_count, - comm_src_rank, - int{0} /* base_tag */, - value_requests.data() + (1 + tuple_size)); - device_irecv( - comm, - dst_value_first, - rx_count, - comm_src_rank, - int{0} /* base_tag */, - value_requests.data() + ((1 + tuple_size) + 1)); - // FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a - // zip iterator having one or more discard iterator - comm.waitall(value_requests.size(), value_requests.data()); + dst_vertices.begin()); + thrust::gather(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()), + vertex_first, + vertex_last, + vertex_value_input_first, + dst_value_first); + } else { + auto constexpr tuple_size = thrust_tuple_size_or_one< + typename std::iterator_traits::value_type>::value; + + auto src_tmp_buffer = + allocate_comm_buffer::value_type>( + tx_count, handle.get_stream()); + auto src_value_first = get_comm_buffer_begin< + typename std::iterator_traits::value_type>(src_tmp_buffer); + + thrust::gather(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()), + vertex_first, + vertex_last, + vertex_value_input_first, + src_value_first); + + std::vector value_requests(2 * (1 + tuple_size)); + device_isend(comm, + vertex_first, + tx_count, + comm_dst_rank, + int{0} /* base_tag */, + value_requests.data()); + device_isend( + comm, + src_value_first, + tx_count, + comm_dst_rank, + int{1} /* base_tag */, + value_requests.data() + 1); + device_irecv( + comm, + dst_vertices.begin(), + rx_count, + comm_src_rank, + int{0} /* base_tag */, + value_requests.data() + (1 + tuple_size)); + device_irecv( + comm, + dst_value_first, + rx_count, + comm_src_rank, + int{0} /* base_tag */, + value_requests.data() + ((1 + tuple_size) + 1)); + // FIXME: this waitall can fail if MatrixMinorValueOutputIterator is a discard iterator or a + // zip iterator having one or more discard iterator + comm.waitall(value_requests.size(), value_requests.data()); + } // FIXME: now we can clear tx_tmp_buffer diff --git a/cpp/include/patterns/update_frontier_v_push_if_out_nbr.cuh b/cpp/include/patterns/update_frontier_v_push_if_out_nbr.cuh index a1d18e26d1c..838bc46dd71 100644 --- a/cpp/include/patterns/update_frontier_v_push_if_out_nbr.cuh +++ b/cpp/include/patterns/update_frontier_v_push_if_out_nbr.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include #include @@ -37,6 +38,8 @@ #include #include +#include +#include #include #include #include @@ -501,6 +504,7 @@ void update_frontier_v_push_if_out_nbr( if (GraphViewType::is_multi_gpu) { auto& comm = handle.get_comms(); + auto const comm_rank = comm.get_rank(); auto& row_comm = handle.get_subcomm(cugraph::partition_2d::key_naming_t().row_name()); auto const row_comm_rank = row_comm.get_rank(); auto const row_comm_size = row_comm.get_size(); @@ -537,22 +541,41 @@ void update_frontier_v_push_if_out_nbr( std::vector rx_counts(graph_view.is_hypergraph_partitioned() ? row_comm_size : col_comm_size); std::vector count_requests(tx_counts.size() + rx_counts.size()); + size_t tx_self_i = std::numeric_limits::max(); for (size_t i = 0; i < tx_counts.size(); ++i) { - comm.isend(&tx_counts[i], - 1, - graph_view.is_hypergraph_partitioned() ? col_comm_rank * row_comm_size + i - : row_comm_rank * col_comm_size + i, - 0 /* tag */, - count_requests.data() + i); + auto comm_dst_rank = graph_view.is_hypergraph_partitioned() + ? col_comm_rank * row_comm_size + static_cast(i) + : row_comm_rank * col_comm_size + static_cast(i); + if (comm_dst_rank == comm_rank) { + tx_self_i = i; + // FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms + count_requests[i] = std::numeric_limits::max(); + } else { + comm.isend(&tx_counts[i], 1, comm_dst_rank, 0 /* tag */, count_requests.data() + i); + } } for (size_t i = 0; i < rx_counts.size(); ++i) { - comm.irecv(&rx_counts[i], - 1, - graph_view.is_hypergraph_partitioned() ? col_comm_rank * row_comm_size + i - : row_comm_rank + i * row_comm_size, - 0 /* tag */, - count_requests.data() + tx_counts.size() + i); + auto comm_src_rank = graph_view.is_hypergraph_partitioned() + ? col_comm_rank * row_comm_size + static_cast(i) + : row_comm_rank + static_cast(i) * row_comm_size; + if (comm_src_rank == comm_rank) { + assert(self_tx_i != std::numeric_limits::max()); + rx_counts[i] = tx_counts[tx_self_i]; + // FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms + count_requests[tx_counts.size() + i] = std::numeric_limits::max(); + } else { + comm.irecv(&rx_counts[i], + 1, + comm_src_rank, + 0 /* tag */, + count_requests.data() + tx_counts.size() + i); + } } + // FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms, if + // raft::comms::wait immediately returns on seeing request_null, this remove is unnecessary + std::remove(count_requests.begin(), + count_requests.end(), + std::numeric_limits::max()); comm.waitall(count_requests.size(), count_requests.data()); std::vector tx_offsets(tx_counts.size() + 1, edge_t{0}); @@ -577,36 +600,63 @@ void update_frontier_v_push_if_out_nbr( auto comm_dst_rank = graph_view.is_hypergraph_partitioned() ? col_comm_rank * row_comm_size + i : row_comm_rank * col_comm_size + i; - comm.isend(detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[i]), - static_cast(tx_counts[i]), - comm_dst_rank, - int{0} /* tag */, - buffer_requests.data() + i * (1 + tuple_size)); - device_isend( - comm, - buffer_payload_first + tx_offsets[i], - static_cast(tx_counts[i]), - comm_dst_rank, - int{1} /* base tag */, - buffer_requests.data() + (i * (1 + tuple_size) + 1)); + if (comm_dst_rank == comm_rank) { + assert(i == tx_self_i); + // FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms + std::fill(buffer_requests.data() + i * (1 + tuple_size), + buffer_requests.data() + (i + 1) * (1 + tuple_size), + std::numeric_limits::max()); + } else { + comm.isend(detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[i]), + static_cast(tx_counts[i]), + comm_dst_rank, + int{0} /* tag */, + buffer_requests.data() + i * (1 + tuple_size)); + device_isend( + comm, + buffer_payload_first + tx_offsets[i], + static_cast(tx_counts[i]), + comm_dst_rank, + int{1} /* base tag */, + buffer_requests.data() + (i * (1 + tuple_size) + 1)); + } } for (size_t i = 0; i < rx_counts.size(); ++i) { auto comm_src_rank = graph_view.is_hypergraph_partitioned() ? col_comm_rank * row_comm_size + i : row_comm_rank + i * row_comm_size; - comm.irecv(detail::iter_to_raw_ptr(buffer_key_first + num_buffer_elements + rx_offsets[i]), - static_cast(rx_counts[i]), - comm_src_rank, - int{0} /* tag */, - buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size))); - device_irecv( - comm, - buffer_payload_first + num_buffer_elements + rx_offsets[i], - static_cast(rx_counts[i]), - comm_src_rank, - int{1} /* base tag */, - buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size) + 1)); + if (comm_src_rank == comm_rank) { + assert(self_tx_i != std::numeric_limits::max()); + assert(rx_counts[i] == tx_counts[tx_self_i]); + thrust::copy( + rmm::exec_policy(handle.get_stream())->on(handle.get_stream()), + detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[tx_self_i]), + detail::iter_to_raw_ptr(buffer_key_first + tx_offsets[tx_self_i] + tx_counts[tx_self_i]), + detail::iter_to_raw_ptr(buffer_key_first + num_buffer_elements + rx_offsets[i])); + // FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms + std::fill(buffer_requests.data() + (tx_counts.size() + i) * (1 + tuple_size), + buffer_requests.data() + (tx_counts.size() + i + 1) * (1 + tuple_size), + std::numeric_limits::max()); + } else { + comm.irecv(detail::iter_to_raw_ptr(buffer_key_first + num_buffer_elements + rx_offsets[i]), + static_cast(rx_counts[i]), + comm_src_rank, + int{0} /* tag */, + buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size))); + device_irecv( + comm, + buffer_payload_first + num_buffer_elements + rx_offsets[i], + static_cast(rx_counts[i]), + comm_src_rank, + int{1} /* base tag */, + buffer_requests.data() + ((tx_counts.size() + i) * (1 + tuple_size) + 1)); + } } + // FIXME: better define request_null (similar to MPI_REQUEST_NULL) under raft::comms, if + // raft::comms::wait immediately returns on seeing request_null, this remove is unnecessary + std::remove(buffer_requests.begin(), + buffer_requests.end(), + std::numeric_limits::max()); comm.waitall(buffer_requests.size(), buffer_requests.data()); // FIXME: this does not exploit the fact that each segment is sorted. Lost performance