diff --git a/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh b/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh index dd0da77851..79eab41ab2 100644 --- a/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh +++ b/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh @@ -336,8 +336,8 @@ std::tuple std::vector> compute_unique_keys(raft::handle_t const& handle, KeyIterator aggregate_local_frontier_key_first, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes) + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes) { using key_t = typename thrust::iterator_traits::value_type; @@ -411,8 +411,8 @@ std::tuple, rmm::device_uvector> compute_frontier_value_sums_and_partitioned_local_value_sum_displacements( raft::handle_t const& handle, raft::device_span aggregate_local_frontier_local_value_sums, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes) + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); auto minor_comm_rank = minor_comm.get_rank(); @@ -453,8 +453,8 @@ compute_valid_local_nbr_count_inclusive_sums( raft::handle_t const& handle, GraphViewType const& graph_view, VertexIterator aggregate_local_frontier_major_first, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes) + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes) { using vertex_t = typename GraphViewType::vertex_type; using edge_t = typename GraphViewType::edge_type; @@ -1237,8 +1237,8 @@ compute_aggregate_local_frontier_local_degrees( raft::handle_t const& handle, GraphViewType const& graph_view, VertexIterator aggregate_local_frontier_major_first, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes) + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes) { using vertex_t = typename GraphViewType::vertex_type; using edge_t = typename GraphViewType::edge_type; @@ -1307,8 +1307,8 @@ compute_aggregate_local_frontier_biases(raft::handle_t const& handle, EdgeDstValueInputWrapper edge_dst_value_input, EdgeValueInputWrapper edge_value_input, EdgeBiasOp e_bias_op, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes, + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes, bool do_expensive_check) { using vertex_t = typename GraphViewType::vertex_type; @@ -1460,286 +1460,32 @@ shuffle_and_compute_local_nbr_values(raft::handle_t const& handle, std::move(local_frontier_sample_offsets)); } -// skip conversion if local neighbor index is cugraph::invalid_edge_id_v -template -rmm::device_uvector convert_to_unmasked_local_nbr_idx( - raft::handle_t const& handle, - GraphViewType const& graph_view, - VertexIterator aggregate_local_frontier_major_first, - rmm::device_uvector&& local_nbr_indices, - std::optional> key_indices, - std::vector const& local_frontier_sample_offsets, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes, - size_t K) -{ - using vertex_t = typename GraphViewType::vertex_type; - using edge_t = typename GraphViewType::edge_type; - static_assert( - std::is_same_v::value_type>); - - auto edge_mask_view = graph_view.edge_mask_view(); - - auto [aggregate_local_frontier_unique_majors, - aggregate_local_frontier_major_idx_to_unique_major_idx, - local_frontier_unique_major_displacements, - local_frontier_unique_major_sizes] = - compute_unique_keys(handle, - aggregate_local_frontier_major_first, - local_frontier_displacements, - local_frontier_sizes); - - // to avoid searching the entire neighbor list K times for high degree vertices with edge masking - auto local_frontier_unique_major_valid_local_nbr_count_inclusive_sums = - compute_valid_local_nbr_count_inclusive_sums(handle, - graph_view, - aggregate_local_frontier_unique_majors.begin(), - local_frontier_unique_major_displacements, - local_frontier_unique_major_sizes); - - auto sample_major_idx_first = thrust::make_transform_iterator( - thrust::make_counting_iterator(size_t{0}), - cuda::proclaim_return_type( - [K, - key_indices = key_indices ? thrust::make_optional>( - (*key_indices).data(), (*key_indices).size()) - : thrust::nullopt] __device__(size_t i) { - return key_indices ? (*key_indices)[i] : i / K; - })); - auto pair_first = thrust::make_zip_iterator(local_nbr_indices.begin(), sample_major_idx_first); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { - auto edge_partition = - edge_partition_device_view_t( - graph_view.local_edge_partition_view(i)); - auto edge_partition_e_mask = - edge_mask_view - ? thrust::make_optional< - detail::edge_partition_edge_property_device_view_t>( - *edge_mask_view, i) - : thrust::nullopt; - - auto edge_partition_frontier_major_first = - aggregate_local_frontier_major_first + local_frontier_displacements[i]; - thrust::transform_if( - handle.get_thrust_policy(), - pair_first + local_frontier_sample_offsets[i], - pair_first + local_frontier_sample_offsets[i + 1], - local_nbr_indices.begin() + local_frontier_sample_offsets[i], - local_nbr_indices.begin() + local_frontier_sample_offsets[i], - find_nth_valid_nbr_idx_t{ - edge_partition, - edge_partition_e_mask, - edge_partition_frontier_major_first, - raft::device_span( - aggregate_local_frontier_major_idx_to_unique_major_idx.data() + - local_frontier_displacements[i], - local_frontier_sizes[i]), - thrust::make_tuple( - raft::device_span( - std::get<0>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]).data(), - std::get<0>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]) - .size()), - raft::device_span( - std::get<1>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]).data(), - std::get<1>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]) - .size()))}, - is_not_equal_t{cugraph::invalid_edge_id_v}); - } - - return std::move(local_nbr_indices); -} - -template -std::tuple, - std::optional>, - std::vector> -uniform_sample_and_compute_local_nbr_indices( +template +std::tuple /* local_nbr_indices */, + std::optional> /* key_indices */, + std::vector /* local_frontier_sample_offsets */> +biased_sample( raft::handle_t const& handle, - GraphViewType const& graph_view, - KeyIterator aggregate_local_frontier_key_first, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes, + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes, + raft::device_span aggregate_local_frontier_key_idx_to_unique_key_idx, + raft::host_span local_frontier_unique_key_displacements, + raft::host_span local_frontier_unique_key_sizes, + raft::device_span aggregate_local_frontier_unique_key_biases, + raft::device_span aggregate_local_frontier_unique_key_local_degree_offsets, raft::random::RngState& rng_state, size_t K, bool with_replacement) { - using edge_t = typename GraphViewType::edge_type; - using vertex_t = typename GraphViewType::vertex_type; - using key_t = typename thrust::iterator_traits::value_type; - - int minor_comm_size{1}; - if constexpr (GraphViewType::is_multi_gpu) { - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - minor_comm_size = minor_comm.get_size(); - } - - auto aggregate_local_frontier_major_first = - thrust_tuple_get_or_identity(aggregate_local_frontier_key_first); - - auto edge_mask_view = graph_view.edge_mask_view(); - - // 1. compute degrees - - rmm::device_uvector frontier_degrees(0, handle.get_stream()); - std::optional> frontier_partitioned_local_degree_displacements{ - std::nullopt}; - { - auto aggregate_local_frontier_local_degrees = - compute_aggregate_local_frontier_local_degrees(handle, - graph_view, - aggregate_local_frontier_major_first, - local_frontier_displacements, - local_frontier_sizes); - - if (minor_comm_size > 1) { - std::tie(frontier_degrees, frontier_partitioned_local_degree_displacements) = - compute_frontier_value_sums_and_partitioned_local_value_sum_displacements( - handle, - raft::device_span(aggregate_local_frontier_local_degrees.data(), - aggregate_local_frontier_local_degrees.size()), - local_frontier_displacements, - local_frontier_sizes); - aggregate_local_frontier_local_degrees.resize(0, handle.get_stream()); - aggregate_local_frontier_local_degrees.shrink_to_fit(handle.get_stream()); - } else { - frontier_degrees = std::move(aggregate_local_frontier_local_degrees); - } - } - - // 2. sample neighbor indices - - rmm::device_uvector nbr_indices(0, handle.get_stream()); - - if (with_replacement) { - if (frontier_degrees.size() > 0) { - nbr_indices.resize(frontier_degrees.size() * K, handle.get_stream()); - cugraph::legacy::ops::graph::get_sampling_index(nbr_indices.data(), - rng_state, - frontier_degrees.data(), - static_cast(frontier_degrees.size()), - static_cast(K), - with_replacement, - handle.get_stream()); - frontier_degrees.resize(0, handle.get_stream()); - frontier_degrees.shrink_to_fit(handle.get_stream()); - } - } else { - nbr_indices = compute_uniform_sampling_index_without_replacement( - handle, std::move(frontier_degrees), rng_state, K); - } - - // 3. shuffle neighbor indices - - auto [local_nbr_indices, key_indices, local_frontier_sample_offsets] = - shuffle_and_compute_local_nbr_values( - handle, - std::move(nbr_indices), - frontier_partitioned_local_degree_displacements - ? std::make_optional>( - (*frontier_partitioned_local_degree_displacements).data(), - (*frontier_partitioned_local_degree_displacements).size()) - : std::nullopt, - K, - cugraph::invalid_edge_id_v); - - // 4. convert neighbor indices in the neighbor list considering edge mask to neighbor indices in - // the neighbor list ignoring edge mask - - if (edge_mask_view) { - local_nbr_indices = convert_to_unmasked_local_nbr_idx( - handle, - graph_view, - aggregate_local_frontier_major_first, - std::move(local_nbr_indices), - key_indices ? std::make_optional>((*key_indices).data(), - (*key_indices).size()) - : std::nullopt, - local_frontier_sample_offsets, - local_frontier_displacements, - local_frontier_sizes, - K); - } - - return std::make_tuple( - std::move(local_nbr_indices), std::move(key_indices), std::move(local_frontier_sample_offsets)); -} - -template -std::tuple, - std::optional>, - std::vector> -biased_sample_and_compute_local_nbr_indices( - raft::handle_t const& handle, - GraphViewType const& graph_view, - KeyIterator aggregate_local_frontier_key_first, - EdgeSrcValueInputWrapper edge_src_value_input, - EdgeDstValueInputWrapper edge_dst_value_input, - EdgeValueInputWrapper edge_value_input, - EdgeBiasOp e_bias_op, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes, - raft::random::RngState& rng_state, - size_t K, - bool with_replacement, - bool do_expensive_check /* check e_bias_op return values */) -{ - using vertex_t = typename GraphViewType::vertex_type; - using edge_t = typename GraphViewType::edge_type; - using key_t = typename thrust::iterator_traits::value_type; - - using bias_t = typename edge_op_result_type::type; - int minor_comm_rank{0}; int minor_comm_size{1}; - if constexpr (GraphViewType::is_multi_gpu) { + if constexpr (multi_gpu) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); minor_comm_rank = minor_comm.get_rank(); minor_comm_size = minor_comm.get_size(); } - assert(minor_comm_size == graph_view.number_of_local_edge_partitions()); - - auto aggregate_local_frontier_major_first = - thrust_tuple_get_or_identity(aggregate_local_frontier_key_first); - - auto edge_mask_view = graph_view.edge_mask_view(); - - // 1. compute biases for unique keys (to reduce memory footprint) - - auto [aggregate_local_frontier_unique_keys, - aggregate_local_frontier_key_idx_to_unique_key_idx, - local_frontier_unique_key_displacements, - local_frontier_unique_key_sizes] = compute_unique_keys(handle, - aggregate_local_frontier_key_first, - local_frontier_displacements, - local_frontier_sizes); - - auto [aggregate_local_frontier_unique_key_biases, - aggregate_local_frontier_unique_key_local_degree_offsets] = - compute_aggregate_local_frontier_biases( - handle, - graph_view, - get_dataframe_buffer_begin(aggregate_local_frontier_unique_keys), - edge_src_value_input, - edge_dst_value_input, - edge_value_input, - e_bias_op, - local_frontier_unique_key_displacements, - local_frontier_unique_key_sizes, - do_expensive_check); - // 2. sample neighbor indices and shuffle neighbor indices + auto num_local_edge_partitions = local_frontier_unique_key_displacements.size(); rmm::device_uvector local_nbr_indices(0, handle.get_stream()); std::optional> key_indices{std::nullopt}; @@ -1756,29 +1502,29 @@ biased_sample_and_compute_local_nbr_indices( offsets.begin() + 1, thrust::upper_bound(thrust::seq, offsets.begin() + 1, offsets.end(), i))); })); + rmm::device_uvector + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums( + aggregate_local_frontier_unique_key_biases.size(), handle.get_stream()); thrust::inclusive_scan_by_key( handle.get_thrust_policy(), unique_key_first, unique_key_first + aggregate_local_frontier_unique_key_biases.size(), - get_dataframe_buffer_begin(aggregate_local_frontier_unique_key_biases), - get_dataframe_buffer_begin(aggregate_local_frontier_unique_key_biases)); - - auto aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums = - std::move(aggregate_local_frontier_unique_key_biases); + aggregate_local_frontier_unique_key_biases.begin(), + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums.begin()); auto aggregate_local_frontier_bias_local_sums = rmm::device_uvector( local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { thrust::tabulate( handle.get_thrust_policy(), get_dataframe_buffer_begin(aggregate_local_frontier_bias_local_sums) + local_frontier_displacements[i], get_dataframe_buffer_begin(aggregate_local_frontier_bias_local_sums) + local_frontier_displacements[i] + local_frontier_sizes[i], - [key_idx_to_unique_key_idx = - raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + - local_frontier_displacements[i], - local_frontier_sizes[i]), + [key_idx_to_unique_key_idx = raft::device_span( + aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), unique_key_local_degree_offsets = raft::device_span( aggregate_local_frontier_unique_key_local_degree_offsets.data() + local_frontier_unique_key_displacements[i], @@ -1841,7 +1587,7 @@ biased_sample_and_compute_local_nbr_indices( rmm::device_uvector sample_local_random_numbers(0, handle.get_stream()); std::tie(sample_local_random_numbers, key_indices, local_frontier_sample_offsets) = - shuffle_and_compute_local_nbr_values( + shuffle_and_compute_local_nbr_values( handle, std::move(sample_random_numbers), frontier_partitioned_bias_local_sum_displacements @@ -1853,7 +1599,7 @@ biased_sample_and_compute_local_nbr_indices( std::numeric_limits::infinity()); local_nbr_indices.resize(sample_local_random_numbers.size(), handle.get_stream()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { thrust::tabulate( handle.get_thrust_policy(), local_nbr_indices.begin() + local_frontier_sample_offsets[i], @@ -1867,12 +1613,12 @@ biased_sample_and_compute_local_nbr_indices( (*key_indices).data() + local_frontier_sample_offsets[i], local_frontier_sample_offsets[i + 1] - local_frontier_sample_offsets[i]) : thrust::nullopt, - key_idx_to_unique_key_idx = - raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + - local_frontier_displacements[i], - local_frontier_sizes[i]), + key_idx_to_unique_key_idx = raft::device_span( + aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums = - raft::device_span( + raft::device_span( aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums.data(), aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums.size()), unique_key_local_degree_offsets = raft::device_span( @@ -1909,17 +1655,17 @@ biased_sample_and_compute_local_nbr_indices( { rmm::device_uvector aggregate_local_frontier_local_degrees( local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { thrust::tabulate( handle.get_thrust_policy(), aggregate_local_frontier_local_degrees.begin() + local_frontier_displacements[i], aggregate_local_frontier_local_degrees.begin() + local_frontier_displacements[i] + local_frontier_sizes[i], - [key_idx_to_unique_key_idx = - raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + - local_frontier_displacements[i], - local_frontier_sizes[i]), - unique_key_local_degree_offsets = raft::device_span( + [key_idx_to_unique_key_idx = raft::device_span( + aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), + unique_key_local_degree_offsets = raft::device_span( aggregate_local_frontier_unique_key_local_degree_offsets.data() + local_frontier_unique_key_displacements[i], local_frontier_unique_key_sizes[i] + 1)] __device__(size_t i) { @@ -1984,16 +1730,16 @@ biased_sample_and_compute_local_nbr_indices( handle.get_stream()); rmm::device_scalar counter(0, handle.get_stream()); std::vector zero_bias_count_inclusive_sums(low_local_frontier_sizes.size()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { thrust::for_each( handle.get_thrust_policy(), aggregate_low_local_frontier_indices.begin() + low_local_frontier_displacements[i], aggregate_low_local_frontier_indices.begin() + (low_local_frontier_displacements[i] + low_local_frontier_sizes[i]), - [key_idx_to_unique_key_idx = - raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + - local_frontier_displacements[i], - local_frontier_sizes[i]), + [key_idx_to_unique_key_idx = raft::device_span( + aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), aggregate_local_frontier_unique_key_biases = raft::device_span(aggregate_local_frontier_unique_key_biases.data(), aggregate_local_frontier_unique_key_biases.size()), @@ -2053,7 +1799,7 @@ biased_sample_and_compute_local_nbr_indices( auto pair_first = thrust::make_zip_iterator(low_frontier_gathered_zero_bias_frontier_indices.begin(), low_frontier_gathered_zero_bias_nbr_indices.begin()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { thrust::transform( handle.get_thrust_policy(), pair_first + rx_displacements[i], @@ -2157,7 +1903,7 @@ biased_sample_and_compute_local_nbr_indices( rmm::device_uvector aggregate_mid_local_frontier_local_degrees( aggregate_mid_local_frontier_indices.size(), handle.get_stream()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { thrust::transform( handle.get_thrust_policy(), aggregate_mid_local_frontier_indices.begin() + mid_local_frontier_displacements[i], @@ -2166,7 +1912,7 @@ biased_sample_and_compute_local_nbr_indices( aggregate_mid_local_frontier_local_degrees.begin() + mid_local_frontier_displacements[i], cuda::proclaim_return_type( - [key_idx_to_unique_key_idx = raft::device_span( + [key_idx_to_unique_key_idx = raft::device_span( aggregate_local_frontier_key_idx_to_unique_key_idx.data() + local_frontier_displacements[i], local_frontier_sizes[i]), @@ -2199,19 +1945,19 @@ biased_sample_and_compute_local_nbr_indices( std::vector mid_local_frontier_degree_sum_lasts( mid_local_frontier_degree_sums.size()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { thrust::for_each( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(mid_local_frontier_sizes[i]), - [key_idx_to_unique_key_idx = raft::device_span( + [key_idx_to_unique_key_idx = raft::device_span( aggregate_local_frontier_key_idx_to_unique_key_idx.data() + local_frontier_displacements[i], local_frontier_sizes[i]), aggregate_local_frontier_unique_key_biases = - raft::device_span(aggregate_local_frontier_unique_key_biases.data(), - aggregate_local_frontier_unique_key_biases.size()), - unique_key_local_degree_offsets = raft::device_span( + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), + unique_key_local_degree_offsets = raft::device_span( aggregate_local_frontier_unique_key_local_degree_offsets.data() + local_frontier_unique_key_displacements[i], local_frontier_unique_key_sizes[i] + 1), @@ -2221,7 +1967,7 @@ biased_sample_and_compute_local_nbr_indices( aggregate_mid_local_frontier_biases = raft::device_span(aggregate_mid_local_frontier_biases.data(), aggregate_mid_local_frontier_biases.size()), - aggregate_mid_local_frontier_local_degree_offsets = raft::device_span( + aggregate_mid_local_frontier_local_degree_offsets = raft::device_span( aggregate_mid_local_frontier_local_degree_offsets.data(), aggregate_mid_local_frontier_local_degree_offsets.size()), output_offset = mid_local_frontier_displacements[i]] __device__(size_t i) { @@ -2372,7 +2118,7 @@ biased_sample_and_compute_local_nbr_indices( handle.get_stream()); rmm::device_uvector aggregate_high_local_frontier_keys( aggregate_high_local_frontier_local_nbr_indices.size(), handle.get_stream()); - for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + for (size_t i = 0; i < num_local_edge_partitions; ++i) { rmm::device_uvector unique_key_indices_for_key_indices( high_local_frontier_sizes[i], handle.get_stream()); thrust::gather( @@ -2391,8 +2137,8 @@ biased_sample_and_compute_local_nbr_indices( aggregate_local_frontier_unique_key_local_degree_offsets.data() + local_frontier_unique_key_displacements[i], local_frontier_unique_key_sizes[i] + 1), - raft::device_span(aggregate_local_frontier_unique_key_biases.data(), - aggregate_local_frontier_unique_key_biases.size()), + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), std::nullopt, raft::device_span(aggregate_high_local_frontier_local_nbr_indices.data() + high_local_frontier_displacements[i] * K, @@ -2537,13 +2283,13 @@ biased_sample_and_compute_local_nbr_indices( handle.get_thrust_policy(), frontier_indices.begin(), frontier_indices.begin() + frontier_partition_offsets[1], - [key_idx_to_unique_key_idx = - raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data(), - aggregate_local_frontier_key_idx_to_unique_key_idx.size()), + [key_idx_to_unique_key_idx = raft::device_span( + aggregate_local_frontier_key_idx_to_unique_key_idx.data(), + aggregate_local_frontier_key_idx_to_unique_key_idx.size()), aggregate_local_frontier_unique_key_biases = - raft::device_span(aggregate_local_frontier_unique_key_biases.data(), - aggregate_local_frontier_unique_key_biases.size()), - aggregate_local_frontier_unique_key_local_degree_offsets = raft::device_span( + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), + aggregate_local_frontier_unique_key_local_degree_offsets = raft::device_span( aggregate_local_frontier_unique_key_local_degree_offsets.data(), aggregate_local_frontier_unique_key_local_degree_offsets.size()), nbr_indices = raft::device_span(nbr_indices.data(), nbr_indices.size()), @@ -2598,7 +2344,7 @@ biased_sample_and_compute_local_nbr_indices( } std::tie(local_nbr_indices, key_indices, local_frontier_sample_offsets) = - shuffle_and_compute_local_nbr_values( + shuffle_and_compute_local_nbr_values( handle, std::move(nbr_indices), frontier_partitioned_local_degree_displacements @@ -2610,6 +2356,317 @@ biased_sample_and_compute_local_nbr_indices( cugraph::invalid_edge_id_v); } + return std::make_tuple( + std::move(local_nbr_indices), std::move(key_indices), std::move(local_frontier_sample_offsets)); +} + +// skip conversion if local neighbor index is cugraph::invalid_edge_id_v +template +rmm::device_uvector convert_to_unmasked_local_nbr_idx( + raft::handle_t const& handle, + GraphViewType const& graph_view, + VertexIterator aggregate_local_frontier_major_first, + rmm::device_uvector&& local_nbr_indices, + std::optional> key_indices, + raft::host_span local_frontier_sample_offsets, + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes, + size_t K) +{ + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; + static_assert( + std::is_same_v::value_type>); + + auto edge_mask_view = graph_view.edge_mask_view(); + + auto [aggregate_local_frontier_unique_majors, + aggregate_local_frontier_major_idx_to_unique_major_idx, + local_frontier_unique_major_displacements, + local_frontier_unique_major_sizes] = + compute_unique_keys(handle, + aggregate_local_frontier_major_first, + local_frontier_displacements, + local_frontier_sizes); + + // to avoid searching the entire neighbor list K times for high degree vertices with edge masking + auto local_frontier_unique_major_valid_local_nbr_count_inclusive_sums = + compute_valid_local_nbr_count_inclusive_sums( + handle, + graph_view, + aggregate_local_frontier_unique_majors.begin(), + raft::host_span(local_frontier_unique_major_displacements.data(), + local_frontier_unique_major_displacements.size()), + raft::host_span(local_frontier_unique_major_sizes.data(), + local_frontier_unique_major_sizes.size())); + + auto sample_major_idx_first = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type( + [K, + key_indices = key_indices ? thrust::make_optional>( + (*key_indices).data(), (*key_indices).size()) + : thrust::nullopt] __device__(size_t i) { + return key_indices ? (*key_indices)[i] : i / K; + })); + auto pair_first = thrust::make_zip_iterator(local_nbr_indices.begin(), sample_major_idx_first); + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + auto edge_partition = + edge_partition_device_view_t( + graph_view.local_edge_partition_view(i)); + auto edge_partition_e_mask = + edge_mask_view + ? thrust::make_optional< + detail::edge_partition_edge_property_device_view_t>( + *edge_mask_view, i) + : thrust::nullopt; + + auto edge_partition_frontier_major_first = + aggregate_local_frontier_major_first + local_frontier_displacements[i]; + thrust::transform_if( + handle.get_thrust_policy(), + pair_first + local_frontier_sample_offsets[i], + pair_first + local_frontier_sample_offsets[i + 1], + local_nbr_indices.begin() + local_frontier_sample_offsets[i], + local_nbr_indices.begin() + local_frontier_sample_offsets[i], + find_nth_valid_nbr_idx_t{ + edge_partition, + edge_partition_e_mask, + edge_partition_frontier_major_first, + raft::device_span( + aggregate_local_frontier_major_idx_to_unique_major_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), + thrust::make_tuple( + raft::device_span( + std::get<0>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]).data(), + std::get<0>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]) + .size()), + raft::device_span( + std::get<1>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]).data(), + std::get<1>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]) + .size()))}, + is_not_equal_t{cugraph::invalid_edge_id_v}); + } + + return std::move(local_nbr_indices); +} + +template +std::tuple, + std::optional>, + std::vector> +uniform_sample_and_compute_local_nbr_indices( + raft::handle_t const& handle, + GraphViewType const& graph_view, + KeyIterator aggregate_local_frontier_key_first, + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes, + raft::random::RngState& rng_state, + size_t K, + bool with_replacement) +{ + using edge_t = typename GraphViewType::edge_type; + using vertex_t = typename GraphViewType::vertex_type; + using key_t = typename thrust::iterator_traits::value_type; + + int minor_comm_size{1}; + if constexpr (GraphViewType::is_multi_gpu) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + minor_comm_size = minor_comm.get_size(); + } + + auto aggregate_local_frontier_major_first = + thrust_tuple_get_or_identity(aggregate_local_frontier_key_first); + + auto edge_mask_view = graph_view.edge_mask_view(); + + // 1. compute degrees + + rmm::device_uvector frontier_degrees(0, handle.get_stream()); + std::optional> frontier_partitioned_local_degree_displacements{ + std::nullopt}; + { + auto aggregate_local_frontier_local_degrees = + compute_aggregate_local_frontier_local_degrees(handle, + graph_view, + aggregate_local_frontier_major_first, + local_frontier_displacements, + local_frontier_sizes); + + if (minor_comm_size > 1) { + std::tie(frontier_degrees, frontier_partitioned_local_degree_displacements) = + compute_frontier_value_sums_and_partitioned_local_value_sum_displacements( + handle, + raft::device_span(aggregate_local_frontier_local_degrees.data(), + aggregate_local_frontier_local_degrees.size()), + local_frontier_displacements, + local_frontier_sizes); + aggregate_local_frontier_local_degrees.resize(0, handle.get_stream()); + aggregate_local_frontier_local_degrees.shrink_to_fit(handle.get_stream()); + } else { + frontier_degrees = std::move(aggregate_local_frontier_local_degrees); + } + } + + // 2. sample neighbor indices + + rmm::device_uvector nbr_indices(0, handle.get_stream()); + + if (with_replacement) { + if (frontier_degrees.size() > 0) { + nbr_indices.resize(frontier_degrees.size() * K, handle.get_stream()); + cugraph::legacy::ops::graph::get_sampling_index(nbr_indices.data(), + rng_state, + frontier_degrees.data(), + static_cast(frontier_degrees.size()), + static_cast(K), + with_replacement, + handle.get_stream()); + frontier_degrees.resize(0, handle.get_stream()); + frontier_degrees.shrink_to_fit(handle.get_stream()); + } + } else { + nbr_indices = compute_uniform_sampling_index_without_replacement( + handle, std::move(frontier_degrees), rng_state, K); + } + + // 3. shuffle neighbor indices + + auto [local_nbr_indices, key_indices, local_frontier_sample_offsets] = + shuffle_and_compute_local_nbr_values( + handle, + std::move(nbr_indices), + frontier_partitioned_local_degree_displacements + ? std::make_optional>( + (*frontier_partitioned_local_degree_displacements).data(), + (*frontier_partitioned_local_degree_displacements).size()) + : std::nullopt, + K, + cugraph::invalid_edge_id_v); + + // 4. convert neighbor indices in the neighbor list considering edge mask to neighbor indices in + // the neighbor list ignoring edge mask + + if (edge_mask_view) { + local_nbr_indices = convert_to_unmasked_local_nbr_idx( + handle, + graph_view, + aggregate_local_frontier_major_first, + std::move(local_nbr_indices), + key_indices ? std::make_optional>((*key_indices).data(), + (*key_indices).size()) + : std::nullopt, + raft::host_span(local_frontier_sample_offsets.data(), + local_frontier_sample_offsets.size()), + local_frontier_displacements, + local_frontier_sizes, + K); + } + + return std::make_tuple( + std::move(local_nbr_indices), std::move(key_indices), std::move(local_frontier_sample_offsets)); +} + +template +std::tuple, + std::optional>, + std::vector> +biased_sample_and_compute_local_nbr_indices( + raft::handle_t const& handle, + GraphViewType const& graph_view, + KeyIterator aggregate_local_frontier_key_first, + EdgeSrcValueInputWrapper edge_src_value_input, + EdgeDstValueInputWrapper edge_dst_value_input, + EdgeValueInputWrapper edge_value_input, + EdgeBiasOp e_bias_op, + raft::host_span local_frontier_displacements, + raft::host_span local_frontier_sizes, + raft::random::RngState& rng_state, + size_t K, + bool with_replacement, + bool do_expensive_check /* check e_bias_op return values */) +{ + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; + using key_t = typename thrust::iterator_traits::value_type; + + using bias_t = typename edge_op_result_type::type; + + int minor_comm_rank{0}; + int minor_comm_size{1}; + if constexpr (GraphViewType::is_multi_gpu) { + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + minor_comm_rank = minor_comm.get_rank(); + minor_comm_size = minor_comm.get_size(); + } + assert(minor_comm_size == graph_view.number_of_local_edge_partitions()); + + auto aggregate_local_frontier_major_first = + thrust_tuple_get_or_identity(aggregate_local_frontier_key_first); + + auto edge_mask_view = graph_view.edge_mask_view(); + + // 1. compute biases for unique keys (to reduce memory footprint) + + auto [aggregate_local_frontier_unique_keys, + aggregate_local_frontier_key_idx_to_unique_key_idx, + local_frontier_unique_key_displacements, + local_frontier_unique_key_sizes] = compute_unique_keys(handle, + aggregate_local_frontier_key_first, + local_frontier_displacements, + local_frontier_sizes); + + auto [aggregate_local_frontier_unique_key_biases, + aggregate_local_frontier_unique_key_local_degree_offsets] = + compute_aggregate_local_frontier_biases( + handle, + graph_view, + get_dataframe_buffer_begin(aggregate_local_frontier_unique_keys), + edge_src_value_input, + edge_dst_value_input, + edge_value_input, + e_bias_op, + raft::host_span(local_frontier_unique_key_displacements.data(), + local_frontier_unique_key_displacements.size()), + raft::host_span(local_frontier_unique_key_sizes.data(), + local_frontier_unique_key_sizes.size()), + do_expensive_check); + + // 2. sample neighbor indices and shuffle neighbor indices + + auto [local_nbr_indices, key_indices, local_frontier_sample_offsets] = + biased_sample( + handle, + local_frontier_displacements, + local_frontier_sizes, + raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data(), + aggregate_local_frontier_key_idx_to_unique_key_idx.size()), + raft::host_span(local_frontier_unique_key_displacements.data(), + local_frontier_unique_key_displacements.size()), + raft::host_span(local_frontier_unique_key_sizes.data(), + local_frontier_unique_key_sizes.size()), + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), + raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data(), + aggregate_local_frontier_unique_key_local_degree_offsets.size()), + rng_state, + K, + with_replacement); + // 3. convert neighbor indices in the neighbor list considering edge mask to neighbor indices in // the neighbor list ignoring edge mask @@ -2622,7 +2679,8 @@ biased_sample_and_compute_local_nbr_indices( key_indices ? std::make_optional>((*key_indices).data(), (*key_indices).size()) : std::nullopt, - local_frontier_sample_offsets, + raft::host_span(local_frontier_sample_offsets.data(), + local_frontier_sample_offsets.size()), local_frontier_displacements, local_frontier_sizes, K); diff --git a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh index 30706632ad..b15afd3980 100644 --- a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh +++ b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh @@ -216,8 +216,7 @@ template -std::tuple>, - decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{}))> +std::tuple>, dataframe_buffer_type_t> per_v_random_select_transform_e(raft::handle_t const& handle, GraphViewType const& graph_view, KeyBucketType const& key_list, @@ -352,8 +351,9 @@ per_v_random_select_transform_e(raft::handle_t const& handle, graph_view, (minor_comm_size > 1) ? get_dataframe_buffer_cbegin(*aggregate_local_key_list) : key_list.begin(), - local_key_list_displacements, - local_key_list_sizes, + raft::host_span(local_key_list_displacements.data(), + local_key_list_displacements.size()), + raft::host_span(local_key_list_sizes.data(), local_key_list_sizes.size()), rng_state, K, with_replacement); @@ -368,8 +368,9 @@ per_v_random_select_transform_e(raft::handle_t const& handle, edge_bias_dst_value_input, edge_bias_value_input, e_bias_op, - local_key_list_displacements, - local_key_list_sizes, + raft::host_span(local_key_list_displacements.data(), + local_key_list_displacements.size()), + raft::host_span(local_key_list_sizes.data(), local_key_list_sizes.size()), rng_state, K, with_replacement, @@ -594,7 +595,10 @@ per_v_random_select_transform_e(raft::handle_t const& handle, } // namespace detail /** - * @brief Randomly select and transform the input (tagged-)vertices' outgoing edges with biases. + * @brief Randomly select and transform the input (tagged-)vertices' outgoing edges. + * + * This function assumes that every outgoing edge of a given vertex has the same odd to be selected + * (uniform neighbor sampling). * * @tparam GraphViewType Type of the passed non-owning graph object. * @tparam KeyBucketType Type of the key bucket class which abstracts the current (tagged-)vertex @@ -602,8 +606,6 @@ per_v_random_select_transform_e(raft::handle_t const& handle, * @tparam EdgeSrcValueInputWrapper Type of the wrapper for edge source property values. * @tparam EdgeDstValueInputWrapper Type of the wrapper for edge destination property values. * @tparam EdgeValueInputWrapper Type of the wrapper for edge property values. - * @tparam EdgeBiasOp Type of the quinary edge operator to set-up selection bias - * values. * @tparam EdgeOp Type of the quinary edge operator. * @tparam T Type of the selected and transformed edge output values. * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and @@ -625,12 +627,6 @@ per_v_random_select_transform_e(raft::handle_t const& handle, * to this process in multi-GPU). Use either cugraph::edge_property_t::view() (if @p e_op needs to * access edge property values) or cugraph::edge_dummy_property_t::view() (if @p e_op does not * access edge property values). - * @param e_bias_op Quinary operator takes (tagged-)edge source, edge destination, property values - * for the source, destination, and edge and returns a floating point bias value to be used in - * biased random selection. The return value should be non-negative. The bias value of 0 indicates - * that the corresponding edge cannot be selected. Assuming that the return value type is bias_t, - * the sum of the bias values for any seed vertex should not exceed - * std::numeric_limits::max(). * @param e_op Quinary operator takes (tagged-)edge source, edge destination, property values for * the source, destination, and edge and returns a value to be collected in the output. This * function is called only for the selected edges. @@ -652,24 +648,15 @@ per_v_random_select_transform_e(raft::handle_t const& handle, */ template -std::tuple>, - decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{}))> +std::tuple>, dataframe_buffer_type_t> per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, GraphViewType const& graph_view, KeyBucketType const& key_list, - EdgeBiasSrcValueInputWrapper edge_bias_src_value_input, - EdgeBiasDstValueInputWrapper edge_bias_dst_value_input, - EdgeBiasValueInputWrapper edge_bias_value_input, - EdgeBiasOp e_bias_op, EdgeSrcValueInputWrapper edge_src_value_input, EdgeDstValueInputWrapper edge_dst_value_input, EdgeValueInputWrapper edge_value_input, @@ -680,29 +667,31 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, std::optional invalid_value, bool do_expensive_check = false) { - return detail::per_v_random_select_transform_e(handle, - graph_view, - key_list, - edge_bias_src_value_input, - edge_bias_dst_value_input, - edge_bias_value_input, - e_bias_op, - edge_src_value_input, - edge_dst_value_input, - edge_value_input, - e_op, - rng_state, - K, - with_replacement, - invalid_value, - do_expensive_check); + return detail::per_v_random_select_transform_e( + handle, + graph_view, + key_list, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + edge_dummy_property_t{}.view(), + detail::constant_e_bias_op_t{}, + edge_src_value_input, + edge_dst_value_input, + edge_value_input, + e_op, + rng_state, + K, + with_replacement, + invalid_value, + do_expensive_check); } /** - * @brief Randomly select and transform the input (tagged-)vertices' outgoing edges. - * - * This function assumes that every outgoing edge of a given vertex has the same odd to be selected - * (uniform neighbor sampling). + * @brief Randomly select and transform the input (tagged-)vertices' outgoing edges with biases. * * @tparam GraphViewType Type of the passed non-owning graph object. * @tparam KeyBucketType Type of the key bucket class which abstracts the current (tagged-)vertex @@ -710,6 +699,8 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, * @tparam EdgeSrcValueInputWrapper Type of the wrapper for edge source property values. * @tparam EdgeDstValueInputWrapper Type of the wrapper for edge destination property values. * @tparam EdgeValueInputWrapper Type of the wrapper for edge property values. + * @tparam EdgeBiasOp Type of the quinary edge operator to set-up selection bias + * values. * @tparam EdgeOp Type of the quinary edge operator. * @tparam T Type of the selected and transformed edge output values. * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and @@ -731,6 +722,12 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, * to this process in multi-GPU). Use either cugraph::edge_property_t::view() (if @p e_op needs to * access edge property values) or cugraph::edge_dummy_property_t::view() (if @p e_op does not * access edge property values). + * @param e_bias_op Quinary operator takes (tagged-)edge source, edge destination, property values + * for the source, destination, and edge and returns a floating point bias value to be used in + * biased random selection. The return value should be non-negative. The bias value of 0 indicates + * that the corresponding edge cannot be selected. Assuming that the return value type is bias_t, + * the sum of the bias values for any seed vertex should not exceed + * std::numeric_limits::max(). * @param e_op Quinary operator takes (tagged-)edge source, edge destination, property values for * the source, destination, and edge and returns a value to be collected in the output. This * function is called only for the selected edges. @@ -752,16 +749,23 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, */ template -std::tuple>, - decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{}))> +std::tuple>, dataframe_buffer_type_t> per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, GraphViewType const& graph_view, KeyBucketType const& key_list, + EdgeBiasSrcValueInputWrapper edge_bias_src_value_input, + EdgeBiasDstValueInputWrapper edge_bias_dst_value_input, + EdgeBiasValueInputWrapper edge_bias_value_input, + EdgeBiasOp e_bias_op, EdgeSrcValueInputWrapper edge_src_value_input, EdgeDstValueInputWrapper edge_dst_value_input, EdgeValueInputWrapper edge_value_input, @@ -772,27 +776,22 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, std::optional invalid_value, bool do_expensive_check = false) { - return detail::per_v_random_select_transform_e( - handle, - graph_view, - key_list, - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - edge_dummy_property_t{}.view(), - detail::constant_e_bias_op_t{}, - edge_src_value_input, - edge_dst_value_input, - edge_value_input, - e_op, - rng_state, - K, - with_replacement, - invalid_value, - do_expensive_check); + return detail::per_v_random_select_transform_e(handle, + graph_view, + key_list, + edge_bias_src_value_input, + edge_bias_dst_value_input, + edge_bias_value_input, + e_bias_op, + edge_src_value_input, + edge_dst_value_input, + edge_value_input, + e_op, + rng_state, + K, + with_replacement, + invalid_value, + do_expensive_check); } } // namespace cugraph diff --git a/cpp/src/sampling/neighbor_sampling_impl.hpp b/cpp/src/sampling/neighbor_sampling_impl.hpp index ed77b33043..b3204d54a5 100644 --- a/cpp/src/sampling/neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/neighbor_sampling_impl.hpp @@ -184,7 +184,7 @@ neighbor_sample_impl(raft::handle_t const& handle, std::vector level_sizes{}; - for (auto hop = 0; hop < num_hops; hop++) { + for (size_t hop = 0; hop < num_hops; ++hop) { rmm::device_uvector level_result_src(0, handle.get_stream()); rmm::device_uvector level_result_dst(0, handle.get_stream());