diff --git a/CHANGELOG.md b/CHANGELOG.md index 2307c9b1..9df34317 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.3.0] - 2023-MM-DD ### Added -- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#253](https://github.com/pyg-team/pyg-lib/pull/253)) +- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#252](https://github.com/pyg-team/pyg-lib/pull/252), [#253](https://github.com/pyg-team/pyg-lib/pull/253)) - Added support for homogeneous and heterogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247), [#251](https://github.com/pyg-team/pyg-lib/pull/251)) - Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243)) - Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229)) diff --git a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp new file mode 100644 index 00000000..856bfb17 --- /dev/null +++ b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp @@ -0,0 +1,170 @@ +#include +#include +#include + +#include "parallel_hashmap/phmap.h" + +#include "pyg_lib/csrc/sampler/cpu/mapper.h" +#include "pyg_lib/csrc/utils/cpu/convert.h" +#include "pyg_lib/csrc/utils/types.h" + +namespace pyg { +namespace sampler { + +namespace { + +template +std::tuple, + std::vector> +merge_outputs( + const std::vector& node_ids, + const std::vector& edge_ids, + const std::vector>& cumsum_neighbors_per_node, + const std::vector& partition_ids, + const std::vector& partition_orders, + const int64_t num_partitions, + const int64_t num_neighbors, + const c10::optional& batch) { + at::Tensor out_node_id; + at::Tensor out_edge_id; + c10::optional out_batch = c10::nullopt; + + auto offset = num_neighbors; + + if (num_neighbors < 0) { + // find maximum population + std::vector> population(num_partitions); + std::vector max_populations(num_partitions); + + at::parallel_for(0, num_partitions, 1, [&](size_t _s, size_t _e) { + for (auto p_id = _s; p_id < _e; p_id++) { + auto cummsum1 = + std::vector(cumsum_neighbors_per_node[p_id].begin() + 1, + cumsum_neighbors_per_node[p_id].end()); + auto cummsum2 = + std::vector(cumsum_neighbors_per_node[p_id].begin(), + cumsum_neighbors_per_node[p_id].end() - 1); + std::transform(cummsum1.begin(), cummsum1.end(), cummsum2.begin(), + std::back_inserter(population[p_id]), + [](int64_t a, int64_t b) { return std::abs(a - b); }); + auto max = + *max_element(population[p_id].begin(), population[p_id].end()); + max_populations[p_id] = max; + } + }); + offset = *max_element(max_populations.begin(), max_populations.end()); + } + + const auto p_size = partition_ids.size(); + std::vector sampled_neighbors_per_node(p_size); + + const auto scalar_type = node_ids[0].scalar_type(); + AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] { + std::vector sampled_node_ids(p_size * offset, -1); + std::vector sampled_edge_ids(p_size * offset, -1); + std::vector> sampled_node_ids_vec(p_size); + std::vector> sampled_edge_ids_vec(p_size); + + std::vector sampled_batch; + if constexpr (disjoint) { + sampled_batch = std::vector(p_size * offset, -1); + } + const auto batch_data = + disjoint ? batch.value().data_ptr() : nullptr; + + for (auto p_id = 0; p_id < num_partitions; p_id++) { + sampled_node_ids_vec[p_id] = + pyg::utils::to_vector(node_ids[p_id]); + sampled_edge_ids_vec[p_id] = + pyg::utils::to_vector(edge_ids[p_id]); + } + at::parallel_for(0, p_size, 1, [&](size_t _s, size_t _e) { + for (auto j = _s; j < _e; j++) { + auto p_id = partition_ids[j]; + auto p_order = partition_orders[j]; + + // When it comes to node and batch, we omit seed nodes. + // In the case of edges, we take into account all sampled edge ids. + auto begin_node = cumsum_neighbors_per_node[p_id][p_order]; + auto begin_edge = begin_node - cumsum_neighbors_per_node[p_id][0]; + + auto end_node = cumsum_neighbors_per_node[p_id][p_order + 1]; + auto end_edge = end_node - cumsum_neighbors_per_node[p_id][0]; + + std::copy(sampled_node_ids_vec[p_id].begin() + begin_node, + sampled_node_ids_vec[p_id].begin() + end_node, + sampled_node_ids.begin() + j * offset); + std::copy(sampled_edge_ids_vec[p_id].begin() + begin_edge, + sampled_edge_ids_vec[p_id].begin() + end_edge, + sampled_edge_ids.begin() + j * offset); + + if constexpr (disjoint) { + std::fill(sampled_batch.begin() + j * offset, + sampled_batch.begin() + j * offset + end_node - begin_node, + batch_data[j]); + } + + sampled_neighbors_per_node[j] = end_node - begin_node; + } + }); + + // Remove auxilary -1 numbers: + auto neg = + std::remove(sampled_node_ids.begin(), sampled_node_ids.end(), -1); + sampled_node_ids.erase(neg, sampled_node_ids.end()); + out_node_id = pyg::utils::from_vector(sampled_node_ids); + + neg = std::remove(sampled_edge_ids.begin(), sampled_edge_ids.end(), -1); + sampled_edge_ids.erase(neg, sampled_edge_ids.end()); + out_edge_id = pyg::utils::from_vector(sampled_edge_ids); + + if constexpr (disjoint) { + neg = std::remove(sampled_batch.begin(), sampled_batch.end(), -1); + sampled_batch.erase(neg, sampled_batch.end()); + out_batch = pyg::utils::from_vector(sampled_batch); + } + }); + + return std::make_tuple(out_node_id, out_edge_id, out_batch, + sampled_neighbors_per_node); +} + +#define DISPATCH_MERGE_OUTPUTS(disjoint, ...) \ + if (disjoint) \ + return merge_outputs(__VA_ARGS__); \ + if (!disjoint) \ + return merge_outputs(__VA_ARGS__); + +} // namespace + +std::tuple, + std::vector> +merge_sampler_outputs_kernel( + const std::vector& node_ids, + const std::vector& edge_ids, + const std::vector>& cumsum_neighbors_per_node, + const std::vector& partition_ids, + const std::vector& partition_orders, + const int64_t num_partitions, + const int64_t num_neighbors, + const c10::optional& batch, + bool disjoint) { + DISPATCH_MERGE_OUTPUTS( + disjoint, node_ids, edge_ids, cumsum_neighbors_per_node, partition_ids, + partition_orders, num_partitions, num_neighbors, batch); +} + +// We use `BackendSelect` as a fallback to the dispatcher logic as automatic +// dispatching of std::vector is not yet supported by PyTorch. +// See: pytorch/aten/src/ATen/templates/RegisterBackendSelect.cpp. +TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::merge_sampler_outputs"), + TORCH_FN(merge_sampler_outputs_kernel)); +} + +} // namespace sampler +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_merge_outputs.cpp b/pyg_lib/csrc/sampler/dist_merge_outputs.cpp new file mode 100644 index 00000000..4caa06a6 --- /dev/null +++ b/pyg_lib/csrc/sampler/dist_merge_outputs.cpp @@ -0,0 +1,59 @@ +#include "dist_merge_outputs.h" + +#include +#include + +#include "pyg_lib/csrc/utils/check.h" + +namespace pyg { +namespace sampler { + +std::tuple, + std::vector> +merge_sampler_outputs( + const std::vector& node_ids, + const std::vector& edge_ids, + const std::vector>& cumsum_neighbors_per_node, + const std::vector& partition_ids, + const std::vector& partition_orders, + const int64_t num_partitions, + const int64_t num_neighbors, + const c10::optional& batch, + bool disjoint) { + std::vector node_ids_args; + std::vector edge_ids_args; + pyg::utils::fill_tensor_args(node_ids_args, node_ids, "node_ids", 0); + pyg::utils::fill_tensor_args(edge_ids_args, edge_ids, "edge_ids", 0); + + at::CheckedFrom c{"merge_sampler_outputs"}; + at::checkAllDefined(c, {node_ids_args}); + at::checkAllDefined(c, {edge_ids_args}); + + TORCH_CHECK(partition_ids.size() == partition_orders.size(), + "Every partition ID must be assigned a sampling order"); + + if (disjoint) { + TORCH_CHECK(batch.has_value(), + "Disjoint sampling requires 'batch' to be specified"); + } + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::merge_sampler_outputs", "") + .typed(); + return op.call(node_ids, edge_ids, cumsum_neighbors_per_node, partition_ids, + partition_orders, num_partitions, num_neighbors, batch, + disjoint); +} + +TORCH_LIBRARY_FRAGMENT(pyg, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::merge_sampler_outputs(Tensor[] node_ids, Tensor[] edge_ids, " + "int[][] cumsum_neighbors_per_node, int[] partition_ids, int[] " + "partition_orders, int num_partitions, int num_neighbors, Tensor? " + "batch, bool disjoint) -> (Tensor, Tensor, Tensor?, int[])")); +} + +} // namespace sampler +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_merge_outputs.h b/pyg_lib/csrc/sampler/dist_merge_outputs.h new file mode 100644 index 00000000..f68fa0bd --- /dev/null +++ b/pyg_lib/csrc/sampler/dist_merge_outputs.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include "pyg_lib/csrc/macros.h" +#include "pyg_lib/csrc/utils/types.h" + +namespace pyg { +namespace sampler { + +// For distributed training purposes. Merges sampler outputs from different +// partitions, so that they are sorted according to the sampling order. +// Removes seed nodes from sampled nodes and calculates how many neighbors +// were sampled by each source node based on the cummulative sum of sampled +// neighbors for each input node. +// Returns the unified node, edge and batch indices as well as the merged +// cummulative sum of sampled neighbors. +PYG_API +std::tuple, + std::vector> +merge_sampler_outputs( + const std::vector& node_ids, + const std::vector& edge_ids, + const std::vector>& cumsum_neighbors_per_node, + const std::vector& partition_ids, + const std::vector& partition_orders, + const int64_t num_partitions, + const int64_t num_neighbors, + const c10::optional& batch = c10::nullopt, + bool disjoint = false); + +} // namespace sampler +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index b589a0a1..66282e38 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -61,6 +61,11 @@ hetero_neighbor_sample( std::string strategy = "uniform", bool return_edge_id = true); +// For distributed sampling purposes. Leverages the `neighbor_sample` function +// internally. Samples one-hop neighborhoods with duplicates from all node +// indices in `seed` in the graph given by `(rowptr, col)`. +// Returns the original node and edge indices for all sampled nodes and edges. +// Lastly, returns the cummulative sum of sampled neighbors for each input node. PYG_API std::tuple> dist_neighbor_sample( const at::Tensor& rowptr, diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 2087628e..dd0cc1a0 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -165,50 +165,6 @@ def hetero_neighbor_sample( num_nodes_per_hop_dict, num_edges_per_hop_dict) -def dist_neighbor_sample( - rowptr: Tensor, - col: Tensor, - seed: Tensor, - num_neighbors: int, - time: Optional[Tensor] = None, - seed_time: Optional[Tensor] = None, - edge_weight: Optional[Tensor] = None, - csc: bool = False, - replace: bool = False, - directed: bool = True, - disjoint: bool = False, - temporal_strategy: str = 'uniform', -) -> Tuple[Tensor, Tensor, List[int]]: - r"""For distributed sampling purpose. Leverages the - :meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from - all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. - - Args: - num_neighbors (int): Maximum number of neighbors to sample in the - current layer. - kwargs: Arguments of :meth:`neighbor_sample`. - - Returns: - (torch.Tensor, torch.Tensor, List[int]): Returns the original node and - edge indices for all sampled nodes and edges. Lastly, returns the - cummulative sum of the amount of sampled neighbors for each input node. - """ - return torch.ops.pyg.dist_neighbor_sample( - rowptr, - col, - seed, - num_neighbors, - time, - seed_time, - edge_weight, - csc, - replace, - directed, - disjoint, - temporal_strategy, - ) - - def subgraph( rowptr: Tensor, col: Tensor, @@ -262,7 +218,6 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, __all__ = [ 'neighbor_sample', 'hetero_neighbor_sample', - 'dist_neighbor_sample', 'subgraph', 'random_walk', ] diff --git a/test/csrc/sampler/test_dist_merge_outputs.cpp b/test/csrc/sampler/test_dist_merge_outputs.cpp new file mode 100644 index 00000000..33c093d6 --- /dev/null +++ b/test/csrc/sampler/test_dist_merge_outputs.cpp @@ -0,0 +1,129 @@ +#include +#include + +#include "pyg_lib/csrc/sampler/dist_merge_outputs.h" +#include "pyg_lib/csrc/utils/types.h" + +TEST(DistMergeOutputsTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + // seed = {0, 1, 2, 3} + const std::vector node_ids = { + at::tensor({2, 7, 8}, options), + at::tensor({0, 1, 4, 5, 6}, options), + at::tensor({3, 9, 10}, options), + }; + const std::vector edge_ids = { + at::tensor({17, 18}, options), + at::tensor({14, 15, 16}, options), + at::tensor({19, 20}, options), + }; + + const std::vector> cumsum_neighbors_per_node = { + {1, 3}, {2, 4, 5}, {1, 3}}; + const std::vector partition_ids = {1, 1, 0, 2}; + const std::vector partition_orders = {0, 1, 0, 0}; + + auto out = pyg::sampler::merge_sampler_outputs( + /*node_ids=*/node_ids, + /*edge_ids=*/edge_ids, + /*cumsum_neighbors_per_node=*/cumsum_neighbors_per_node, + /*partition_ids=*/partition_ids, + /*partition_orders=*/partition_orders, + /*num_partitions=*/3, + /*num_neighbors=*/2, + /*batch=*/c10::nullopt, + /*disjoint=*/false); + + auto expected_nodes = at::tensor({4, 5, 6, 7, 8, 9, 10}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); + + const std::vector expected_sampled_neighbors_per_node = {2, 1, 2, 2}; + EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node); +} + +TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + // seed = {0, 1, 2, 3} + const std::vector node_ids = { + at::tensor({2, 7, 8}, options), + at::tensor({0, 1, 4, 5, 6}, options), + at::tensor({3, 9, 10, 11}, options), + }; + const std::vector edge_ids = { + at::tensor({17, 18}, options), + at::tensor({14, 15, 16}, options), + at::tensor({19, 20, 21}, options), + }; + + const std::vector> cumsum_neighbors_per_node = { + {1, 3}, {2, 4, 5}, {1, 4}}; + const std::vector partition_ids = {1, 1, 0, 2}; + const std::vector partition_orders = {0, 1, 0, 0}; + + auto out = pyg::sampler::merge_sampler_outputs( + /*node_ids=*/node_ids, + /*edge_ids=*/edge_ids, + /*cumsum_neighbors_per_node=*/cumsum_neighbors_per_node, + /*partition_ids=*/partition_ids, + /*partition_orders=*/partition_orders, + /*num_partitions=*/3, + /*num_neighbors=*/-1, + /*batch=*/c10::nullopt, + /*disjoint=*/false); + + auto expected_nodes = at::tensor({4, 5, 6, 7, 8, 9, 10, 11}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20, 21}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); + + const std::vector expected_sampled_neighbors_per_node = {2, 1, 2, 3}; + EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node); +} + +TEST(DistDisjointMergeOutputsTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + // seed = {0, 1, 2, 3} + const std::vector node_ids = { + at::tensor({2, 7, 8}, options), + at::tensor({0, 1, 4, 5, 6}, options), + at::tensor({3, 9, 10}, options), + }; + const std::vector edge_ids = { + at::tensor({17, 18}, options), + at::tensor({14, 15, 16}, options), + at::tensor({19, 20}, options), + }; + const auto batch = at::tensor({0, 1, 2, 3}, options); + + const std::vector> cumsum_neighbors_per_node = { + {1, 3}, {2, 4, 5}, {1, 3}}; + const std::vector partition_ids = {1, 1, 0, 2}; + const std::vector partition_orders = {0, 1, 0, 0}; + + auto out = pyg::sampler::merge_sampler_outputs( + /*node_ids=*/node_ids, + /*edge_ids=*/edge_ids, + /*cumsum_neighbors_per_node=*/cumsum_neighbors_per_node, + /*partition_ids=*/partition_ids, + /*partition_orders=*/partition_orders, + /*num_partitions=*/3, + /*num_neighbors=*/2, + /*batch=*/batch, + /*disjoint=*/true); + + auto expected_nodes = at::tensor({4, 5, 6, 7, 8, 9, 10}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_batch = at::tensor({0, 0, 1, 2, 2, 3, 3}, options); + EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_batch)); + + const std::vector expected_sampled_neighbors_per_node = {2, 1, 2, 2}; + EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node); +}