From f7ef720f01482399dcb230fe60e62eb5cb9cf942 Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Mon, 21 Aug 2023 11:51:07 +0000 Subject: [PATCH 1/7] add support for dist sampler --- .../csrc/sampler/cpu/dist_neighbor_kernel.cpp | 261 +++++++++++ .../csrc/sampler/cpu/dist_neighbor_kernel.h | 31 ++ pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 170 +++++-- pyg_lib/csrc/sampler/cpu/neighbor_kernel.h | 8 +- pyg_lib/csrc/sampler/dist_neighbor.cpp | 86 ++++ pyg_lib/csrc/sampler/dist_neighbor.h | 42 ++ pyg_lib/csrc/sampler/neighbor.cpp | 28 +- pyg_lib/csrc/sampler/neighbor.h | 8 +- pyg_lib/sampler/__init__.py | 116 ++++- test/csrc/sampler/test_dist_neighbor.cpp | 434 ++++++++++++++++++ test/csrc/sampler/test_neighbor.cpp | 15 +- 11 files changed, 1111 insertions(+), 88 deletions(-) create mode 100644 pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp create mode 100644 pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h create mode 100644 pyg_lib/csrc/sampler/dist_neighbor.cpp create mode 100644 pyg_lib/csrc/sampler/dist_neighbor.h create mode 100644 test/csrc/sampler/test_dist_neighbor.cpp diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp new file mode 100644 index 00000000..61f30e2d --- /dev/null +++ b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp @@ -0,0 +1,261 @@ +#include +#include +#include + +#include "parallel_hashmap/phmap.h" + +#include "pyg_lib/csrc/sampler/cpu/mapper.h" +#include "pyg_lib/csrc/utils/cpu/convert.h" +#include "pyg_lib/csrc/utils/types.h" + +namespace pyg { +namespace sampler { + +namespace { + +template +std::tuple get_sampled_edges( + std::vector sampled_rows, + std::vector sampled_cols, + const bool csc = false) { + const auto row = pyg::utils::from_vector(sampled_rows); + const auto col = pyg::utils::from_vector(sampled_cols); + + if (!csc) { + return std::make_tuple(row, col); + } else { + return std::make_tuple(col, row); + } +} + +template +std::tuple relabel( + const at::Tensor& seed, + const at::Tensor& sampled_nodes_with_dupl, + const std::vector& sampled_nbrs_per_node, + const int64_t num_nodes, + const c10::optional& batch, + const bool csc) { + if (disjoint) { + TORCH_CHECK(batch.has_value(), + "Batch needs to be specified to create disjoint subgraphs"); + TORCH_CHECK(batch.value().is_contiguous(), "Non-contiguous 'batch'"); + TORCH_CHECK(batch.value().numel() == sampled_nodes_with_dupl.numel(), + "Each node must belong to a subgraph.'"); + } + TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); + TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), + "Non-contiguous 'sampled_nodes_with_dupl'"); + + at::Tensor out_row, out_col; + + AT_DISPATCH_INTEGRAL_TYPES( + seed.scalar_type(), "relabel_neighborhood_kernel", [&] { + typedef std::pair pair_scalar_t; + typedef std::conditional_t node_t; + + const auto sampled_nodes_data = + sampled_nodes_with_dupl.data_ptr(); + const auto batch_data = + !disjoint ? nullptr : batch.value().data_ptr(); + + std::vector sampled_rows; + std::vector sampled_cols; + auto mapper = Mapper(num_nodes); + + const auto seed_data = seed.data_ptr(); + if constexpr (!disjoint) { + mapper.fill(seed); + } else { + for (size_t i = 0; i < seed.numel(); ++i) { + mapper.insert({i, seed_data[i]}); + } + } + size_t begin = 0, end = 0; + for (auto i = 0; i < sampled_nbrs_per_node.size(); i++) { + end += sampled_nbrs_per_node[i]; + + for (auto j = begin; j < end; j++) { + std::pair res; + if constexpr (!disjoint) + res = mapper.insert(sampled_nodes_data[j]); + else + res = mapper.insert({batch_data[j], sampled_nodes_data[j]}); + sampled_rows.push_back(i); + sampled_cols.push_back(res.first); + } + + begin = end; + } + + std::tie(out_row, out_col) = + get_sampled_edges(sampled_rows, sampled_cols, csc); + }); + + return std::make_tuple(out_row, out_col); +} + +template +std::tuple, c10::Dict> +relabel(const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& seed_dict, + const c10::Dict& sampled_nodes_with_dupl_dict, + const c10::Dict>& + sampled_nbrs_per_node_dict, + const c10::Dict num_nodes_dict, + const c10::optional>& batch_dict, + const bool csc) { + if (disjoint) { + TORCH_CHECK(batch_dict.has_value(), + "Batch needs to be specified to create disjoint subgraphs"); + for (const auto& kv : batch_dict.value()) { + const at::Tensor& batch = kv.value(); + const at::Tensor& sampled_nodes_with_dupl = kv.value(); + TORCH_CHECK(batch.is_contiguous(), "Non-contiguous 'batch'"); + TORCH_CHECK(batch.numel() == sampled_nodes_with_dupl.numel(), + "Each node must belong to a subgraph.'"); + } + } + for (const auto& kv : seed_dict) { + const at::Tensor& seed = kv.value(); + TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); + } + for (const auto& kv : sampled_nodes_with_dupl_dict) { + const at::Tensor& sampled_nodes_with_dupl = kv.value(); + TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), + "Non-contiguous 'sampled_nodes_with_dupl'"); + } + + c10::Dict out_row_dict, out_col_dict; + + AT_DISPATCH_INTEGRAL_TYPES( + seed_dict.begin()->value().scalar_type(), + "hetero_relabel_neighborhood_kernel", [&] { + typedef std::pair pair_scalar_t; + typedef std::conditional_t node_t; + + phmap::flat_hash_map sampled_nodes_data_dict; + phmap::flat_hash_map batch_data_dict; + phmap::flat_hash_map> + sampled_rows_dict; + phmap::flat_hash_map> + sampled_cols_dict; + + phmap::flat_hash_map> mapper_dict; + phmap::flat_hash_map> slice_dict; + + for (const auto& k : edge_types) { + // Initialize empty vectors. + sampled_rows_dict[k]; + sampled_cols_dict[k]; + } + for (const auto& k : node_types) { + sampled_nodes_data_dict.insert( + {k, sampled_nodes_with_dupl_dict.at(k).data_ptr()}); + const auto N = num_nodes_dict.at(k) > 0 ? num_nodes_dict.at(k) : 0; + mapper_dict.insert({k, Mapper(N)}); + slice_dict[k] = {0, 0}; + if constexpr (disjoint) { + batch_data_dict.insert( + {k, batch_dict.value().at(k).data_ptr()}); + } + } + for (const auto& kv : seed_dict) { + const at::Tensor& seed = kv.value(); + if constexpr (!disjoint) { + mapper_dict.at(kv.key()).fill(seed); + } else { + auto& mapper = mapper_dict.at(kv.key()); + const auto seed_data = seed.data_ptr(); + for (size_t i = 0; i < seed.numel(); ++i) { + mapper.insert({i, seed_data[i]}); + } + } + } + for (const auto& k : edge_types) { + const auto src = !csc ? std::get<0>(k) : std::get<2>(k); + const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); + for (auto i = 0; i < sampled_nbrs_per_node_dict.at(dst).size(); i++) { + auto& dst_mapper = mapper_dict.at(dst); + auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst); + slice_dict.at(dst).second += sampled_nbrs_per_node_dict.at(dst)[i]; + size_t begin, end; + std::tie(begin, end) = slice_dict.at(dst); + + for (auto j = begin; j < end; j++) { + std::pair res; + if constexpr (!disjoint) { + res = dst_mapper.insert(dst_sampled_nodes_data[j]); + } else { + res = dst_mapper.insert( + {batch_data_dict.at(dst)[j], dst_sampled_nodes_data[j]}); + } + sampled_rows_dict.at(k).push_back(i); + sampled_cols_dict.at(k).push_back(res.first); + } + slice_dict.at(dst).first = end; + } + } + + for (const auto& k : edge_types) { + const auto edges = get_sampled_edges( + sampled_rows_dict.at(k), sampled_cols_dict.at(k), csc); + out_row_dict.insert(to_rel_type(k), std::get<0>(edges)); + out_col_dict.insert(to_rel_type(k), std::get<1>(edges)); + } + }); + + return std::make_tuple(out_row_dict, out_col_dict); +} + +#define DISPATCH_RELABEL(disjoint, ...) \ + if (disjoint) \ + return relabel(__VA_ARGS__); \ + if (!disjoint) \ + return relabel(__VA_ARGS__); + +} // namespace + +std::tuple relabel_neighborhood_kernel( + const at::Tensor& seed, + const at::Tensor& sampled_nodes_with_dupl, + const std::vector& sampled_nbrs_per_node, + const int64_t num_nodes, + const c10::optional& batch, + bool csc, + bool disjoint) { + DISPATCH_RELABEL(disjoint, seed, sampled_nodes_with_dupl, + sampled_nbrs_per_node, num_nodes, batch, csc); +} + +std::tuple, c10::Dict> +hetero_relabel_neighborhood_kernel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& seed_dict, + const c10::Dict& sampled_nodes_with_dupl_dict, + const c10::Dict>& + sampled_nbrs_per_node_dict, + const c10::Dict num_nodes_dict, + const c10::optional>& batch_dict, + bool csc, + bool disjoint) { + c10::Dict out_row_dict, out_col_dict; + DISPATCH_RELABEL(disjoint, node_types, edge_types, seed_dict, + sampled_nodes_with_dupl_dict, sampled_nbrs_per_node_dict, + num_nodes_dict, batch_dict, csc); +} + +TORCH_LIBRARY_IMPL(pyg, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::relabel_neighborhood"), + TORCH_FN(relabel_neighborhood_kernel)); +} + +TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::hetero_relabel_neighborhood"), + TORCH_FN(hetero_relabel_neighborhood_kernel)); +} + +} // namespace sampler +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h new file mode 100644 index 00000000..9df244fd --- /dev/null +++ b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h @@ -0,0 +1,31 @@ +#include +#include +#include "pyg_lib/csrc/utils/types.h" + +namespace pyg { +namespace sampler { + +std::tuple relabel_neighborhood_kernel( + const at::Tensor& seed, + const at::Tensor& sampled_nodes_with_dupl, + const std::vector& sampled_nbrs_per_node, + const int64_t num_nodes, + const c10::optional& batch, + bool csc, + bool disjoint); + +std::tuple, c10::Dict> +hetero_relabel_neighborhood_kernel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& seed_dict, + const c10::Dict& sampled_nodes_with_dupl_dict, + const c10::Dict>& + sampled_nbrs_per_node_dict, + const c10::Dict num_nodes_dict, + const c10::optional>& batch_dict, + bool csc, + bool disjoint); + +} // namespace sampler +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 9e5cde4e..2571e9bd 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -25,7 +25,8 @@ template + bool save_edge_ids, + bool distributed> class NeighborSampler { public: NeighborSampler(const scalar_t* rowptr, @@ -189,6 +190,15 @@ class NeighborSampler { const auto global_dst_node_value = col_[edge_id]; const auto global_dst_node = to_node_t(global_dst_node_value, global_src_node); + + if constexpr (distributed) { + out_global_dst_nodes.push_back(global_dst_node); + if (save_edge_ids) { + sampled_edge_ids_.push_back(edge_id); + } + return; + } + const auto res = dst_mapper.insert(global_dst_node); if (res.second) { // not yet sampled. out_global_dst_nodes.push_back(global_dst_node); @@ -216,12 +226,17 @@ class NeighborSampler { // Homogeneous neighbor sampling /////////////////////////////////////////////// -template +template std::tuple, std::vector, + std::vector, std::vector> sample(const at::Tensor& rowptr, const at::Tensor& col, @@ -229,6 +244,7 @@ sample(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& batch, const bool csc, const std::string temporal_strategy) { TORCH_CHECK(!time.has_value() || disjoint, @@ -249,6 +265,9 @@ sample(const at::Tensor& rowptr, c10::optional out_edge_id = c10::nullopt; std::vector num_sampled_nodes_per_hop; std::vector num_sampled_edges_per_hop; + std::vector cumm_sum_sampled_nbrs_per_node = + distributed ? std::vector(1, seed.size(0)) + : std::vector(); AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "sample_kernel", [&] { typedef std::pair pair_scalar_t; @@ -256,7 +275,7 @@ sample(const at::Tensor& rowptr, // TODO(zeyuan): Do not force int64_t for time type. typedef int64_t temporal_t; typedef NeighborSampler + return_edge_id, distributed> NeighborSamplerImpl; pyg::random::RandintEngine generator; @@ -273,9 +292,17 @@ sample(const at::Tensor& rowptr, sampled_nodes = pyg::utils::to_vector(seed); mapper.fill(seed); } else { - for (size_t i = 0; i < seed.numel(); ++i) { - sampled_nodes.push_back({i, seed_data[i]}); - mapper.insert({i, seed_data[i]}); + if (batch.has_value()) { + const auto batch_data = batch.value().data_ptr(); + for (size_t i = 0; i < seed.numel(); ++i) { + sampled_nodes.push_back({batch_data[i], seed_data[i]}); + mapper.insert({batch_data[i], seed_data[i]}); + } + } else { + for (size_t i = 0; i < seed.numel(); ++i) { + sampled_nodes.push_back({i, seed_data[i]}); + mapper.insert({i, seed_data[i]}); + } } if (seed_time.has_value()) { const auto seed_time_data = seed_time.value().data_ptr(); @@ -301,6 +328,8 @@ sample(const at::Tensor& rowptr, sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], /*local_src_node=*/i, count, mapper, generator, /*out_global_dst_nodes=*/sampled_nodes); + if constexpr (distributed) + cumm_sum_sampled_nbrs_per_node.push_back(sampled_nodes.size()); } } else if constexpr (!std::is_scalar::value) { // Temporal: const auto time_data = time.value().data_ptr(); @@ -311,6 +340,8 @@ sample(const at::Tensor& rowptr, seed_times[batch_idx], time_data, mapper, generator, /*out_global_dst_nodes=*/sampled_nodes); + if constexpr (distributed) + cumm_sum_sampled_nbrs_per_node.push_back(sampled_nodes.size()); } } begin = end, end = sampled_nodes.size(); @@ -329,12 +360,17 @@ sample(const at::Tensor& rowptr, }); return std::make_tuple(out_row, out_col, out_node_id, out_edge_id, - num_sampled_nodes_per_hop, num_sampled_edges_per_hop); + num_sampled_nodes_per_hop, num_sampled_edges_per_hop, + cumm_sum_sampled_nbrs_per_node); } // Heterogeneous neighbor sampling ///////////////////////////////////////////// -template +template std::tuple, c10::Dict, c10::Dict, @@ -398,7 +434,7 @@ sample(const std::vector& node_types, typedef std::conditional_t node_t; typedef int64_t temporal_t; typedef NeighborSampler + return_edge_id, distributed> NeighborSamplerImpl; pyg::random::RandintEngine generator; @@ -596,39 +632,72 @@ sample(const std::vector& node_types, // Dispatcher ////////////////////////////////////////////////////////////////// -#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, ...) \ - if (replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); +#define DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, \ + distributed, ...) \ + if (replace && directed && disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id && distributed) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id && !distributed) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id && !distributed) \ + return sample(__VA_ARGS__); } // namespace @@ -637,6 +706,7 @@ std::tuple, std::vector, + std::vector, std::vector> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, @@ -644,14 +714,17 @@ neighbor_sample_kernel(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& batch, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, - seed, num_neighbors, time, seed_time, csc, temporal_strategy); + bool return_edge_id, + bool distributed) { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, distributed, + rowptr, col, seed, num_neighbors, time, seed_time, batch, csc, + temporal_strategy); } std::tuple, @@ -674,9 +747,10 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, - edge_types, rowptr_dict, col_dict, seed_dict, + bool return_edge_id, + bool distributed) { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, distributed, + node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, temporal_strategy); } diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 0e0a532f..ef0b743a 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -10,6 +10,7 @@ std::tuple, std::vector, + std::vector, std::vector> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, @@ -17,12 +18,14 @@ neighbor_sample_kernel(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& batch, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id); + bool return_edge_id, + bool distributed); std::tuple, c10::Dict, @@ -44,7 +47,8 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id); + bool return_edge_id, + bool distributed); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_neighbor.cpp b/pyg_lib/csrc/sampler/dist_neighbor.cpp new file mode 100644 index 00000000..68ddee65 --- /dev/null +++ b/pyg_lib/csrc/sampler/dist_neighbor.cpp @@ -0,0 +1,86 @@ +#include "dist_neighbor.h" + +#include +#include + +#include "pyg_lib/csrc/utils/check.h" + +namespace pyg { +namespace sampler { + +std::tuple relabel_neighborhood( + const at::Tensor& seed, + const at::Tensor& sampled_nodes_with_dupl, + const std::vector& sampled_nbrs_per_node, + const int64_t num_nodes, + const c10::optional& batch, + bool csc, + bool disjoint) { + at::TensorArg seed_t{seed, "seed", 1}; + at::TensorArg sampled_nodes_with_dupl_t{sampled_nodes_with_dupl, + "sampled_nodes_with_dupl", 1}; + + at::CheckedFrom c = "relabel_neighborhood"; + at::checkAllDefined(c, {sampled_nodes_with_dupl_t, seed_t}); + at::checkAllSameType(c, {sampled_nodes_with_dupl_t, seed_t}); + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::relabel_neighborhood", "") + .typed(); + return op.call(seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, + num_nodes, batch, csc, disjoint); +} + +std::tuple, c10::Dict> +hetero_relabel_neighborhood( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& seed_dict, + const c10::Dict& sampled_nodes_with_dupl_dict, + const c10::Dict>& + sampled_nbrs_per_node_dict, + const c10::Dict num_nodes_dict, + const c10::optional>& batch_dict, + bool csc, + bool disjoint) { + std::vector seed_dict_args; + std::vector sampled_nodes_with_dupl_dict_args; + pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0); + pyg::utils::fill_tensor_args(sampled_nodes_with_dupl_dict_args, + sampled_nodes_with_dupl_dict, + "sampled_nodes_with_dupl_dict", 0); + at::CheckedFrom c{"hetero_relabel_neighborhood"}; + + at::checkAllDefined(c, seed_dict_args); + at::checkAllDefined(c, sampled_nodes_with_dupl_dict_args); + at::checkSameType(c, seed_dict_args[0], sampled_nodes_with_dupl_dict_args[0]); + + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::hetero_relabel_neighborhood", "") + .typed(); + return op.call(node_types, edge_types, seed_dict, + sampled_nodes_with_dupl_dict, sampled_nbrs_per_node_dict, + num_nodes_dict, batch_dict, csc, disjoint); +} + +TORCH_LIBRARY_FRAGMENT(pyg, m) { + m.def( + TORCH_SELECTIVE_SCHEMA("pyg::relabel_neighborhood(Tensor seed, Tensor " + "sampled_nodes_with_dupl, int[] " + "sampled_nbrs_per_node, int num_nodes, Tensor? " + "batch = None, bool csc = False, bool " + "disjoint = False) " + "-> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::hetero_relabel_neighborhood(str[] node_types, (str, str, str)[] " + "edge_types, Dict(str, Tensor) seed_dict, Dict(str, Tensor) " + "sampled_nodes_with_dupl_dict, Dict(str, int[]) " + "sampled_nbrs_per_node_dict, Dict(str, int) num_nodes_dict, Dict(str, " + "Tensor)? batch_dict = None, bool csc = False, bool " + "disjoint = False) " + "-> (Dict(str, Tensor), Dict(str, Tensor))")); +} + +} // namespace sampler +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_neighbor.h b/pyg_lib/csrc/sampler/dist_neighbor.h new file mode 100644 index 00000000..51e520cd --- /dev/null +++ b/pyg_lib/csrc/sampler/dist_neighbor.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include "pyg_lib/csrc/macros.h" +#include "pyg_lib/csrc/utils/types.h" + +namespace pyg { +namespace sampler { + +// Relabel global indices of the `sampled_nodes_with_dupl` to the local +// subtree/subgraph indices. +// Returns (row, col). +PYG_API +std::tuple relabel_neighborhood( + const at::Tensor& seed, + const at::Tensor& sampled_nodes_with_dupl, + const std::vector& sampled_nbrs_per_node, + const int64_t num_nodes, + const c10::optional& batch = c10::nullopt, + bool csc = false, + bool disjoint = false); + +// Relabel global indices of the `sampled_nodes_with_dupl` to the local +// subtree/subgraph indices in the heterogeneous graph. +// Returns src and dst indices for a given edge type as a (row_dict, col_dict). +PYG_API +std::tuple, c10::Dict> +hetero_relabel_neighborhood( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& seed_dict, + const c10::Dict& sampled_nodes_with_dupl_dict, + const c10::Dict>& + sampled_nbrs_per_node_dict, + const c10::Dict num_nodes_dict, + const c10::optional>& batch_dict = + c10::nullopt, + bool csc = false, + bool disjoint = false); + +} // namespace sampler +} // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index f0550b78..ad3a4895 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -13,6 +13,7 @@ std::tuple, std::vector, + std::vector, std::vector> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, @@ -20,12 +21,14 @@ neighbor_sample(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& batch, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { + bool return_edge_id, + bool distributed) { at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg seed_t{seed, "seed", 1}; @@ -37,9 +40,9 @@ neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, time, seed_time, csc, - replace, directed, disjoint, temporal_strategy, - return_edge_id); + return op.call(rowptr, col, seed, num_neighbors, time, seed_time, batch, csc, + replace, directed, disjoint, temporal_strategy, return_edge_id, + distributed); } std::tuple, @@ -62,7 +65,8 @@ hetero_neighbor_sample( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { + bool return_edge_id, + bool distributed) { TORCH_CHECK(rowptr_dict.size() == col_dict.size(), "Number of edge types in 'rowptr_dict' and 'col_dict' must match") @@ -88,16 +92,19 @@ hetero_neighbor_sample( .typed(); return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, replace, - directed, disjoint, temporal_strategy, return_edge_id); + directed, disjoint, temporal_strategy, return_edge_id, + distributed); } TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " + "batch = None, bool csc " "= False, bool replace = False, bool directed = True, bool disjoint = " - "False, str temporal_strategy = 'uniform', bool return_edge_id = True) " - "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])")); + "False, str temporal_strategy = 'uniform', bool return_edge_id = True, " + "bool distributed = False) " + "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " @@ -105,7 +112,8 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " "= None, bool csc = False, bool replace = False, bool directed = True, " "bool disjoint = False, str temporal_strategy = 'uniform', bool " - "return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), " + "return_edge_id = True, bool distributed = False) -> (Dict(str, Tensor), " + "Dict(str, Tensor), " "Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), " "Dict(str, int[]))")); } diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 55114450..7e794878 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -16,6 +16,7 @@ std::tuple, std::vector, + std::vector, std::vector> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, @@ -23,12 +24,14 @@ neighbor_sample(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time = c10::nullopt, const c10::optional& seed_time = c10::nullopt, + const c10::optional& batch = c10::nullopt, bool csc = false, bool replace = false, bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true); + bool return_edge_id = true, + bool distributed = false); // Recursively samples neighbors from all node indices in `seed_dict` // in the heterogeneous graph given by `(rowptr_dict, col_dict)`. @@ -56,7 +59,8 @@ hetero_neighbor_sample( bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true); + bool return_edge_id = true, + bool distributed = false); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 2cfea286..7a816c1a 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -15,13 +15,16 @@ def neighbor_sample( num_neighbors: List[int], time: Optional[Tensor] = None, seed_time: Optional[Tensor] = None, + batch: Optional[Tensor] = None, csc: bool = False, replace: bool = False, directed: bool = True, disjoint: bool = False, temporal_strategy: str = 'uniform', return_edge_id: bool = True, -) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int]]: + distributed: bool = False, +) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int], + List[int]]: r"""Recursively samples neighbors from all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. @@ -48,6 +51,9 @@ def neighbor_sample( seed_time (torch.Tensor, optional): Optional values to override the timestamp for seed nodes. If not set, will use timestamps in :obj:`time` as default for seed nodes. (default: :obj:`None`) + batch (torch.Tensor, optional): Optional values to specify the + initial subgraph indices for seed nodes. If not set, will use + incremental values starting from 0. (default: :obj:`None`) csc (bool, optional): If set to :obj:`True`, assumes that the graph is given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`) replace (bool, optional): If set to :obj:`True`, will sample with @@ -62,20 +68,25 @@ def neighbor_sample( return_edge_id (bool, optional): If set to :obj:`False`, will not return the indices of edges of the original graph. (default: :obj: `True`) + distributed (bool, optional): If set to :obj:`True`, will sample nodes + with duplicates, save information about the number of sampled + neighbors per node and will not return rows and cols. + This argument was added for the purpose of a distributed training. Returns: (torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], - List[int], List[int]): + List[int], List[int], List[int]): Row indices, col indices of the returned subtree/subgraph, as well as original node indices for all nodes sampled. In addition, may return the indices of edges of the original graph. Lastly, returns information about the sampled amount of nodes and edges - per hop. + per hop and if `distributed` will return cummulative sum of the sampled + neighbors per node. """ return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors, - time, seed_time, csc, replace, + time, seed_time, batch, csc, replace, directed, disjoint, temporal_strategy, - return_edge_id) + return_edge_id, distributed) def hetero_neighbor_sample( @@ -91,9 +102,10 @@ def hetero_neighbor_sample( disjoint: bool = False, temporal_strategy: str = 'uniform', return_edge_id: bool = True, + distributed: bool = False, ) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[ NodeType, Tensor], Optional[Dict[EdgeType, Tensor]], Dict[ - NodeType, List[int]], Dict[NodeType, List[int]]]: + NodeType, List[int]], Dict[EdgeType, List[int]]]: r"""Recursively samples neighbors from all node indices in :obj:`seed_dict` in the heterogeneous graph given by :obj:`(rowptr_dict, col_dict)`. @@ -121,21 +133,9 @@ def hetero_neighbor_sample( } out = torch.ops.pyg.hetero_neighbor_sample( - node_types, - edge_types, - rowptr_dict, - col_dict, - seed_dict, - num_neighbors_dict, - time_dict, - seed_time_dict, - csc, - replace, - directed, - disjoint, - temporal_strategy, - return_edge_id, - ) + node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, time_dict, seed_time_dict, csc, replace, directed, + disjoint, temporal_strategy, return_edge_id, distributed) (row_dict, col_dict, node_id_dict, edge_id_dict, num_nodes_per_hop_dict, num_edges_per_hop_dict) = out @@ -205,9 +205,83 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, return torch.ops.pyg.random_walk(rowptr, col, seed, walk_length, p, q) +def relabel_neighborhood( + seed: Tensor, + sampled_nodes_with_dupl: Tensor, + sampled_nbrs_per_node: List[int], + num_nodes: int, + batch: Optional[Tensor] = None, + csc: bool = False, + disjoint: bool = False, +) -> Tuple[Tensor, Tensor]: + r"""Relabel global indices of the :obj:`sampled_nodes_with_dupl` to the + local subtree/subgraph indices. + + .. note:: + + For :obj:`disjoint`, the :obj:`batch` needs to be specified + and each node from :obj:`sampled_nodes_with_dupl` must be assigned + to a subgraph. + + Args: + seed (torch.Tensor): The seed node indices. + sampled_nodes_with_dupl (torch.Tensor): Sampled nodes with duplicates. + Should not include seed nodes. + sampled_nbrs_per_node (List[int]): The number of neighbors sampled by + each node from :obj:`sampled_nodes_with_dupl`. + num_nodes (int): Number of all nodes in a graph. + batch (torch.Tensor, optional): Stores information about which subgraph + the node from :obj:`sampled_nodes_with_dupl` belongs to. + Must be specified when :obj:`disjoint`. (default: :obj:`None`) + csc (bool, optional): If set to :obj:`True`, assumes that the graph is + given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`) + disjoint (bool, optional): If set to :obj:`True` , will create disjoint + subgraphs for every seed node. (default: :obj:`False`) + + Returns: + (torch.Tensor, torch.Tensor): + Row indices, col indices of the returned subtree/subgraph. + """ + return torch.ops.pyg.relabel_neighborhood(seed, sampled_nodes_with_dupl, + sampled_nbrs_per_node, num_nodes, + batch, csc, disjoint) + + +def hetero_relabel_neighborhood( + edge_types: List[EdgeType], seed_dict: Dict[NodeType, Tensor], + sampled_nodes_with_dupl_dict: Dict[NodeType, Tensor], + sampled_nbrs_per_node_dict: Dict[NodeType, + List[int]], num_nodes_dict: Dict[NodeType, + int], + batch_dict: Optional[Dict[NodeType, Tensor]] = None, csc: bool = False, + disjoint: bool = False +) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor]]: + r"""Relabel global indices of the :obj:`sampled_nodes_with_dupl` to the + local subtree/subgraph indices in the heterogeneous graph. + + .. note :: + Similar to :meth:`relabel_neighborhood`, but expects a dictionary of + node types (:obj:`str`) and edge types (:obj:`Tuple[str, str, str]`) + for each non-boolean argument. + + Args: + kwargs: Arguments of :meth:`relabel_neighborhood`. + """ + + src_node_types = {k[0] for k in sampled_nodes_with_dupl_dict.keys()} + dst_node_types = {k[-1] for k in sampled_nodes_with_dupl_dict.keys()} + node_types = list(src_node_types | dst_node_types) + + return torch.ops.pyg.hetero_relabel_neighborhood( + node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, + sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, csc, disjoint) + + __all__ = [ 'neighbor_sample', 'hetero_neighbor_sample', 'subgraph', 'random_walk', + 'relabel_neighborhood', + 'hetero_relabel_neighborhood', ] diff --git a/test/csrc/sampler/test_dist_neighbor.cpp b/test/csrc/sampler/test_dist_neighbor.cpp new file mode 100644 index 00000000..72a47b01 --- /dev/null +++ b/test/csrc/sampler/test_dist_neighbor.cpp @@ -0,0 +1,434 @@ +#include +#include + +#include "pyg_lib/csrc/sampler/dist_neighbor.h" +#include "pyg_lib/csrc/sampler/neighbor.h" +#include "pyg_lib/csrc/utils/types.h" +#include "test/csrc/graph.h" + +TEST(FullDistNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + std::vector num_neighbors = {-1}; + + auto out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), seed, num_neighbors, + /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, + /*batch=*/c10::nullopt, /*csc*/ false, /*replace=*/false, + /*directed=*/true, /*disjoint=*/false, + /*temporal_strategy=*/"uniform", /*return_edge_id=*/true, + /*distributed=*/true); + + // do not sample rows and cols + EXPECT_EQ(std::get<0>(out).numel(), 0); + EXPECT_EQ(std::get<1>(out).numel(), 0); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({2, 3, 1, 3, 2, 4}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 5, 6, 7}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); + + std::vector sampled_nbrs_per_node = {2, 2}; + // without seed nodes + auto sampled_nodes_with_dupl = at::tensor({1, 3, 2, 4}, options); + + // get rows and cols + auto relabel_out = pyg::sampler::relabel_neighborhood( + seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes); + + auto expected_row = at::tensor({0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); + auto expected_col = at::tensor({2, 1, 0, 3}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); + + // check if rows and cols are the same as for the classic sampling + auto non_dist_out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), seed, num_neighbors, + /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, + /*batch=*/c10::nullopt, /*csc*/ false, /*replace=*/false, + /*directed=*/true, /*disjoint=*/false, + /*temporal_strategy=*/"uniform", + /*return_edge_id=*/true, /*distributed=*/false); + + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); +} + +TEST(WithoutReplacementNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + std::vector num_neighbors = {1}; + + at::manual_seed(123456); + auto out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), seed, num_neighbors, + /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, + /*batch=*/c10::nullopt, /*csc*/ false, /*replace=*/false, + /*directed=*/true, /*disjoint=*/false, + /*temporal_strategy=*/"uniform", /*return_edge_id=*/true, + /*distributed=*/true); + + // do not sample rows and cols + EXPECT_EQ(std::get<0>(out).numel(), 0); + EXPECT_EQ(std::get<1>(out).numel(), 0); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({2, 3, 1, 4}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 7}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 3, 4}; + EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); + + std::vector sampled_nbrs_per_node = {1, 1}; + // without seed nodes + auto sampled_nodes_with_dupl = at::tensor({1, 4}, options); + + // get rows and cols + auto relabel_out = pyg::sampler::relabel_neighborhood( + seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes); + + auto expected_row = at::tensor({0, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); + auto expected_col = at::tensor({2, 3}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); +} + +TEST(WithReplacementNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + std::vector num_neighbors = {2}; + + at::manual_seed(123456); + auto out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, + num_neighbors, /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, + /*csc*/ false, /*replace=*/true, /*directed=*/true, + /*disjoint=*/false, /*temporal_strategy=*/"uniform", + /*return_edge_id=*/true, /*distributed=*/true); + + // do not sample rows and cols + EXPECT_EQ(std::get<0>(out).numel(), 0); + EXPECT_EQ(std::get<1>(out).numel(), 0); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({2, 3, 1, 3, 4, 4}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 5, 7, 7}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); + + std::vector sampled_nbrs_per_node = {2, 2}; + // without seed nodes + auto sampled_nodes_with_dupl = at::tensor({1, 3, 4, 4}, options); + + // get rows and cols + auto relabel_out = pyg::sampler::relabel_neighborhood( + seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes); + + auto expected_row = at::tensor({0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); + auto expected_col = at::tensor({2, 1, 3, 3}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); +} + +TEST(DistDisjointNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + std::vector num_neighbors = {2}; + auto batch = at::tensor({0, 1}, options); + + auto out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, + num_neighbors, /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, batch, /*csc*/ false, + /*replace=*/false, /*directed=*/true, /*disjoint=*/true, + /*temporal_strategy=*/"uniform", /*return_edge_id=*/true, + /*distributed=*/true); + + // do not sample rows and cols + EXPECT_EQ(std::get<0>(out).numel(), 0); + EXPECT_EQ(std::get<1>(out).numel(), 0); + + // sample nodes with duplicates + auto expected_nodes = + at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); + + auto expected_edges = at::tensor({4, 5, 6, 7}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); + + std::vector sampled_nbrs_per_node = {2, 2}; + // without seed nodes + auto sampled_nodes_with_dupl = at::tensor({1, 3, 2, 4}, options); + auto sampled_batch = at::tensor({0, 0, 1, 1}, options); + + // get rows and cols + auto relabel_out = pyg::sampler::relabel_neighborhood( + seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes, + sampled_batch, /*csc=*/false, /*disjoint=*/true); + + auto expected_row = at::tensor({0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); + auto expected_col = at::tensor({2, 3, 4, 5}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); + + // check if rows and cols are the same as for the classic sampling + auto non_dist_out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), seed, num_neighbors, + /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, batch, /*csc*/ false, + /*replace=*/false, + /*directed=*/true, /*disjoint=*/true, + /*temporal_strategy=*/"uniform", + /*return_edge_id=*/true, /*distributed=*/false); + + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); +} + +TEST(DistTemporalNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto rowptr = std::get<0>(graph); + auto col = std::get<1>(graph); + + auto seed = at::arange(2, 4, options); + std::vector num_neighbors = {2}; + + // Time is equal to node ID ... + auto time = at::arange(6, options); + // ... so we need to sort the column vector by time/node ID: + col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); + + auto out = pyg::sampler::neighbor_sample( + rowptr, col, seed, num_neighbors, time, + /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, + /*csc*/ false, /*replace=*/false, /*directed=*/true, + /*disjoint=*/true, /*temporal_strategy=*/"uniform", + /*return_edge_id=*/true, /*distributed=*/true); + + // do not sample rows and cols + EXPECT_EQ(std::get<0>(out).numel(), 0); + EXPECT_EQ(std::get<1>(out).numel(), 0); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({0, 2, 1, 3, 0, 1, 1, 2}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); + + auto expected_edges = at::tensor({4, 6}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 3, 4}; + EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); + + std::vector sampled_nbrs_per_node = {1, 1}; + // without seed nodes + auto sampled_nodes_with_dupl = at::tensor({1, 2}, options); + auto sampled_batch = at::tensor({0, 1}, options); + + // get rows and cols + auto relabel_out = pyg::sampler::relabel_neighborhood( + seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes, + sampled_batch, /*csc=*/false, /*disjoint=*/true); + + auto expected_row = at::tensor({0, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); + auto expected_col = at::tensor({2, 3}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); + + // check if rows and cols are the same as for the classic sampling + auto non_dist_out = pyg::sampler::neighbor_sample( + rowptr, col, seed, num_neighbors, time, /*seed_time=*/c10::nullopt, + /*batch=*/c10::nullopt, /*csc*/ false, + /*replace=*/false, /*directed=*/true, + /*disjoint=*/true, /*temporal_strategy=*/"uniform", + /*return_edge_id=*/true, /*distributed=*/false); + + EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); +} + +TEST(DistHeteroNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + const auto node_key = "paper"; + const auto edge_key = std::make_tuple("paper", "to", "paper"); + const auto rel_key = "paper__to__paper"; + std::vector node_types = {node_key}; + std::vector edge_types = {edge_key}; + c10::Dict rowptr_dict; + rowptr_dict.insert(rel_key, std::get<0>(graph)); + c10::Dict col_dict; + col_dict.insert(rel_key, std::get<1>(graph)); + c10::Dict seed_dict; + seed_dict.insert(node_key, at::arange(2, 4, options)); + std::vector num_neighbors = {2}; + c10::Dict> num_neighbors_dict; + num_neighbors_dict.insert(rel_key, num_neighbors); + c10::Dict num_nodes_dict; + num_nodes_dict.insert(node_key, num_nodes); + + c10::Dict sampled_nodes_with_dupl_dict; + c10::Dict> sampled_nbrs_per_node_dict; + sampled_nodes_with_dupl_dict.insert(node_key, + at::tensor({1, 3, 2, 4}, options)); + sampled_nbrs_per_node_dict.insert(node_key, std::vector(2, 2)); + // get rows and cols + auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( + node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, + sampled_nbrs_per_node_dict, num_nodes_dict, + /*batch_dict=*/c10::nullopt, /*csc=*/false, /*disjoint=*/false); + + auto expected_row = at::tensor({0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); + auto expected_col = at::tensor({2, 1, 0, 3}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); + + // check if rows and cols are the same as for the classic sampling + auto non_dist_out = pyg::sampler::hetero_neighbor_sample( + node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict); + + EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), + std::get<0>(non_dist_out).at(rel_key))); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), + std::get<1>(non_dist_out).at(rel_key))); +} + +TEST(DistHeteroCscNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + const auto node_key = "paper"; + const auto edge_key = std::make_tuple("paper", "to", "paper"); + const auto rel_key = "paper__to__paper"; + std::vector node_types = {node_key}; + std::vector edge_types = {edge_key}; + c10::Dict rowptr_dict; + rowptr_dict.insert(rel_key, std::get<0>(graph)); + c10::Dict col_dict; + col_dict.insert(rel_key, std::get<1>(graph)); + c10::Dict seed_dict; + seed_dict.insert(node_key, at::arange(2, 4, options)); + std::vector num_neighbors = {2}; + c10::Dict> num_neighbors_dict; + num_neighbors_dict.insert(rel_key, num_neighbors); + c10::Dict num_nodes_dict; + num_nodes_dict.insert(node_key, num_nodes); + + c10::Dict sampled_nodes_with_dupl_dict; + c10::Dict> sampled_nbrs_per_node_dict; + sampled_nodes_with_dupl_dict.insert(node_key, + at::tensor({1, 3, 2, 4}, options)); + sampled_nbrs_per_node_dict.insert(node_key, std::vector(2, 2)); + // get rows and cols + auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( + node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, + sampled_nbrs_per_node_dict, num_nodes_dict, + /*batch_dict=*/c10::nullopt, /*csc=*/true, /*disjoint=*/false); + + auto expected_row = at::tensor({2, 1, 0, 3}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); + auto expected_col = at::tensor({0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); + + // check if rows and cols are the same as for the classic sampling + auto non_dist_out = pyg::sampler::hetero_neighbor_sample( + node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, + /*time_dict=*/c10::nullopt, + /*seed_time_dict=*/c10::nullopt, /*csc*/ true); + + EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), + std::get<0>(non_dist_out).at(rel_key))); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), + std::get<1>(non_dist_out).at(rel_key))); +} + +TEST(DistHeteroDisjointNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + const auto node_key = "paper"; + const auto edge_key = std::make_tuple("paper", "to", "paper"); + const auto rel_key = "paper__to__paper"; + std::vector node_types = {node_key}; + std::vector edge_types = {edge_key}; + c10::Dict rowptr_dict; + rowptr_dict.insert(rel_key, std::get<0>(graph)); + c10::Dict col_dict; + col_dict.insert(rel_key, std::get<1>(graph)); + c10::Dict seed_dict; + seed_dict.insert(node_key, at::arange(2, 4, options)); + std::vector num_neighbors = {2}; + c10::Dict> num_neighbors_dict; + num_neighbors_dict.insert(rel_key, num_neighbors); + c10::Dict num_nodes_dict; + num_nodes_dict.insert(node_key, num_nodes); + + c10::Dict sampled_nodes_with_dupl_dict; + c10::Dict> sampled_nbrs_per_node_dict; + c10::Dict batch_dict; + sampled_nodes_with_dupl_dict.insert(node_key, + at::tensor({1, 3, 2, 4}, options)); + sampled_nbrs_per_node_dict.insert(node_key, std::vector(2, 2)); + batch_dict.insert(node_key, at::tensor({0, 0, 1, 1}, options)); + // get rows and cols + auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( + node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, + sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, + /*csc=*/false, /*disjoint=*/true); + + auto expected_row = at::tensor({0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); + auto expected_col = at::tensor({2, 3, 4, 5}, options); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); + + // check if rows and cols are the same as for the classic sampling + auto non_dist_out = pyg::sampler::hetero_neighbor_sample( + node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, /*time_dict=*/c10::nullopt, + /*seed_time_dict=*/c10::nullopt, /*csc=*/false, /*replace=*/false, + /*directed=*/true, /*disjoint=*/true); + + EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), + std::get<0>(non_dist_out).at(rel_key))); + EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), + std::get<1>(non_dist_out).at(rel_key))); +} diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 2c3c7570..eb260cc1 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -41,7 +41,8 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false); + /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, + /*replace=*/false); auto expected_row = at::tensor({0, 1, 2, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -64,7 +65,8 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/true); + /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, + /*replace=*/true); auto expected_row = at::tensor({0, 1, 2, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -86,7 +88,8 @@ TEST(DisjointNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, + /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, + /*replace=*/false, /*directed=*/true, /*disjoint=*/true); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5}, options); @@ -116,7 +119,8 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto out1 = pyg::sampler::neighbor_sample( rowptr, col, seed, /*num_neighbors=*/{2, 2}, /*time=*/time, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, + /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, + /*replace=*/false, /*directed=*/true, /*disjoint=*/true); // Expect only the earlier neighbors or the same node to be sampled: @@ -132,7 +136,8 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto out2 = pyg::sampler::neighbor_sample( rowptr, col, seed, /*num_neighbors=*/{1, 2}, /*time=*/time, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, + /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, + /*replace=*/false, /*directed=*/true, /*disjoint=*/true, /*temporal_strategy=*/"last"); EXPECT_TRUE(at::equal(std::get<0>(out1), std::get<0>(out2))); From 83fb5c5b787fed4245bc6a839ee73b910b6bab0d Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Mon, 21 Aug 2023 14:06:02 +0000 Subject: [PATCH 2/7] update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c784c0e..8c13c1da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.3.0] - 2023-MM-DD ### Added +- Added low level support for distributed sampler ([#246](https://github.com/pyg-team/pyg-lib/pull/246)) - Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243)) - Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229)) - Enable `hetero_neighbor_samplee` to work in parallel ([#211](https://github.com/pyg-team/pyg-lib/pull/211)) From e02312afc53f3ab3c51ed5a924bf30212bb9bb02 Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Tue, 22 Aug 2023 11:43:53 +0000 Subject: [PATCH 3/7] apply Damian's comments --- .../csrc/sampler/cpu/dist_neighbor_kernel.cpp | 59 +++++-------------- .../csrc/sampler/cpu/dist_neighbor_kernel.h | 2 +- pyg_lib/csrc/sampler/dist_neighbor.cpp | 47 +++++++++++++-- pyg_lib/csrc/sampler/dist_neighbor.h | 2 +- 4 files changed, 57 insertions(+), 53 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp index 61f30e2d..b9afd6a7 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp @@ -15,8 +15,8 @@ namespace { template std::tuple get_sampled_edges( - std::vector sampled_rows, - std::vector sampled_cols, + const std::vector& sampled_rows, + const std::vector& sampled_cols, const bool csc = false) { const auto row = pyg::utils::from_vector(sampled_rows); const auto col = pyg::utils::from_vector(sampled_cols); @@ -36,23 +36,12 @@ std::tuple relabel( const int64_t num_nodes, const c10::optional& batch, const bool csc) { - if (disjoint) { - TORCH_CHECK(batch.has_value(), - "Batch needs to be specified to create disjoint subgraphs"); - TORCH_CHECK(batch.value().is_contiguous(), "Non-contiguous 'batch'"); - TORCH_CHECK(batch.value().numel() == sampled_nodes_with_dupl.numel(), - "Each node must belong to a subgraph.'"); - } - TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); - TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), - "Non-contiguous 'sampled_nodes_with_dupl'"); - at::Tensor out_row, out_col; AT_DISPATCH_INTEGRAL_TYPES( seed.scalar_type(), "relabel_neighborhood_kernel", [&] { - typedef std::pair pair_scalar_t; - typedef std::conditional_t node_t; + using pair_scalar_t = std::pair; + using node_t = std::conditional_t; const auto sampled_nodes_data = sampled_nodes_with_dupl.data_ptr(); @@ -67,11 +56,12 @@ std::tuple relabel( if constexpr (!disjoint) { mapper.fill(seed); } else { - for (size_t i = 0; i < seed.numel(); ++i) { + for (size_t i = 0; i < seed.numel(); i++) { mapper.insert({i, seed_data[i]}); } } - size_t begin = 0, end = 0; + size_t begin = 0; + size_t end = 0; for (auto i = 0; i < sampled_nbrs_per_node.size(); i++) { end += sampled_nbrs_per_node[i]; @@ -103,37 +93,16 @@ relabel(const std::vector& node_types, const c10::Dict& sampled_nodes_with_dupl_dict, const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict num_nodes_dict, + const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, const bool csc) { - if (disjoint) { - TORCH_CHECK(batch_dict.has_value(), - "Batch needs to be specified to create disjoint subgraphs"); - for (const auto& kv : batch_dict.value()) { - const at::Tensor& batch = kv.value(); - const at::Tensor& sampled_nodes_with_dupl = kv.value(); - TORCH_CHECK(batch.is_contiguous(), "Non-contiguous 'batch'"); - TORCH_CHECK(batch.numel() == sampled_nodes_with_dupl.numel(), - "Each node must belong to a subgraph.'"); - } - } - for (const auto& kv : seed_dict) { - const at::Tensor& seed = kv.value(); - TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); - } - for (const auto& kv : sampled_nodes_with_dupl_dict) { - const at::Tensor& sampled_nodes_with_dupl = kv.value(); - TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), - "Non-contiguous 'sampled_nodes_with_dupl'"); - } - c10::Dict out_row_dict, out_col_dict; AT_DISPATCH_INTEGRAL_TYPES( seed_dict.begin()->value().scalar_type(), "hetero_relabel_neighborhood_kernel", [&] { - typedef std::pair pair_scalar_t; - typedef std::conditional_t node_t; + using pair_scalar_t = std::pair; + using node_t = std::conditional_t; phmap::flat_hash_map sampled_nodes_data_dict; phmap::flat_hash_map batch_data_dict; @@ -168,7 +137,7 @@ relabel(const std::vector& node_types, } else { auto& mapper = mapper_dict.at(kv.key()); const auto seed_data = seed.data_ptr(); - for (size_t i = 0; i < seed.numel(); ++i) { + for (size_t i = 0; i < seed.numel(); i++) { mapper.insert({i, seed_data[i]}); } } @@ -179,9 +148,9 @@ relabel(const std::vector& node_types, for (auto i = 0; i < sampled_nbrs_per_node_dict.at(dst).size(); i++) { auto& dst_mapper = mapper_dict.at(dst); auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst); + slice_dict.at(dst).second += sampled_nbrs_per_node_dict.at(dst)[i]; - size_t begin, end; - std::tie(begin, end) = slice_dict.at(dst); + auto [begin, end] = slice_dict.at(dst); for (auto j = begin; j < end; j++) { std::pair res; @@ -237,7 +206,7 @@ hetero_relabel_neighborhood_kernel( const c10::Dict& sampled_nodes_with_dupl_dict, const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict num_nodes_dict, + const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, bool disjoint) { diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h index 9df244fd..1cd82a07 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h @@ -22,7 +22,7 @@ hetero_relabel_neighborhood_kernel( const c10::Dict& sampled_nodes_with_dupl_dict, const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict num_nodes_dict, + const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, bool disjoint); diff --git a/pyg_lib/csrc/sampler/dist_neighbor.cpp b/pyg_lib/csrc/sampler/dist_neighbor.cpp index 68ddee65..446a9ecb 100644 --- a/pyg_lib/csrc/sampler/dist_neighbor.cpp +++ b/pyg_lib/csrc/sampler/dist_neighbor.cpp @@ -16,13 +16,25 @@ std::tuple relabel_neighborhood( const c10::optional& batch, bool csc, bool disjoint) { - at::TensorArg seed_t{seed, "seed", 1}; - at::TensorArg sampled_nodes_with_dupl_t{sampled_nodes_with_dupl, - "sampled_nodes_with_dupl", 1}; + at::TensorArg seed_args{seed, "seed", 1}; + at::TensorArg sampled_nodes_with_dupl_args{sampled_nodes_with_dupl, + "sampled_nodes_with_dupl", 1}; at::CheckedFrom c = "relabel_neighborhood"; - at::checkAllDefined(c, {sampled_nodes_with_dupl_t, seed_t}); - at::checkAllSameType(c, {sampled_nodes_with_dupl_t, seed_t}); + at::checkAllDefined(c, {sampled_nodes_with_dupl_args, seed_args}); + at::checkAllSameType(c, {sampled_nodes_with_dupl_args, seed_args}); + + TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); + TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), + "Non-contiguous 'sampled_nodes_with_dupl'"); + + if (disjoint) { + TORCH_CHECK(batch.has_value(), + "Batch needs to be specified to create disjoint subgraphs"); + TORCH_CHECK(batch.value().is_contiguous(), "Non-contiguous 'batch'"); + TORCH_CHECK(batch.value().numel() == sampled_nodes_with_dupl.numel(), + "Each node must belong to a subgraph.'"); + } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::relabel_neighborhood", "") @@ -39,7 +51,7 @@ hetero_relabel_neighborhood( const c10::Dict& sampled_nodes_with_dupl_dict, const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict num_nodes_dict, + const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, bool disjoint) { @@ -55,6 +67,29 @@ hetero_relabel_neighborhood( at::checkAllDefined(c, sampled_nodes_with_dupl_dict_args); at::checkSameType(c, seed_dict_args[0], sampled_nodes_with_dupl_dict_args[0]); + for (const auto& kv : seed_dict) { + const at::Tensor& seed = kv.value(); + TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); + } + for (const auto& kv : sampled_nodes_with_dupl_dict) { + const at::Tensor& sampled_nodes_with_dupl = kv.value(); + TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), + "Non-contiguous 'sampled_nodes_with_dupl'"); + } + + if (disjoint) { + TORCH_CHECK(batch_dict.has_value(), + "Batch needs to be specified to create disjoint subgraphs"); + for (const auto& kv : batch_dict.value()) { + const at::Tensor& batch = kv.value(); + const at::Tensor& sampled_nodes_with_dupl = + sampled_nodes_with_dupl_dict.at(kv.key()); + TORCH_CHECK(batch.is_contiguous(), "Non-contiguous 'batch'"); + TORCH_CHECK(batch.numel() == sampled_nodes_with_dupl.numel(), + "Each node must belong to a subgraph.'"); + } + } + static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::hetero_relabel_neighborhood", "") diff --git a/pyg_lib/csrc/sampler/dist_neighbor.h b/pyg_lib/csrc/sampler/dist_neighbor.h index 51e520cd..4102ed25 100644 --- a/pyg_lib/csrc/sampler/dist_neighbor.h +++ b/pyg_lib/csrc/sampler/dist_neighbor.h @@ -32,7 +32,7 @@ hetero_relabel_neighborhood( const c10::Dict& sampled_nodes_with_dupl_dict, const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict num_nodes_dict, + const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict = c10::nullopt, bool csc = false, From 90e4b8c26df5e3fc4e6c319121ea45c0187c14e8 Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Sat, 2 Sep 2023 10:35:27 +0000 Subject: [PATCH 4/7] parallel hetero + minor changes --- .../csrc/sampler/cpu/dist_neighbor_kernel.cpp | 108 ++++++++++++------ .../csrc/sampler/cpu/dist_neighbor_kernel.h | 3 +- pyg_lib/csrc/sampler/dist_neighbor.cpp | 3 +- pyg_lib/csrc/sampler/dist_neighbor.h | 3 +- pyg_lib/sampler/__init__.py | 25 +++- test/csrc/sampler/test_dist_neighbor.cpp | 12 +- 6 files changed, 105 insertions(+), 49 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp index b9afd6a7..14cde385 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp @@ -32,7 +32,7 @@ template std::tuple relabel( const at::Tensor& seed, const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, + const std::vector& num_sampled_nbrs_per_node, const int64_t num_nodes, const c10::optional& batch, const bool csc) { @@ -62,8 +62,8 @@ std::tuple relabel( } size_t begin = 0; size_t end = 0; - for (auto i = 0; i < sampled_nbrs_per_node.size(); i++) { - end += sampled_nbrs_per_node[i]; + for (auto i = 0; i < num_sampled_nbrs_per_node.size(); i++) { + end += num_sampled_nbrs_per_node[i]; for (auto j = begin; j < end; j++) { std::pair res; @@ -91,8 +91,8 @@ relabel(const std::vector& node_types, const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - sampled_nbrs_per_node_dict, + const c10::Dict>& + num_sampled_nbrs_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, const bool csc) { @@ -114,11 +114,37 @@ relabel(const std::vector& node_types, phmap::flat_hash_map> mapper_dict; phmap::flat_hash_map> slice_dict; + const bool parallel = + at::get_num_threads() > 1 && edge_types.size() > 1; + std::vector> threads_edge_types; + for (const auto& k : edge_types) { // Initialize empty vectors. sampled_rows_dict[k]; sampled_cols_dict[k]; + + if (parallel) { + // Each thread is assigned edge types that have the same dst node + // type. Thanks to this, each thread will operate on a separate + // mapper and separate sampler. + bool added = false; + const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); + for (auto& e : threads_edge_types) { + if ((!csc ? std::get<2>(e[0]) : std::get<0>(e[0])) == dst) { + e.push_back(k); + added = true; + break; + } + } + if (!added) + threads_edge_types.push_back({k}); + } + } + if (!parallel) { + // If not parallel then one thread handles all edge types. + threads_edge_types.push_back({edge_types}); } + for (const auto& k : node_types) { sampled_nodes_data_dict.insert( {k, sampled_nodes_with_dupl_dict.at(k).data_ptr()}); @@ -142,30 +168,46 @@ relabel(const std::vector& node_types, } } } - for (const auto& k : edge_types) { - const auto src = !csc ? std::get<0>(k) : std::get<2>(k); - const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); - for (auto i = 0; i < sampled_nbrs_per_node_dict.at(dst).size(); i++) { - auto& dst_mapper = mapper_dict.at(dst); - auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst); - - slice_dict.at(dst).second += sampled_nbrs_per_node_dict.at(dst)[i]; - auto [begin, end] = slice_dict.at(dst); - - for (auto j = begin; j < end; j++) { - std::pair res; - if constexpr (!disjoint) { - res = dst_mapper.insert(dst_sampled_nodes_data[j]); - } else { - res = dst_mapper.insert( - {batch_data_dict.at(dst)[j], dst_sampled_nodes_data[j]}); + at::parallel_for( + 0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) { + for (auto j = _s; j < _e; j++) { + for (const auto& k : threads_edge_types[j]) { + const auto src = !csc ? std::get<0>(k) : std::get<2>(k); + const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); + + if (num_sampled_nbrs_per_node_dict.at(to_rel_type(k)) + .size() == 0) { + continue; + } + + for (auto i = 0; + i < + num_sampled_nbrs_per_node_dict.at(to_rel_type(k)).size(); + i++) { + auto& dst_mapper = mapper_dict.at(dst); + auto& dst_sampled_nodes_data = + sampled_nodes_data_dict.at(dst); + + slice_dict.at(dst).second += + num_sampled_nbrs_per_node_dict.at(to_rel_type(k))[i]; + auto [begin, end] = slice_dict.at(dst); + + for (auto j = begin; j < end; j++) { + std::pair res; + if constexpr (!disjoint) { + res = dst_mapper.insert(dst_sampled_nodes_data[j]); + } else { + res = dst_mapper.insert({batch_data_dict.at(dst)[j], + dst_sampled_nodes_data[j]}); + } + sampled_rows_dict.at(k).push_back(i); + sampled_cols_dict.at(k).push_back(res.first); + } + slice_dict.at(dst).first = end; + } + } } - sampled_rows_dict.at(k).push_back(i); - sampled_cols_dict.at(k).push_back(res.first); - } - slice_dict.at(dst).first = end; - } - } + }); for (const auto& k : edge_types) { const auto edges = get_sampled_edges( @@ -189,13 +231,13 @@ relabel(const std::vector& node_types, std::tuple relabel_neighborhood_kernel( const at::Tensor& seed, const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, + const std::vector& num_sampled_nbrs_per_node, const int64_t num_nodes, const c10::optional& batch, bool csc, bool disjoint) { DISPATCH_RELABEL(disjoint, seed, sampled_nodes_with_dupl, - sampled_nbrs_per_node, num_nodes, batch, csc); + num_sampled_nbrs_per_node, num_nodes, batch, csc); } std::tuple, c10::Dict> @@ -204,15 +246,15 @@ hetero_relabel_neighborhood_kernel( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - sampled_nbrs_per_node_dict, + const c10::Dict>& + num_sampled_nbrs_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, bool disjoint) { c10::Dict out_row_dict, out_col_dict; DISPATCH_RELABEL(disjoint, node_types, edge_types, seed_dict, - sampled_nodes_with_dupl_dict, sampled_nbrs_per_node_dict, + sampled_nodes_with_dupl_dict, num_sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, csc); } diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h index 1cd82a07..b6d563d5 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h @@ -20,8 +20,7 @@ hetero_relabel_neighborhood_kernel( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - sampled_nbrs_per_node_dict, + const c10::Dict>& sampled_nbrs_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, diff --git a/pyg_lib/csrc/sampler/dist_neighbor.cpp b/pyg_lib/csrc/sampler/dist_neighbor.cpp index 446a9ecb..b7689c01 100644 --- a/pyg_lib/csrc/sampler/dist_neighbor.cpp +++ b/pyg_lib/csrc/sampler/dist_neighbor.cpp @@ -49,8 +49,7 @@ hetero_relabel_neighborhood( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - sampled_nbrs_per_node_dict, + const c10::Dict>& sampled_nbrs_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, diff --git a/pyg_lib/csrc/sampler/dist_neighbor.h b/pyg_lib/csrc/sampler/dist_neighbor.h index 4102ed25..134fa96f 100644 --- a/pyg_lib/csrc/sampler/dist_neighbor.h +++ b/pyg_lib/csrc/sampler/dist_neighbor.h @@ -30,8 +30,7 @@ hetero_relabel_neighborhood( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - sampled_nbrs_per_node_dict, + const c10::Dict>& sampled_nbrs_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict = c10::nullopt, diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 7a816c1a..92290fd2 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -248,9 +248,10 @@ def relabel_neighborhood( def hetero_relabel_neighborhood( - edge_types: List[EdgeType], seed_dict: Dict[NodeType, Tensor], - sampled_nodes_with_dupl_dict: Dict[NodeType, Tensor], - sampled_nbrs_per_node_dict: Dict[NodeType, + seed_dict: Dict[NodeType, + Tensor], sampled_nodes_with_dupl_dict: Dict[NodeType, + Tensor], + sampled_nbrs_per_node_dict: Dict[EdgeType, List[int]], num_nodes_dict: Dict[NodeType, int], batch_dict: Optional[Dict[NodeType, Tensor]] = None, csc: bool = False, @@ -271,11 +272,27 @@ def hetero_relabel_neighborhood( src_node_types = {k[0] for k in sampled_nodes_with_dupl_dict.keys()} dst_node_types = {k[-1] for k in sampled_nodes_with_dupl_dict.keys()} node_types = list(src_node_types | dst_node_types) + edge_types = list(sampled_nbrs_per_node_dict.keys()) - return torch.ops.pyg.hetero_relabel_neighborhood( + TO_REL_TYPE = {key: '__'.join(key) for key in edge_types} + TO_EDGE_TYPE = {'__'.join(key): key for key in edge_types} + + sampled_nbrs_per_node_dict = { + TO_REL_TYPE[k]: v + for k, v in sampled_nbrs_per_node_dict.items() + } + + out = torch.ops.pyg.hetero_relabel_neighborhood( node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, csc, disjoint) + (row_dict, col_dict) = out + + row_dict = {TO_EDGE_TYPE[k]: v for k, v in row_dict.items()} + col_dict = {TO_EDGE_TYPE[k]: v for k, v in col_dict.items()} + + return (row_dict, col_dict) + __all__ = [ 'neighbor_sample', diff --git a/test/csrc/sampler/test_dist_neighbor.cpp b/test/csrc/sampler/test_dist_neighbor.cpp index 72a47b01..c22f5c18 100644 --- a/test/csrc/sampler/test_dist_neighbor.cpp +++ b/test/csrc/sampler/test_dist_neighbor.cpp @@ -303,10 +303,10 @@ TEST(DistHeteroNeighborTest, BasicAssertions) { num_nodes_dict.insert(node_key, num_nodes); c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; + c10::Dict> sampled_nbrs_per_node_dict; sampled_nodes_with_dupl_dict.insert(node_key, at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(node_key, std::vector(2, 2)); + sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); // get rows and cols auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, @@ -352,10 +352,10 @@ TEST(DistHeteroCscNeighborTest, BasicAssertions) { num_nodes_dict.insert(node_key, num_nodes); c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; + c10::Dict> sampled_nbrs_per_node_dict; sampled_nodes_with_dupl_dict.insert(node_key, at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(node_key, std::vector(2, 2)); + sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); // get rows and cols auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, @@ -403,11 +403,11 @@ TEST(DistHeteroDisjointNeighborTest, BasicAssertions) { num_nodes_dict.insert(node_key, num_nodes); c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; + c10::Dict> sampled_nbrs_per_node_dict; c10::Dict batch_dict; sampled_nodes_with_dupl_dict.insert(node_key, at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(node_key, std::vector(2, 2)); + sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); batch_dict.insert(node_key, at::tensor({0, 0, 1, 1}, options)); // get rows and cols auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( From 99d971a53256d77e8a1125b0b4f734a12993d078 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 5 Sep 2023 14:11:45 +0000 Subject: [PATCH 5/7] update --- CHANGELOG.md | 2 +- .../csrc/sampler/cpu/dist_neighbor_kernel.cpp | 272 ----------- .../csrc/sampler/cpu/dist_neighbor_kernel.h | 30 -- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 238 ++++++---- pyg_lib/csrc/sampler/cpu/neighbor_kernel.h | 50 +- pyg_lib/csrc/sampler/dist_neighbor.cpp | 120 ----- pyg_lib/csrc/sampler/dist_neighbor.h | 41 -- pyg_lib/csrc/sampler/neighbor.cpp | 129 +++++- pyg_lib/csrc/sampler/neighbor.h | 54 ++- pyg_lib/sampler/__init__.py | 128 +----- test/csrc/sampler/test_dist_neighbor.cpp | 434 ------------------ test/csrc/sampler/test_neighbor.cpp | 15 +- 12 files changed, 374 insertions(+), 1139 deletions(-) delete mode 100644 pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp delete mode 100644 pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h delete mode 100644 pyg_lib/csrc/sampler/dist_neighbor.cpp delete mode 100644 pyg_lib/csrc/sampler/dist_neighbor.h delete mode 100644 test/csrc/sampler/test_dist_neighbor.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c13c1da..3829abbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.3.0] - 2023-MM-DD ### Added -- Added low level support for distributed sampler ([#246](https://github.com/pyg-team/pyg-lib/pull/246)) +- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246)) - Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243)) - Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229)) - Enable `hetero_neighbor_samplee` to work in parallel ([#211](https://github.com/pyg-team/pyg-lib/pull/211)) diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp deleted file mode 100644 index 14cde385..00000000 --- a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.cpp +++ /dev/null @@ -1,272 +0,0 @@ -#include -#include -#include - -#include "parallel_hashmap/phmap.h" - -#include "pyg_lib/csrc/sampler/cpu/mapper.h" -#include "pyg_lib/csrc/utils/cpu/convert.h" -#include "pyg_lib/csrc/utils/types.h" - -namespace pyg { -namespace sampler { - -namespace { - -template -std::tuple get_sampled_edges( - const std::vector& sampled_rows, - const std::vector& sampled_cols, - const bool csc = false) { - const auto row = pyg::utils::from_vector(sampled_rows); - const auto col = pyg::utils::from_vector(sampled_cols); - - if (!csc) { - return std::make_tuple(row, col); - } else { - return std::make_tuple(col, row); - } -} - -template -std::tuple relabel( - const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& num_sampled_nbrs_per_node, - const int64_t num_nodes, - const c10::optional& batch, - const bool csc) { - at::Tensor out_row, out_col; - - AT_DISPATCH_INTEGRAL_TYPES( - seed.scalar_type(), "relabel_neighborhood_kernel", [&] { - using pair_scalar_t = std::pair; - using node_t = std::conditional_t; - - const auto sampled_nodes_data = - sampled_nodes_with_dupl.data_ptr(); - const auto batch_data = - !disjoint ? nullptr : batch.value().data_ptr(); - - std::vector sampled_rows; - std::vector sampled_cols; - auto mapper = Mapper(num_nodes); - - const auto seed_data = seed.data_ptr(); - if constexpr (!disjoint) { - mapper.fill(seed); - } else { - for (size_t i = 0; i < seed.numel(); i++) { - mapper.insert({i, seed_data[i]}); - } - } - size_t begin = 0; - size_t end = 0; - for (auto i = 0; i < num_sampled_nbrs_per_node.size(); i++) { - end += num_sampled_nbrs_per_node[i]; - - for (auto j = begin; j < end; j++) { - std::pair res; - if constexpr (!disjoint) - res = mapper.insert(sampled_nodes_data[j]); - else - res = mapper.insert({batch_data[j], sampled_nodes_data[j]}); - sampled_rows.push_back(i); - sampled_cols.push_back(res.first); - } - - begin = end; - } - - std::tie(out_row, out_col) = - get_sampled_edges(sampled_rows, sampled_cols, csc); - }); - - return std::make_tuple(out_row, out_col); -} - -template -std::tuple, c10::Dict> -relabel(const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - num_sampled_nbrs_per_node_dict, - const c10::Dict& num_nodes_dict, - const c10::optional>& batch_dict, - const bool csc) { - c10::Dict out_row_dict, out_col_dict; - - AT_DISPATCH_INTEGRAL_TYPES( - seed_dict.begin()->value().scalar_type(), - "hetero_relabel_neighborhood_kernel", [&] { - using pair_scalar_t = std::pair; - using node_t = std::conditional_t; - - phmap::flat_hash_map sampled_nodes_data_dict; - phmap::flat_hash_map batch_data_dict; - phmap::flat_hash_map> - sampled_rows_dict; - phmap::flat_hash_map> - sampled_cols_dict; - - phmap::flat_hash_map> mapper_dict; - phmap::flat_hash_map> slice_dict; - - const bool parallel = - at::get_num_threads() > 1 && edge_types.size() > 1; - std::vector> threads_edge_types; - - for (const auto& k : edge_types) { - // Initialize empty vectors. - sampled_rows_dict[k]; - sampled_cols_dict[k]; - - if (parallel) { - // Each thread is assigned edge types that have the same dst node - // type. Thanks to this, each thread will operate on a separate - // mapper and separate sampler. - bool added = false; - const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); - for (auto& e : threads_edge_types) { - if ((!csc ? std::get<2>(e[0]) : std::get<0>(e[0])) == dst) { - e.push_back(k); - added = true; - break; - } - } - if (!added) - threads_edge_types.push_back({k}); - } - } - if (!parallel) { - // If not parallel then one thread handles all edge types. - threads_edge_types.push_back({edge_types}); - } - - for (const auto& k : node_types) { - sampled_nodes_data_dict.insert( - {k, sampled_nodes_with_dupl_dict.at(k).data_ptr()}); - const auto N = num_nodes_dict.at(k) > 0 ? num_nodes_dict.at(k) : 0; - mapper_dict.insert({k, Mapper(N)}); - slice_dict[k] = {0, 0}; - if constexpr (disjoint) { - batch_data_dict.insert( - {k, batch_dict.value().at(k).data_ptr()}); - } - } - for (const auto& kv : seed_dict) { - const at::Tensor& seed = kv.value(); - if constexpr (!disjoint) { - mapper_dict.at(kv.key()).fill(seed); - } else { - auto& mapper = mapper_dict.at(kv.key()); - const auto seed_data = seed.data_ptr(); - for (size_t i = 0; i < seed.numel(); i++) { - mapper.insert({i, seed_data[i]}); - } - } - } - at::parallel_for( - 0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) { - for (auto j = _s; j < _e; j++) { - for (const auto& k : threads_edge_types[j]) { - const auto src = !csc ? std::get<0>(k) : std::get<2>(k); - const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); - - if (num_sampled_nbrs_per_node_dict.at(to_rel_type(k)) - .size() == 0) { - continue; - } - - for (auto i = 0; - i < - num_sampled_nbrs_per_node_dict.at(to_rel_type(k)).size(); - i++) { - auto& dst_mapper = mapper_dict.at(dst); - auto& dst_sampled_nodes_data = - sampled_nodes_data_dict.at(dst); - - slice_dict.at(dst).second += - num_sampled_nbrs_per_node_dict.at(to_rel_type(k))[i]; - auto [begin, end] = slice_dict.at(dst); - - for (auto j = begin; j < end; j++) { - std::pair res; - if constexpr (!disjoint) { - res = dst_mapper.insert(dst_sampled_nodes_data[j]); - } else { - res = dst_mapper.insert({batch_data_dict.at(dst)[j], - dst_sampled_nodes_data[j]}); - } - sampled_rows_dict.at(k).push_back(i); - sampled_cols_dict.at(k).push_back(res.first); - } - slice_dict.at(dst).first = end; - } - } - } - }); - - for (const auto& k : edge_types) { - const auto edges = get_sampled_edges( - sampled_rows_dict.at(k), sampled_cols_dict.at(k), csc); - out_row_dict.insert(to_rel_type(k), std::get<0>(edges)); - out_col_dict.insert(to_rel_type(k), std::get<1>(edges)); - } - }); - - return std::make_tuple(out_row_dict, out_col_dict); -} - -#define DISPATCH_RELABEL(disjoint, ...) \ - if (disjoint) \ - return relabel(__VA_ARGS__); \ - if (!disjoint) \ - return relabel(__VA_ARGS__); - -} // namespace - -std::tuple relabel_neighborhood_kernel( - const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& num_sampled_nbrs_per_node, - const int64_t num_nodes, - const c10::optional& batch, - bool csc, - bool disjoint) { - DISPATCH_RELABEL(disjoint, seed, sampled_nodes_with_dupl, - num_sampled_nbrs_per_node, num_nodes, batch, csc); -} - -std::tuple, c10::Dict> -hetero_relabel_neighborhood_kernel( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - num_sampled_nbrs_per_node_dict, - const c10::Dict& num_nodes_dict, - const c10::optional>& batch_dict, - bool csc, - bool disjoint) { - c10::Dict out_row_dict, out_col_dict; - DISPATCH_RELABEL(disjoint, node_types, edge_types, seed_dict, - sampled_nodes_with_dupl_dict, num_sampled_nbrs_per_node_dict, - num_nodes_dict, batch_dict, csc); -} - -TORCH_LIBRARY_IMPL(pyg, CPU, m) { - m.impl(TORCH_SELECTIVE_NAME("pyg::relabel_neighborhood"), - TORCH_FN(relabel_neighborhood_kernel)); -} - -TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { - m.impl(TORCH_SELECTIVE_NAME("pyg::hetero_relabel_neighborhood"), - TORCH_FN(hetero_relabel_neighborhood_kernel)); -} - -} // namespace sampler -} // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h deleted file mode 100644 index b6d563d5..00000000 --- a/pyg_lib/csrc/sampler/cpu/dist_neighbor_kernel.h +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include -#include "pyg_lib/csrc/utils/types.h" - -namespace pyg { -namespace sampler { - -std::tuple relabel_neighborhood_kernel( - const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, - const int64_t num_nodes, - const c10::optional& batch, - bool csc, - bool disjoint); - -std::tuple, c10::Dict> -hetero_relabel_neighborhood_kernel( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict& num_nodes_dict, - const c10::optional>& batch_dict, - bool csc, - bool disjoint); - -} // namespace sampler -} // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 2571e9bd..15f12f9a 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -191,6 +191,7 @@ class NeighborSampler { const auto global_dst_node = to_node_t(global_dst_node_value, global_src_node); + // In the distributed sampling case, we do not perform any mapping: if constexpr (distributed) { out_global_dst_nodes.push_back(global_dst_node); if (save_edge_ids) { @@ -244,7 +245,6 @@ sample(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, - const c10::optional& batch, const bool csc, const std::string temporal_strategy) { TORCH_CHECK(!time.has_value() || disjoint, @@ -265,7 +265,7 @@ sample(const at::Tensor& rowptr, c10::optional out_edge_id = c10::nullopt; std::vector num_sampled_nodes_per_hop; std::vector num_sampled_edges_per_hop; - std::vector cumm_sum_sampled_nbrs_per_node = + std::vector cumsum_neighbors_per_node = distributed ? std::vector(1, seed.size(0)) : std::vector(); @@ -292,17 +292,9 @@ sample(const at::Tensor& rowptr, sampled_nodes = pyg::utils::to_vector(seed); mapper.fill(seed); } else { - if (batch.has_value()) { - const auto batch_data = batch.value().data_ptr(); - for (size_t i = 0; i < seed.numel(); ++i) { - sampled_nodes.push_back({batch_data[i], seed_data[i]}); - mapper.insert({batch_data[i], seed_data[i]}); - } - } else { - for (size_t i = 0; i < seed.numel(); ++i) { - sampled_nodes.push_back({i, seed_data[i]}); - mapper.insert({i, seed_data[i]}); - } + for (size_t i = 0; i < seed.numel(); ++i) { + sampled_nodes.push_back({i, seed_data[i]}); + mapper.insert({i, seed_data[i]}); } if (seed_time.has_value()) { const auto seed_time_data = seed_time.value().data_ptr(); @@ -329,7 +321,7 @@ sample(const at::Tensor& rowptr, /*local_src_node=*/i, count, mapper, generator, /*out_global_dst_nodes=*/sampled_nodes); if constexpr (distributed) - cumm_sum_sampled_nbrs_per_node.push_back(sampled_nodes.size()); + cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } } else if constexpr (!std::is_scalar::value) { // Temporal: const auto time_data = time.value().data_ptr(); @@ -341,7 +333,7 @@ sample(const at::Tensor& rowptr, generator, /*out_global_dst_nodes=*/sampled_nodes); if constexpr (distributed) - cumm_sum_sampled_nbrs_per_node.push_back(sampled_nodes.size()); + cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } } begin = end, end = sampled_nodes.size(); @@ -361,7 +353,7 @@ sample(const at::Tensor& rowptr, return std::make_tuple(out_row, out_col, out_node_id, out_edge_id, num_sampled_nodes_per_hop, num_sampled_edges_per_hop, - cumm_sum_sampled_nbrs_per_node); + cumsum_neighbors_per_node); } // Heterogeneous neighbor sampling ///////////////////////////////////////////// @@ -632,73 +624,74 @@ sample(const std::vector& node_types, // Dispatcher ////////////////////////////////////////////////////////////////// -#define DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, \ - distributed, ...) \ - if (replace && directed && disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id && distributed) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id && !distributed) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id && !distributed) \ +#define DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \ + if (replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id) \ return sample(__VA_ARGS__); +#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \ + if (replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); + } // namespace std::tuple, std::vector, - std::vector, std::vector> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, @@ -714,17 +706,19 @@ neighbor_sample_kernel(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, - const c10::optional& batch, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool distributed) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, distributed, - rowptr, col, seed, num_neighbors, time, seed_time, batch, csc, - temporal_strategy); + bool return_edge_id) { + const auto out = [&] { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, + seed, num_neighbors, time, seed_time, csc, + temporal_strategy); + }(); + return std::make_tuple(std::get<0>(out), std::get<1>(out), std::get<2>(out), + std::get<3>(out), std::get<4>(out), std::get<5>(out)); } std::tuple, @@ -747,14 +741,64 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool distributed) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, distributed, - node_types, edge_types, rowptr_dict, col_dict, seed_dict, + bool return_edge_id) { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, + edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, temporal_strategy); } +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample_kernel(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time, + const c10::optional& seed_time, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, + seed, num_neighbors, time, seed_time, csc, + temporal_strategy); +} + +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample_kernel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + const c10::optional>& seed_time_dict, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, + edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, time_dict, seed_time_dict, csc, + temporal_strategy); +} + TORCH_LIBRARY_IMPL(pyg, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("pyg::neighbor_sample"), TORCH_FN(neighbor_sample_kernel)); @@ -768,5 +812,15 @@ TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { TORCH_FN(hetero_neighbor_sample_kernel)); } +TORCH_LIBRARY_IMPL(pyg, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::dist_neighbor_sample"), + TORCH_FN(dist_neighbor_sample_kernel)); +} + +TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::dist_hetero_neighbor_sample"), + TORCH_FN(dist_hetero_neighbor_sample_kernel)); +} + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index ef0b743a..f32ab9fb 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -10,7 +10,6 @@ std::tuple, std::vector, - std::vector, std::vector> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, @@ -18,14 +17,12 @@ neighbor_sample_kernel(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, - const c10::optional& batch, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool distributed); + bool return_edge_id); std::tuple, c10::Dict, @@ -47,8 +44,49 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool distributed); + bool return_edge_id); + +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample_kernel(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time, + const c10::optional& seed_time, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id); + +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample_kernel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + const c10::optional>& seed_time_dict, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_neighbor.cpp b/pyg_lib/csrc/sampler/dist_neighbor.cpp deleted file mode 100644 index b7689c01..00000000 --- a/pyg_lib/csrc/sampler/dist_neighbor.cpp +++ /dev/null @@ -1,120 +0,0 @@ -#include "dist_neighbor.h" - -#include -#include - -#include "pyg_lib/csrc/utils/check.h" - -namespace pyg { -namespace sampler { - -std::tuple relabel_neighborhood( - const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, - const int64_t num_nodes, - const c10::optional& batch, - bool csc, - bool disjoint) { - at::TensorArg seed_args{seed, "seed", 1}; - at::TensorArg sampled_nodes_with_dupl_args{sampled_nodes_with_dupl, - "sampled_nodes_with_dupl", 1}; - - at::CheckedFrom c = "relabel_neighborhood"; - at::checkAllDefined(c, {sampled_nodes_with_dupl_args, seed_args}); - at::checkAllSameType(c, {sampled_nodes_with_dupl_args, seed_args}); - - TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); - TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), - "Non-contiguous 'sampled_nodes_with_dupl'"); - - if (disjoint) { - TORCH_CHECK(batch.has_value(), - "Batch needs to be specified to create disjoint subgraphs"); - TORCH_CHECK(batch.value().is_contiguous(), "Non-contiguous 'batch'"); - TORCH_CHECK(batch.value().numel() == sampled_nodes_with_dupl.numel(), - "Each node must belong to a subgraph.'"); - } - - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("pyg::relabel_neighborhood", "") - .typed(); - return op.call(seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, - num_nodes, batch, csc, disjoint); -} - -std::tuple, c10::Dict> -hetero_relabel_neighborhood( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict& num_nodes_dict, - const c10::optional>& batch_dict, - bool csc, - bool disjoint) { - std::vector seed_dict_args; - std::vector sampled_nodes_with_dupl_dict_args; - pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0); - pyg::utils::fill_tensor_args(sampled_nodes_with_dupl_dict_args, - sampled_nodes_with_dupl_dict, - "sampled_nodes_with_dupl_dict", 0); - at::CheckedFrom c{"hetero_relabel_neighborhood"}; - - at::checkAllDefined(c, seed_dict_args); - at::checkAllDefined(c, sampled_nodes_with_dupl_dict_args); - at::checkSameType(c, seed_dict_args[0], sampled_nodes_with_dupl_dict_args[0]); - - for (const auto& kv : seed_dict) { - const at::Tensor& seed = kv.value(); - TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); - } - for (const auto& kv : sampled_nodes_with_dupl_dict) { - const at::Tensor& sampled_nodes_with_dupl = kv.value(); - TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), - "Non-contiguous 'sampled_nodes_with_dupl'"); - } - - if (disjoint) { - TORCH_CHECK(batch_dict.has_value(), - "Batch needs to be specified to create disjoint subgraphs"); - for (const auto& kv : batch_dict.value()) { - const at::Tensor& batch = kv.value(); - const at::Tensor& sampled_nodes_with_dupl = - sampled_nodes_with_dupl_dict.at(kv.key()); - TORCH_CHECK(batch.is_contiguous(), "Non-contiguous 'batch'"); - TORCH_CHECK(batch.numel() == sampled_nodes_with_dupl.numel(), - "Each node must belong to a subgraph.'"); - } - } - - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("pyg::hetero_relabel_neighborhood", "") - .typed(); - return op.call(node_types, edge_types, seed_dict, - sampled_nodes_with_dupl_dict, sampled_nbrs_per_node_dict, - num_nodes_dict, batch_dict, csc, disjoint); -} - -TORCH_LIBRARY_FRAGMENT(pyg, m) { - m.def( - TORCH_SELECTIVE_SCHEMA("pyg::relabel_neighborhood(Tensor seed, Tensor " - "sampled_nodes_with_dupl, int[] " - "sampled_nbrs_per_node, int num_nodes, Tensor? " - "batch = None, bool csc = False, bool " - "disjoint = False) " - "-> (Tensor, Tensor)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::hetero_relabel_neighborhood(str[] node_types, (str, str, str)[] " - "edge_types, Dict(str, Tensor) seed_dict, Dict(str, Tensor) " - "sampled_nodes_with_dupl_dict, Dict(str, int[]) " - "sampled_nbrs_per_node_dict, Dict(str, int) num_nodes_dict, Dict(str, " - "Tensor)? batch_dict = None, bool csc = False, bool " - "disjoint = False) " - "-> (Dict(str, Tensor), Dict(str, Tensor))")); -} - -} // namespace sampler -} // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_neighbor.h b/pyg_lib/csrc/sampler/dist_neighbor.h deleted file mode 100644 index 134fa96f..00000000 --- a/pyg_lib/csrc/sampler/dist_neighbor.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include -#include "pyg_lib/csrc/macros.h" -#include "pyg_lib/csrc/utils/types.h" - -namespace pyg { -namespace sampler { - -// Relabel global indices of the `sampled_nodes_with_dupl` to the local -// subtree/subgraph indices. -// Returns (row, col). -PYG_API -std::tuple relabel_neighborhood( - const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, - const int64_t num_nodes, - const c10::optional& batch = c10::nullopt, - bool csc = false, - bool disjoint = false); - -// Relabel global indices of the `sampled_nodes_with_dupl` to the local -// subtree/subgraph indices in the heterogeneous graph. -// Returns src and dst indices for a given edge type as a (row_dict, col_dict). -PYG_API -std::tuple, c10::Dict> -hetero_relabel_neighborhood( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict& num_nodes_dict, - const c10::optional>& batch_dict = - c10::nullopt, - bool csc = false, - bool disjoint = false); - -} // namespace sampler -} // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index ad3a4895..3502d843 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -13,7 +13,6 @@ std::tuple, std::vector, - std::vector, std::vector> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, @@ -21,14 +20,12 @@ neighbor_sample(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, - const c10::optional& batch, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool distributed) { + bool return_edge_id) { at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg seed_t{seed, "seed", 1}; @@ -40,9 +37,9 @@ neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, time, seed_time, batch, csc, - replace, directed, disjoint, temporal_strategy, return_edge_id, - distributed); + return op.call(rowptr, col, seed, num_neighbors, time, seed_time, csc, + replace, directed, disjoint, temporal_strategy, + return_edge_id); } std::tuple, @@ -65,8 +62,7 @@ hetero_neighbor_sample( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool distributed) { + bool return_edge_id) { TORCH_CHECK(rowptr_dict.size() == col_dict.size(), "Number of edge types in 'rowptr_dict' and 'col_dict' must match") @@ -92,19 +88,101 @@ hetero_neighbor_sample( .typed(); return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, replace, - directed, disjoint, temporal_strategy, return_edge_id, - distributed); + directed, disjoint, temporal_strategy, return_edge_id); +} + +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time, + const c10::optional& seed_time, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; + at::TensorArg col_t{col, "col", 1}; + at::TensorArg seed_t{seed, "seed", 1}; + + at::CheckedFrom c = "dist_neighbor_sample"; + at::checkAllDefined(c, {rowptr_t, col_t, seed_t}); + at::checkAllSameType(c, {rowptr_t, col_t, seed_t}); + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::dist_neighbor_sample", "") + .typed(); + return op.call(rowptr, col, seed, num_neighbors, time, seed_time, csc, + replace, directed, disjoint, temporal_strategy, + return_edge_id); +} + +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + const c10::optional>& seed_time_dict, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + TORCH_CHECK(rowptr_dict.size() == col_dict.size(), + "Number of edge types in 'rowptr_dict' and 'col_dict' must match") + + std::vector rowptr_dict_args; + std::vector col_dict_args; + std::vector seed_dict_args; + pyg::utils::fill_tensor_args(rowptr_dict_args, rowptr_dict, "rowptr_dict", 0); + pyg::utils::fill_tensor_args(col_dict_args, col_dict, "col_dict", 0); + pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0); + at::CheckedFrom c{"dist_hetero_neighbor_sample"}; + + at::checkAllDefined(c, rowptr_dict_args); + at::checkAllDefined(c, col_dict_args); + at::checkAllDefined(c, seed_dict_args); + at::checkAllSameType(c, rowptr_dict_args); + at::checkAllSameType(c, col_dict_args); + at::checkAllSameType(c, seed_dict_args); + at::checkSameType(c, rowptr_dict_args[0], col_dict_args[0]); + at::checkSameType(c, rowptr_dict_args[0], seed_dict_args[0]); + + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::dist_hetero_neighbor_sample", "") + .typed(); + return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, time_dict, seed_time_dict, csc, replace, + directed, disjoint, temporal_strategy, return_edge_id); } TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " - "batch = None, bool csc " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " "= False, bool replace = False, bool directed = True, bool disjoint = " - "False, str temporal_strategy = 'uniform', bool return_edge_id = True, " - "bool distributed = False) " - "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])")); + "False, str temporal_strategy = 'uniform', bool return_edge_id = True) " + "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " @@ -112,8 +190,23 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " "= None, bool csc = False, bool replace = False, bool directed = True, " "bool disjoint = False, str temporal_strategy = 'uniform', bool " - "return_edge_id = True, bool distributed = False) -> (Dict(str, Tensor), " - "Dict(str, Tensor), " + "return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), " + "Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), " + "Dict(str, int[]))")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " + "= False, bool replace = False, bool directed = True, bool disjoint = " + "False, str temporal_strategy = 'uniform', bool return_edge_id = True) " + "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::dist_hetero_neighbor_sample(str[] node_types, (str, str, str)[] " + "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " + "Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, " + "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " + "= None, bool csc = False, bool replace = False, bool directed = True, " + "bool disjoint = False, str temporal_strategy = 'uniform', bool " + "return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), " "Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), " "Dict(str, int[]))")); } diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 7e794878..85263d44 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -16,7 +16,6 @@ std::tuple, std::vector, - std::vector, std::vector> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, @@ -24,14 +23,12 @@ neighbor_sample(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time = c10::nullopt, const c10::optional& seed_time = c10::nullopt, - const c10::optional& batch = c10::nullopt, bool csc = false, bool replace = false, bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true, - bool distributed = false); + bool return_edge_id = true); // Recursively samples neighbors from all node indices in `seed_dict` // in the heterogeneous graph given by `(rowptr_dict, col_dict)`. @@ -59,8 +56,53 @@ hetero_neighbor_sample( bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true, - bool distributed = false); + bool return_edge_id = true); + +PYG_API +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time = c10::nullopt, + const c10::optional& seed_time = c10::nullopt, + bool csc = false, + bool replace = false, + bool directed = true, + bool disjoint = false, + std::string strategy = "uniform", + bool return_edge_id = true); + +PYG_API +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict = + c10::nullopt, + const c10::optional>& seed_time_dict = + c10::nullopt, + bool csc = false, + bool replace = false, + bool directed = true, + bool disjoint = false, + std::string strategy = "uniform", + bool return_edge_id = true); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 92290fd2..91ac6f83 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -15,16 +15,13 @@ def neighbor_sample( num_neighbors: List[int], time: Optional[Tensor] = None, seed_time: Optional[Tensor] = None, - batch: Optional[Tensor] = None, csc: bool = False, replace: bool = False, directed: bool = True, disjoint: bool = False, temporal_strategy: str = 'uniform', return_edge_id: bool = True, - distributed: bool = False, -) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int], - List[int]]: +) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int]]: r"""Recursively samples neighbors from all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. @@ -51,9 +48,6 @@ def neighbor_sample( seed_time (torch.Tensor, optional): Optional values to override the timestamp for seed nodes. If not set, will use timestamps in :obj:`time` as default for seed nodes. (default: :obj:`None`) - batch (torch.Tensor, optional): Optional values to specify the - initial subgraph indices for seed nodes. If not set, will use - incremental values starting from 0. (default: :obj:`None`) csc (bool, optional): If set to :obj:`True`, assumes that the graph is given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`) replace (bool, optional): If set to :obj:`True`, will sample with @@ -68,14 +62,10 @@ def neighbor_sample( return_edge_id (bool, optional): If set to :obj:`False`, will not return the indices of edges of the original graph. (default: :obj: `True`) - distributed (bool, optional): If set to :obj:`True`, will sample nodes - with duplicates, save information about the number of sampled - neighbors per node and will not return rows and cols. - This argument was added for the purpose of a distributed training. Returns: (torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], - List[int], List[int], List[int]): + List[int], List[int]): Row indices, col indices of the returned subtree/subgraph, as well as original node indices for all nodes sampled. In addition, may return the indices of edges of the original graph. @@ -84,9 +74,9 @@ def neighbor_sample( neighbors per node. """ return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors, - time, seed_time, batch, csc, replace, + time, seed_time, csc, replace, directed, disjoint, temporal_strategy, - return_edge_id, distributed) + return_edge_id) def hetero_neighbor_sample( @@ -102,7 +92,6 @@ def hetero_neighbor_sample( disjoint: bool = False, temporal_strategy: str = 'uniform', return_edge_id: bool = True, - distributed: bool = False, ) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[ NodeType, Tensor], Optional[Dict[EdgeType, Tensor]], Dict[ NodeType, List[int]], Dict[EdgeType, List[int]]]: @@ -133,9 +122,21 @@ def hetero_neighbor_sample( } out = torch.ops.pyg.hetero_neighbor_sample( - node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, csc, replace, directed, - disjoint, temporal_strategy, return_edge_id, distributed) + node_types, + edge_types, + rowptr_dict, + col_dict, + seed_dict, + num_neighbors_dict, + time_dict, + seed_time_dict, + csc, + replace, + directed, + disjoint, + temporal_strategy, + return_edge_id, + ) (row_dict, col_dict, node_id_dict, edge_id_dict, num_nodes_per_hop_dict, num_edges_per_hop_dict) = out @@ -205,100 +206,9 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, return torch.ops.pyg.random_walk(rowptr, col, seed, walk_length, p, q) -def relabel_neighborhood( - seed: Tensor, - sampled_nodes_with_dupl: Tensor, - sampled_nbrs_per_node: List[int], - num_nodes: int, - batch: Optional[Tensor] = None, - csc: bool = False, - disjoint: bool = False, -) -> Tuple[Tensor, Tensor]: - r"""Relabel global indices of the :obj:`sampled_nodes_with_dupl` to the - local subtree/subgraph indices. - - .. note:: - - For :obj:`disjoint`, the :obj:`batch` needs to be specified - and each node from :obj:`sampled_nodes_with_dupl` must be assigned - to a subgraph. - - Args: - seed (torch.Tensor): The seed node indices. - sampled_nodes_with_dupl (torch.Tensor): Sampled nodes with duplicates. - Should not include seed nodes. - sampled_nbrs_per_node (List[int]): The number of neighbors sampled by - each node from :obj:`sampled_nodes_with_dupl`. - num_nodes (int): Number of all nodes in a graph. - batch (torch.Tensor, optional): Stores information about which subgraph - the node from :obj:`sampled_nodes_with_dupl` belongs to. - Must be specified when :obj:`disjoint`. (default: :obj:`None`) - csc (bool, optional): If set to :obj:`True`, assumes that the graph is - given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`) - disjoint (bool, optional): If set to :obj:`True` , will create disjoint - subgraphs for every seed node. (default: :obj:`False`) - - Returns: - (torch.Tensor, torch.Tensor): - Row indices, col indices of the returned subtree/subgraph. - """ - return torch.ops.pyg.relabel_neighborhood(seed, sampled_nodes_with_dupl, - sampled_nbrs_per_node, num_nodes, - batch, csc, disjoint) - - -def hetero_relabel_neighborhood( - seed_dict: Dict[NodeType, - Tensor], sampled_nodes_with_dupl_dict: Dict[NodeType, - Tensor], - sampled_nbrs_per_node_dict: Dict[EdgeType, - List[int]], num_nodes_dict: Dict[NodeType, - int], - batch_dict: Optional[Dict[NodeType, Tensor]] = None, csc: bool = False, - disjoint: bool = False -) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor]]: - r"""Relabel global indices of the :obj:`sampled_nodes_with_dupl` to the - local subtree/subgraph indices in the heterogeneous graph. - - .. note :: - Similar to :meth:`relabel_neighborhood`, but expects a dictionary of - node types (:obj:`str`) and edge types (:obj:`Tuple[str, str, str]`) - for each non-boolean argument. - - Args: - kwargs: Arguments of :meth:`relabel_neighborhood`. - """ - - src_node_types = {k[0] for k in sampled_nodes_with_dupl_dict.keys()} - dst_node_types = {k[-1] for k in sampled_nodes_with_dupl_dict.keys()} - node_types = list(src_node_types | dst_node_types) - edge_types = list(sampled_nbrs_per_node_dict.keys()) - - TO_REL_TYPE = {key: '__'.join(key) for key in edge_types} - TO_EDGE_TYPE = {'__'.join(key): key for key in edge_types} - - sampled_nbrs_per_node_dict = { - TO_REL_TYPE[k]: v - for k, v in sampled_nbrs_per_node_dict.items() - } - - out = torch.ops.pyg.hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, csc, disjoint) - - (row_dict, col_dict) = out - - row_dict = {TO_EDGE_TYPE[k]: v for k, v in row_dict.items()} - col_dict = {TO_EDGE_TYPE[k]: v for k, v in col_dict.items()} - - return (row_dict, col_dict) - - __all__ = [ 'neighbor_sample', 'hetero_neighbor_sample', 'subgraph', 'random_walk', - 'relabel_neighborhood', - 'hetero_relabel_neighborhood', ] diff --git a/test/csrc/sampler/test_dist_neighbor.cpp b/test/csrc/sampler/test_dist_neighbor.cpp deleted file mode 100644 index c22f5c18..00000000 --- a/test/csrc/sampler/test_dist_neighbor.cpp +++ /dev/null @@ -1,434 +0,0 @@ -#include -#include - -#include "pyg_lib/csrc/sampler/dist_neighbor.h" -#include "pyg_lib/csrc/sampler/neighbor.h" -#include "pyg_lib/csrc/utils/types.h" -#include "test/csrc/graph.h" - -TEST(FullDistNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {-1}; - - auto out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors, - /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, - /*batch=*/c10::nullopt, /*csc*/ false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/false, - /*temporal_strategy=*/"uniform", /*return_edge_id=*/true, - /*distributed=*/true); - - // do not sample rows and cols - EXPECT_EQ(std::get<0>(out).numel(), 0); - EXPECT_EQ(std::get<1>(out).numel(), 0); - - // sample nodes with duplicates - auto expected_nodes = at::tensor({2, 3, 1, 3, 2, 4}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); - - auto expected_edges = at::tensor({4, 5, 6, 7}, options); - EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); - - std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; - EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); - - std::vector sampled_nbrs_per_node = {2, 2}; - // without seed nodes - auto sampled_nodes_with_dupl = at::tensor({1, 3, 2, 4}, options); - - // get rows and cols - auto relabel_out = pyg::sampler::relabel_neighborhood( - seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes); - - auto expected_row = at::tensor({0, 0, 1, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); - auto expected_col = at::tensor({2, 1, 0, 3}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); - - // check if rows and cols are the same as for the classic sampling - auto non_dist_out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors, - /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, - /*batch=*/c10::nullopt, /*csc*/ false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/false, - /*temporal_strategy=*/"uniform", - /*return_edge_id=*/true, /*distributed=*/false); - - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); -} - -TEST(WithoutReplacementNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {1}; - - at::manual_seed(123456); - auto out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors, - /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, - /*batch=*/c10::nullopt, /*csc*/ false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/false, - /*temporal_strategy=*/"uniform", /*return_edge_id=*/true, - /*distributed=*/true); - - // do not sample rows and cols - EXPECT_EQ(std::get<0>(out).numel(), 0); - EXPECT_EQ(std::get<1>(out).numel(), 0); - - // sample nodes with duplicates - auto expected_nodes = at::tensor({2, 3, 1, 4}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); - - auto expected_edges = at::tensor({4, 7}, options); - EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); - - std::vector expected_cumm_sum_nbrs_per_node = {2, 3, 4}; - EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); - - std::vector sampled_nbrs_per_node = {1, 1}; - // without seed nodes - auto sampled_nodes_with_dupl = at::tensor({1, 4}, options); - - // get rows and cols - auto relabel_out = pyg::sampler::relabel_neighborhood( - seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes); - - auto expected_row = at::tensor({0, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); - auto expected_col = at::tensor({2, 3}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); -} - -TEST(WithReplacementNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {2}; - - at::manual_seed(123456); - auto out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, - num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, - /*csc*/ false, /*replace=*/true, /*directed=*/true, - /*disjoint=*/false, /*temporal_strategy=*/"uniform", - /*return_edge_id=*/true, /*distributed=*/true); - - // do not sample rows and cols - EXPECT_EQ(std::get<0>(out).numel(), 0); - EXPECT_EQ(std::get<1>(out).numel(), 0); - - // sample nodes with duplicates - auto expected_nodes = at::tensor({2, 3, 1, 3, 4, 4}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); - - auto expected_edges = at::tensor({4, 5, 7, 7}, options); - EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); - - std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; - EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); - - std::vector sampled_nbrs_per_node = {2, 2}; - // without seed nodes - auto sampled_nodes_with_dupl = at::tensor({1, 3, 4, 4}, options); - - // get rows and cols - auto relabel_out = pyg::sampler::relabel_neighborhood( - seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes); - - auto expected_row = at::tensor({0, 0, 1, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); - auto expected_col = at::tensor({2, 1, 3, 3}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); -} - -TEST(DistDisjointNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {2}; - auto batch = at::tensor({0, 1}, options); - - auto out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, - num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, batch, /*csc*/ false, - /*replace=*/false, /*directed=*/true, /*disjoint=*/true, - /*temporal_strategy=*/"uniform", /*return_edge_id=*/true, - /*distributed=*/true); - - // do not sample rows and cols - EXPECT_EQ(std::get<0>(out).numel(), 0); - EXPECT_EQ(std::get<1>(out).numel(), 0); - - // sample nodes with duplicates - auto expected_nodes = - at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); - - auto expected_edges = at::tensor({4, 5, 6, 7}, options); - EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); - - std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; - EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); - - std::vector sampled_nbrs_per_node = {2, 2}; - // without seed nodes - auto sampled_nodes_with_dupl = at::tensor({1, 3, 2, 4}, options); - auto sampled_batch = at::tensor({0, 0, 1, 1}, options); - - // get rows and cols - auto relabel_out = pyg::sampler::relabel_neighborhood( - seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes, - sampled_batch, /*csc=*/false, /*disjoint=*/true); - - auto expected_row = at::tensor({0, 0, 1, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); - auto expected_col = at::tensor({2, 3, 4, 5}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); - - // check if rows and cols are the same as for the classic sampling - auto non_dist_out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors, - /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, batch, /*csc*/ false, - /*replace=*/false, - /*directed=*/true, /*disjoint=*/true, - /*temporal_strategy=*/"uniform", - /*return_edge_id=*/true, /*distributed=*/false); - - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); -} - -TEST(DistTemporalNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - auto rowptr = std::get<0>(graph); - auto col = std::get<1>(graph); - - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {2}; - - // Time is equal to node ID ... - auto time = at::arange(6, options); - // ... so we need to sort the column vector by time/node ID: - col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); - - auto out = pyg::sampler::neighbor_sample( - rowptr, col, seed, num_neighbors, time, - /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, - /*csc*/ false, /*replace=*/false, /*directed=*/true, - /*disjoint=*/true, /*temporal_strategy=*/"uniform", - /*return_edge_id=*/true, /*distributed=*/true); - - // do not sample rows and cols - EXPECT_EQ(std::get<0>(out).numel(), 0); - EXPECT_EQ(std::get<1>(out).numel(), 0); - - // sample nodes with duplicates - auto expected_nodes = at::tensor({0, 2, 1, 3, 0, 1, 1, 2}, options); - EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); - - auto expected_edges = at::tensor({4, 6}, options); - EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); - - std::vector expected_cumm_sum_nbrs_per_node = {2, 3, 4}; - EXPECT_EQ(std::get<6>(out), expected_cumm_sum_nbrs_per_node); - - std::vector sampled_nbrs_per_node = {1, 1}; - // without seed nodes - auto sampled_nodes_with_dupl = at::tensor({1, 2}, options); - auto sampled_batch = at::tensor({0, 1}, options); - - // get rows and cols - auto relabel_out = pyg::sampler::relabel_neighborhood( - seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes, - sampled_batch, /*csc=*/false, /*disjoint=*/true); - - auto expected_row = at::tensor({0, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); - auto expected_col = at::tensor({2, 3}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); - - // check if rows and cols are the same as for the classic sampling - auto non_dist_out = pyg::sampler::neighbor_sample( - rowptr, col, seed, num_neighbors, time, /*seed_time=*/c10::nullopt, - /*batch=*/c10::nullopt, /*csc*/ false, - /*replace=*/false, /*directed=*/true, - /*disjoint=*/true, /*temporal_strategy=*/"uniform", - /*return_edge_id=*/true, /*distributed=*/false); - - EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); -} - -TEST(DistHeteroNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - const auto node_key = "paper"; - const auto edge_key = std::make_tuple("paper", "to", "paper"); - const auto rel_key = "paper__to__paper"; - std::vector node_types = {node_key}; - std::vector edge_types = {edge_key}; - c10::Dict rowptr_dict; - rowptr_dict.insert(rel_key, std::get<0>(graph)); - c10::Dict col_dict; - col_dict.insert(rel_key, std::get<1>(graph)); - c10::Dict seed_dict; - seed_dict.insert(node_key, at::arange(2, 4, options)); - std::vector num_neighbors = {2}; - c10::Dict> num_neighbors_dict; - num_neighbors_dict.insert(rel_key, num_neighbors); - c10::Dict num_nodes_dict; - num_nodes_dict.insert(node_key, num_nodes); - - c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; - sampled_nodes_with_dupl_dict.insert(node_key, - at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); - // get rows and cols - auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, - /*batch_dict=*/c10::nullopt, /*csc=*/false, /*disjoint=*/false); - - auto expected_row = at::tensor({0, 0, 1, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); - auto expected_col = at::tensor({2, 1, 0, 3}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); - - // check if rows and cols are the same as for the classic sampling - auto non_dist_out = pyg::sampler::hetero_neighbor_sample( - node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict); - - EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), - std::get<0>(non_dist_out).at(rel_key))); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), - std::get<1>(non_dist_out).at(rel_key))); -} - -TEST(DistHeteroCscNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - const auto node_key = "paper"; - const auto edge_key = std::make_tuple("paper", "to", "paper"); - const auto rel_key = "paper__to__paper"; - std::vector node_types = {node_key}; - std::vector edge_types = {edge_key}; - c10::Dict rowptr_dict; - rowptr_dict.insert(rel_key, std::get<0>(graph)); - c10::Dict col_dict; - col_dict.insert(rel_key, std::get<1>(graph)); - c10::Dict seed_dict; - seed_dict.insert(node_key, at::arange(2, 4, options)); - std::vector num_neighbors = {2}; - c10::Dict> num_neighbors_dict; - num_neighbors_dict.insert(rel_key, num_neighbors); - c10::Dict num_nodes_dict; - num_nodes_dict.insert(node_key, num_nodes); - - c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; - sampled_nodes_with_dupl_dict.insert(node_key, - at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); - // get rows and cols - auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, - /*batch_dict=*/c10::nullopt, /*csc=*/true, /*disjoint=*/false); - - auto expected_row = at::tensor({2, 1, 0, 3}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); - auto expected_col = at::tensor({0, 0, 1, 1}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); - - // check if rows and cols are the same as for the classic sampling - auto non_dist_out = pyg::sampler::hetero_neighbor_sample( - node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, - /*time_dict=*/c10::nullopt, - /*seed_time_dict=*/c10::nullopt, /*csc*/ true); - - EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), - std::get<0>(non_dist_out).at(rel_key))); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), - std::get<1>(non_dist_out).at(rel_key))); -} - -TEST(DistHeteroDisjointNeighborTest, BasicAssertions) { - auto options = at::TensorOptions().dtype(at::kLong); - - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); - const auto node_key = "paper"; - const auto edge_key = std::make_tuple("paper", "to", "paper"); - const auto rel_key = "paper__to__paper"; - std::vector node_types = {node_key}; - std::vector edge_types = {edge_key}; - c10::Dict rowptr_dict; - rowptr_dict.insert(rel_key, std::get<0>(graph)); - c10::Dict col_dict; - col_dict.insert(rel_key, std::get<1>(graph)); - c10::Dict seed_dict; - seed_dict.insert(node_key, at::arange(2, 4, options)); - std::vector num_neighbors = {2}; - c10::Dict> num_neighbors_dict; - num_neighbors_dict.insert(rel_key, num_neighbors); - c10::Dict num_nodes_dict; - num_nodes_dict.insert(node_key, num_nodes); - - c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; - c10::Dict batch_dict; - sampled_nodes_with_dupl_dict.insert(node_key, - at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); - batch_dict.insert(node_key, at::tensor({0, 0, 1, 1}, options)); - // get rows and cols - auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, - /*csc=*/false, /*disjoint=*/true); - - auto expected_row = at::tensor({0, 0, 1, 1}, options); - EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); - auto expected_col = at::tensor({2, 3, 4, 5}, options); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); - - // check if rows and cols are the same as for the classic sampling - auto non_dist_out = pyg::sampler::hetero_neighbor_sample( - node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, /*time_dict=*/c10::nullopt, - /*seed_time_dict=*/c10::nullopt, /*csc=*/false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/true); - - EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), - std::get<0>(non_dist_out).at(rel_key))); - EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), - std::get<1>(non_dist_out).at(rel_key))); -} diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index eb260cc1..2c3c7570 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -41,8 +41,7 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, - /*replace=*/false); + /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false); auto expected_row = at::tensor({0, 1, 2, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -65,8 +64,7 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, - /*replace=*/true); + /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/true); auto expected_row = at::tensor({0, 1, 2, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -88,8 +86,7 @@ TEST(DisjointNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, - /*replace=*/false, + /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5}, options); @@ -119,8 +116,7 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto out1 = pyg::sampler::neighbor_sample( rowptr, col, seed, /*num_neighbors=*/{2, 2}, /*time=*/time, - /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, - /*replace=*/false, + /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); // Expect only the earlier neighbors or the same node to be sampled: @@ -136,8 +132,7 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto out2 = pyg::sampler::neighbor_sample( rowptr, col, seed, /*num_neighbors=*/{1, 2}, /*time=*/time, - /*seed_time=*/c10::nullopt, /*batch=*/c10::nullopt, /*csc=*/false, - /*replace=*/false, + /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true, /*temporal_strategy=*/"last"); EXPECT_TRUE(at::equal(std::get<0>(out1), std::get<0>(out2))); From 61b2cb3188fe0e57f2b6452a82894a3cd8233f56 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 5 Sep 2023 14:18:08 +0000 Subject: [PATCH 6/7] update --- pyg_lib/sampler/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 91ac6f83..4884b2e9 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -70,8 +70,7 @@ def neighbor_sample( original node indices for all nodes sampled. In addition, may return the indices of edges of the original graph. Lastly, returns information about the sampled amount of nodes and edges - per hop and if `distributed` will return cummulative sum of the sampled - neighbors per node. + per hop. """ return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors, time, seed_time, csc, replace, From b45cc0e1bbadbf6545f1aa1c87929560d0924713 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 5 Sep 2023 14:26:51 +0000 Subject: [PATCH 7/7] update --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 8 +++--- pyg_lib/csrc/sampler/cpu/neighbor_kernel.h | 2 ++ pyg_lib/csrc/sampler/neighbor.cpp | 28 +++++++++++--------- pyg_lib/csrc/sampler/neighbor.h | 28 +++++++++++--------- 4 files changed, 39 insertions(+), 27 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index d76ca7fe..71ea6de6 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -860,6 +860,7 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& edge_weight, bool csc, bool replace, bool directed, @@ -867,7 +868,7 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, std::string temporal_strategy, bool return_edge_id) { DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, - seed, num_neighbors, time, seed_time, csc, + seed, num_neighbors, time, seed_time, edge_weight, csc, temporal_strategy); } @@ -886,6 +887,7 @@ dist_hetero_neighbor_sample_kernel( const c10::Dict>& num_neighbors_dict, const c10::optional>& time_dict, const c10::optional>& seed_time_dict, + const c10::optional>& edge_weight_dict, bool csc, bool replace, bool directed, @@ -894,8 +896,8 @@ dist_hetero_neighbor_sample_kernel( bool return_edge_id) { DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, csc, - temporal_strategy); + num_neighbors_dict, time_dict, seed_time_dict, + edge_weight_dict, csc, temporal_strategy); } TORCH_LIBRARY_IMPL(pyg, CPU, m) { diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 46c80608..9c6e3aa5 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -61,6 +61,7 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& edge_weight, bool csc, bool replace, bool directed, @@ -83,6 +84,7 @@ dist_hetero_neighbor_sample_kernel( const c10::Dict>& num_neighbors_dict, const c10::optional>& time_dict, const c10::optional>& seed_time_dict, + const c10::optional>& edge_weight_dict, bool csc, bool replace, bool directed, diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 0c246ecb..6907ab5d 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -107,6 +107,7 @@ dist_neighbor_sample(const at::Tensor& rowptr, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& edge_weight, bool csc, bool replace, bool directed, @@ -124,8 +125,8 @@ dist_neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::dist_neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, time, seed_time, csc, - replace, directed, disjoint, temporal_strategy, + return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, + csc, replace, directed, disjoint, temporal_strategy, return_edge_id); } @@ -144,6 +145,7 @@ dist_hetero_neighbor_sample( const c10::Dict>& num_neighbors_dict, const c10::optional>& time_dict, const c10::optional>& seed_time_dict, + const c10::optional>& edge_weight_dict, bool csc, bool replace, bool directed, @@ -175,8 +177,9 @@ dist_hetero_neighbor_sample( .findSchemaOrThrow("pyg::dist_hetero_neighbor_sample", "") .typed(); return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, csc, replace, - directed, disjoint, temporal_strategy, return_edge_id); + num_neighbors_dict, time_dict, seed_time_dict, + edge_weight_dict, csc, replace, directed, disjoint, + temporal_strategy, return_edge_id); } TORCH_LIBRARY_FRAGMENT(pyg, m) { @@ -199,20 +202,21 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " - "= False, bool replace = False, bool directed = True, bool disjoint = " - "False, str temporal_strategy = 'uniform', bool return_edge_id = True) " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " + "edge_weight = None, bool csc = False, bool replace = False, bool " + "directed = True, bool disjoint = False, str temporal_strategy = " + "'uniform', bool return_edge_id = True) " "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::dist_hetero_neighbor_sample(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " "Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, " "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " - "= None, bool csc = False, bool replace = False, bool directed = True, " - "bool disjoint = False, str temporal_strategy = 'uniform', bool " - "return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), " - "Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), " - "Dict(str, int[]))")); + "= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, " + "bool replace = False, bool directed = True, bool disjoint = False, " + "str temporal_strategy = 'uniform', bool return_edge_id = True) -> " + "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " + "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index b38fdc3c..8abcd42b 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -69,18 +69,20 @@ std::tuple, std::vector, std::vector> -dist_neighbor_sample(const at::Tensor& rowptr, - const at::Tensor& col, - const at::Tensor& seed, - const std::vector& num_neighbors, - const c10::optional& time = c10::nullopt, - const c10::optional& seed_time = c10::nullopt, - bool csc = false, - bool replace = false, - bool directed = true, - bool disjoint = false, - std::string strategy = "uniform", - bool return_edge_id = true); +dist_neighbor_sample( + const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time = c10::nullopt, + const c10::optional& seed_time = c10::nullopt, + const c10::optional& edge_weight = c10::nullopt, + bool csc = false, + bool replace = false, + bool directed = true, + bool disjoint = false, + std::string strategy = "uniform", + bool return_edge_id = true); PYG_API std::tuple, @@ -100,6 +102,8 @@ dist_hetero_neighbor_sample( c10::nullopt, const c10::optional>& seed_time_dict = c10::nullopt, + const c10::optional>& edge_weight_dict = + c10::nullopt, bool csc = false, bool replace = false, bool directed = true,