-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for distributed sampling #246
Conversation
Codecov Report
@@ Coverage Diff @@
## master #246 +/- ##
==========================================
- Coverage 82.83% 79.61% -3.22%
==========================================
Files 28 28
Lines 938 996 +58
==========================================
+ Hits 777 793 +16
- Misses 161 203 +42
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
557f959
to
e183f1d
Compare
9d10051
to
000bb82
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. A few minor high-level comments about the neighbor sampling part.
Overall, I think it would be easier to get this PR in if we would add merge_sampler_outputs
and relabel_neighborhood
separately.
std::vector<int64_t>> | ||
sample(const at::Tensor& rowptr, | ||
const at::Tensor& col, | ||
const at::Tensor& seed, | ||
const std::vector<int64_t>& num_neighbors, | ||
const c10::optional<at::Tensor>& time, | ||
const c10::optional<at::Tensor>& seed_time, | ||
const c10::optional<at::Tensor>& batch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does batch
behave in case of disjoint=False
? If distributed sampling requires disjoint=True
anyway, I am not totally sure I understand why we need this new argument here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Distributed sampling does not requires disjoint=true
. It can work with disjoint=false
as well.
batch
is used only when disjoint=true
, otherwise it is not relevant.
Why batch
is needed:
During distributed sampling we sample by one hop in c++ and go out of the sample() function. So if we sample more than one layer, information about which subgraph a given node belonged to, will be lost. So, thanks to the batch
variable, we can assign initial values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is clear now. But I am still not sure why we need it though. In the end, we can just do
batch[out.batch]
outside of sampling to re-construct the correct batch information.
000bb82
to
90e4b8c
Compare
Thank you for the comments. As you suggested I opened a new PR for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kgajdamo. To decrease the size of PR, I removed the relabel
function. Can you recheck it in in a separate PR - please be patient with me :(
In this PR, I made the following changes:
- I added
dist_neighbor_sample
anddist_hetero_neighbor_sample
. Please call these on PyG side since I don't want to make the generalneighbor_sample
interface to polluted. Fordist_neighbor_sample
we can also consider cleaning up the interface, e.g.,num_neighbors
should only be anint
rather than alist
, but it's not a must. - For now, I removed the
batch
argument since I still don't see why we need it. Please bring it back if you see no other way to resolve this.
Can we also add a test for these new functions? |
Thanks for the updates. It is a good idea to have a
|
This code belongs to the part of the whole distributed training for PyG. This PR is complementary to the [#246](#246) and introduces some updates. What has been changed: * Removed not needed `dist_hetero_neighbor_sample` function (due to the fact, that distributed sampling have a loop over the layers in python, in case of hetero at the moment when we call `neighbor_sample` we have only one edge type. So it becomes actually homo and we don't need the `dist_hetero_neighbor_sample` and can use `dist_neighbor_sample` instead.) * Removed all not used outputs and left only the following: `node`, `edge_ids`, `cummsum_sampled_nbrs_per_node`. * Changed `std::vector<int64_t> num_neighbors` input list into `int64_t one_hop_num`. Added: * Unit tests --------- Co-authored-by: rusty1s <[email protected]>
This code belongs to the part of the whole distributed training for PyG. ## Description Distributed training requires after each layer to merge results between machines. For later algorithms, it is required that the results be sorted according to the sampling order. This PR introduces a function which purpose is to handle merge and sort operations in parallel. **Other distributed PRs:** pytorch_geometric DistLoader: [#7869](pyg-team/pytorch_geometric#7869) pytorch_geometric DistSampler: [#7974](pyg-team/pytorch_geometric#7974) pyg-lib: [#246](#246) --------- Co-authored-by: rusty1s <[email protected]>
#254) This code belongs to the part of the whole distributed training for PyG. This PR is complementary to the [#246](#246). ##Descrption Perform global to local mappings using mapper and create (row, col) based on a sampled_nodes_with_duplicates and sampled_nbrs_per_node. **Other distributed PRs:** pytorch_geometric DistLoader: [#7869](pyg-team/pytorch_geometric#7869) pytorch_geometric DistSampler: [#7974](pyg-team/pytorch_geometric#7974) pyg-lib [MERGED]: [#246](#246) pyg-lib: [#252](#252) pyg-lib: [#253](#253) --------- Co-authored-by: Matthias Fey <[email protected]>
This code belongs to the part of the whole distributed training for PyG. `DistNeighborSampler` leverages the `NeighborSampler` class from `pytorch_geometric` and the `neighbor_sample` function from `pyg-lib`. However, due to the fact that in case of distributed training it is required to synchronise the results between machines after each layer, the part of the code responsible for sampling was implemented in python. Added suport for the following sampling methods: - node, edge, negative, disjoint, temporal **TODOs:** - [x] finish hetero part - [x] subgraph sampling **This PR should be merged together with other distributed PRs:** pyg-lib: [#246](pyg-team/pyg-lib#246), [#252](pyg-team/pyg-lib#252) GraphStore\FeatureStore: #8083 DistLoaders: 1. #8079 2. #8080 3. #8085 --------- Co-authored-by: JakubPietrakIntel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ZhengHongming888 <[email protected]> Co-authored-by: Jakub Pietrak <[email protected]> Co-authored-by: Matthias Fey <[email protected]>
This code belongs to the part of the whole distributed training for PyG.
Description
Distributed training neighbor sampling differs from the sampling currently implemented in pyg-lib. During distributed training nodes from one batch can be sampled by different machines (and therefore different samplers). The result of this is incorrect subtree/subgraph node indexing.
To achieve correct results it is necessary to sample by one hop and then synchronise outputs between machines.
Proposed algorithm:
sampled_nodes
) with duplicates inneighbor_sample
.cumm_sum_sampled_nbrs_per_node
).sampled_nodes_with_duplicates
andsampled_nbrs_per_node
.Step 3. was implemented in pytorch_geometric.
Added
new argument
distributed
to theneighbor_sample
function to enable the algorithm described above.new argument
batch
to theneighbor_sample
function that allows to specify the initial subgraph indices for seed nodes (used with disjoint).new return value
cumm_sum_sampled_nbrs_per_node
to theneighbor_sample
function to return cumulative sum of the sampled neighbors per each node.new function
relabel_neighborhood
that is used after sampling all layers and its purpose is to relabel global indices of the sampled nodes to the local subtree/subgraph indices (row, col).new function
hetero_relabel_neighborhood
(same asrelabel_neighborhood
but for heterogeneous graphs). Returns (row_dict and col_dict).unit tests