-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some MTMG code cleanup and small optimizations #3894
Some MTMG code cleanup and small optimizations #3894
Conversation
std::vector<std::tuple<vertex_t*, vertex_t const*, size_t>> dst_copies; | ||
std::vector<std::tuple<weight_t*, weight_t const*, size_t>> wgt_copies; | ||
std::vector<std::tuple<edge_t*, edge_t const*, size_t>> edge_id_copies; | ||
std::vector<std::tuple<edge_type_t*, edge_type_t const*, size_t>> edge_type_copies; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we maintain 5 variables or one variable storing (input_start_offset, output_start_offset, size) triplets will be sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion, I'll look into that for next push.
while (count > 0) { | ||
size_t copy_count = std::min(count, (src_.back().size() - current_pos_)); | ||
|
||
src_copies.push_back( | ||
std::make_tuple(src_.back().begin() + current_pos_, src.begin() + pos, copy_count)); | ||
dst_copies.push_back( | ||
std::make_tuple(dst_.back().begin() + current_pos_, dst.begin() + pos, copy_count)); | ||
if (wgt) | ||
wgt_copies.push_back( | ||
std::make_tuple(wgt_->back().begin() + current_pos_, wgt->begin() + pos, copy_count)); | ||
if (edge_id) | ||
edge_id_copies.push_back(std::make_tuple( | ||
edge_id_->back().begin() + current_pos_, edge_id->begin() + pos, copy_count)); | ||
if (edge_type) | ||
edge_type_copies.push_back(std::make_tuple( | ||
edge_type_->back().begin() + current_pos_, edge_type->begin() + pos, copy_count)); | ||
|
||
count -= copy_count; | ||
pos += copy_count; | ||
current_pos_ += copy_count; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if count
= 1000, src_.back().size()
= 100, and current_pos_
= 0?
At the end of the first loop, copy_count
= 100, count
= 900, pos
=100, current_pos_
=100. From the second loop, copy_count
=0 and this loop won't finish or am I missing something?
Shouldn't we allocate additional buffers and reset current_pos_
to 0 for this loop to finish?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes... not sure how I missed that, the original code had that logic, I imagine I accidentally deleted that. I'll add that back in.
|
||
handle.raft_handle().sync_stream(); | ||
handle.raft_handle().sync_stream(handle.get_stream()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add get_stream() to mtmg::handle, what about adding sync_stream to mtmg::handle as well?
@@ -153,11 +154,12 @@ class resource_manager_t { | |||
auto pos = local_rank_map_.find(rank); | |||
RAFT_CUDA_TRY(cudaSetDevice(pos->second.value())); | |||
|
|||
raft::handle_t tmp_handle; | |||
|
|||
size_t n_streams{16}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I needed that many for one of the tests I ran :-)
I'll make that a parameter. Any suggestion on a good default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe # of GPUs? (assuming that 1 stream per thread and # threads == # GPUs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each GPU will have its own pool of streams. The pool so far is used by different thread ranks copying data to the GPU independently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added it as a parameter.
cpp/src/mtmg/vertex_result.cu
Outdated
@@ -97,7 +97,7 @@ rmm::device_uvector<result_t> vertex_result_view_t<result_t>::gather( | |||
return vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(v); | |||
}); | |||
|
|||
thrust::gather(handle.raft_handle().get_thrust_policy(), | |||
thrust::gather(rmm::exec_policy(handle.get_stream()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about adding (mtmg::)handle.get_thrust_policy()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me (besides the reason behind 4 in the code, some documentation will be helpful).
cpp/tests/mtmg/threaded_test.cu
Outdated
@@ -107,7 +107,7 @@ class Tests_Multithreaded | |||
ncclGetUniqueId(&instance_manager_id); | |||
|
|||
auto instance_manager = resource_manager.create_instance_manager( | |||
resource_manager.registered_ranks(), instance_manager_id); | |||
resource_manager.registered_ranks(), instance_manager_id, 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is 4 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made this a constant, added a comment describing why it's 4 in the latest push.
/merge |
Added some missing documentation.
A couple of optimizations:
append
logic to keep the mutex lock only long enough to compute what needs to be copied and where.