Skip to content

Commit

Permalink
Save WIP for page rank with STF
Browse files Browse the repository at this point in the history
  • Loading branch information
caugonnet committed Jan 22, 2025
1 parent 22b54cc commit f960650
Showing 1 changed file with 63 additions and 24 deletions.
87 changes: 63 additions & 24 deletions cpp/src/prims/detail/per_v_transform_reduce_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@
#include <type_traits>
#include <utility>

#include <cuda/experimental/stf.cuh>
#include <raft/core/resource/custom_resource.hpp>

using namespace cuda::experimental::stf;

namespace cugraph {

namespace detail {
Expand Down Expand Up @@ -1151,6 +1156,15 @@ void per_v_transform_reduce_e_edge_partition(
std::optional<raft::host_span<size_t const>> key_segment_offsets,
std::optional<raft::host_span<size_t const>> const& edge_partition_stream_pool_indices)
{
async_resources_handle& cudastf_handle = *raft::resource::get_custom_resource<async_resources_handle>(handle);
stream_ctx cudastf_ctx(handle.get_stream(), cudastf_handle);

logical_data<void_interface> output_tokens[4];
for (size_t i = 0; i < 4; i++)
{
output_tokens[i] = cudastf_ctx.logical_token();
}

constexpr bool use_input_key = !std::is_same_v<OptionalKeyIterator, void*>;

using vertex_t = typename GraphViewType::vertex_type;
Expand All @@ -1174,10 +1188,13 @@ void per_v_transform_reduce_e_edge_partition(

if constexpr (update_major && !use_input_key) { // this is necessary as we don't visit
// every vertex in the hypersparse segment
thrust::fill(rmm::exec_policy_nosync(exec_stream),
output_buffer + (*key_segment_offsets)[3],
output_buffer + (*key_segment_offsets)[4],
major_init);
// TODO task write output_token[3]
cudastf_ctx.task(output_tokens[3].write())->*[=](cudaStream_t stream) {
thrust::fill(rmm::exec_policy_nosync(stream),
output_buffer + (*key_segment_offsets)[3],
output_buffer + (*key_segment_offsets)[4],
major_init);
};
}

auto segment_size = use_input_key
Expand All @@ -1187,8 +1204,9 @@ void per_v_transform_reduce_e_edge_partition(
raft::grid_1d_thread_t update_grid(segment_size,
detail::per_v_transform_reduce_e_kernel_block_size,
handle.get_device_properties().maxGridSize[0]);
size_t token_idx = 0;
auto segment_output_buffer = output_buffer;
if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[3]; }
if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[3]; token_idx +=3; }
auto segment_key_first = edge_partition_key_first;
auto segment_key_last = edge_partition_key_last;
if constexpr (use_input_key) {
Expand All @@ -1199,20 +1217,22 @@ void per_v_transform_reduce_e_edge_partition(
assert(segment_key_first == nullptr);
assert(segment_key_last == nullptr);
}
detail::per_v_transform_reduce_e_hypersparse<update_major, GraphViewType>
<<<update_grid.num_blocks, update_grid.block_size, 0, exec_stream>>>(
edge_partition,
segment_key_first,
segment_key_last,
edge_partition_src_value_input,
edge_partition_dst_value_input,
edge_partition_e_value_input,
edge_partition_e_mask,
segment_output_buffer,
e_op,
major_init,
reduce_op,
pred_op);
cudastf_ctx.task(output_tokens[token_idx].rw())->*[=](cudaStream_t stream) {
detail::per_v_transform_reduce_e_hypersparse<update_major, GraphViewType>
<<<update_grid.num_blocks, update_grid.block_size, 0, stream>>>(
edge_partition,
segment_key_first,
segment_key_last,
edge_partition_src_value_input,
edge_partition_dst_value_input,
edge_partition_e_value_input,
edge_partition_e_mask,
segment_output_buffer,
e_op,
major_init,
reduce_op,
pred_op);
};
}
}
if ((*key_segment_offsets)[3] - (*key_segment_offsets)[2]) {
Expand All @@ -1223,8 +1243,9 @@ void per_v_transform_reduce_e_edge_partition(
raft::grid_1d_thread_t update_grid((*key_segment_offsets)[3] - (*key_segment_offsets)[2],
detail::per_v_transform_reduce_e_kernel_block_size,
handle.get_device_properties().maxGridSize[0]);
size_t token_idx = 0;
auto segment_output_buffer = output_buffer;
if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[2]; }
if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[2]; token_idx += 2; }
std::optional<segment_key_iterator_t>
segment_key_first{}; // std::optional as thrust::transform_iterator's default constructor
// is a deleted function, segment_key_first should always have a value
Expand All @@ -1234,8 +1255,10 @@ void per_v_transform_reduce_e_edge_partition(
segment_key_first = thrust::make_counting_iterator(edge_partition.major_range_first());
}
*segment_key_first += (*key_segment_offsets)[2];

cudastf_ctx.task(output_tokens[token_idx].rw())->*[=](cudaStream_t stream) {
detail::per_v_transform_reduce_e_low_degree<update_major, GraphViewType>
<<<update_grid.num_blocks, update_grid.block_size, 0, exec_stream>>>(
<<<update_grid.num_blocks, update_grid.block_size, 0, stream>>>(
edge_partition,
*segment_key_first,
*segment_key_first + ((*key_segment_offsets)[3] - (*key_segment_offsets)[2]),
Expand All @@ -1248,6 +1271,7 @@ void per_v_transform_reduce_e_edge_partition(
major_init,
reduce_op,
pred_op);
};
}
if ((*key_segment_offsets)[2] - (*key_segment_offsets)[1] > 0) {
auto exec_stream = edge_partition_stream_pool_indices
Expand All @@ -1257,8 +1281,9 @@ void per_v_transform_reduce_e_edge_partition(
raft::grid_1d_warp_t update_grid((*key_segment_offsets)[2] - (*key_segment_offsets)[1],
detail::per_v_transform_reduce_e_kernel_block_size,
handle.get_device_properties().maxGridSize[0]);
size_t token_idx = 0;
auto segment_output_buffer = output_buffer;
if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[1]; }
if constexpr (update_major) { segment_output_buffer += (*key_segment_offsets)[1]; token_idx += 1;}
std::optional<segment_key_iterator_t>
segment_key_first{}; // std::optional as thrust::transform_iterator's default constructor
// is a deleted function, segment_key_first should always have a value
Expand All @@ -1268,8 +1293,10 @@ void per_v_transform_reduce_e_edge_partition(
segment_key_first = thrust::make_counting_iterator(edge_partition.major_range_first());
}
*segment_key_first += (*key_segment_offsets)[1];

cudastf_ctx.task(output_tokens[token_idx].rw())->*[=](cudaStream_t stream) {
detail::per_v_transform_reduce_e_mid_degree<update_major, GraphViewType>
<<<update_grid.num_blocks, update_grid.block_size, 0, exec_stream>>>(
<<<update_grid.num_blocks, update_grid.block_size, 0, stream>>>(
edge_partition,
*segment_key_first,
*segment_key_first + ((*key_segment_offsets)[2] - (*key_segment_offsets)[1]),
Expand All @@ -1283,6 +1310,7 @@ void per_v_transform_reduce_e_edge_partition(
major_identity_element,
reduce_op,
pred_op);
};
}
if ((*key_segment_offsets)[1] > 0) {
auto exec_stream = edge_partition_stream_pool_indices
Expand All @@ -1303,8 +1331,9 @@ void per_v_transform_reduce_e_edge_partition(
} else {
segment_key_first = thrust::make_counting_iterator(edge_partition.major_range_first());
}
cudastf_ctx.task(output_tokens[0].rw())->*[=](cudaStream_t stream) {
detail::per_v_transform_reduce_e_high_degree<update_major, GraphViewType>
<<<update_grid.num_blocks, update_grid.block_size, 0, exec_stream>>>(
<<<update_grid.num_blocks, update_grid.block_size, 0, stream>>>(
edge_partition,
*segment_key_first,
*segment_key_first + (*key_segment_offsets)[1],
Expand All @@ -1318,6 +1347,7 @@ void per_v_transform_reduce_e_edge_partition(
major_identity_element,
reduce_op,
pred_op);
};
}
} else {
auto exec_stream = edge_partition_stream_pool_indices
Expand Down Expand Up @@ -1361,6 +1391,8 @@ void per_v_transform_reduce_e_edge_partition(
pred_op);
}
}

cudastf_ctx.finalize();
}

template <bool incoming, // iterate over incoming edges (incoming == true) or outgoing edges
Expand Down Expand Up @@ -3093,6 +3125,9 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
}
if (loop_stream_pool_indices) { handle.sync_stream_pool(*loop_stream_pool_indices); }

// TODO BEGIN
//stream_ctx stf_ctx(handle.get_stream());

for (size_t j = 0; j < loop_count; ++j) {
if (process_local_edges[j]) {
auto partition_idx = i + j;
Expand Down Expand Up @@ -3265,6 +3300,10 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
}
}
}

//stf_ctx.finalize();

// TODO END
if (stream_pool_indices) { handle.sync_stream_pool(*stream_pool_indices); }

if constexpr (GraphViewType::is_multi_gpu && update_major) {
Expand Down

0 comments on commit f960650

Please sign in to comment.