From f4ed34c28be1a9c5859b3f10d2e2cfbeb6db38a7 Mon Sep 17 00:00:00 2001 From: Kinga Gajdamowicz Date: Fri, 10 Nov 2023 11:49:53 +0100 Subject: [PATCH] Update `DistNeighborSampler` for homo graphs [2/6] (#8209) **Changes made:** - added support for temporal sampling - use torch.Tensors instead of numpy arrays - move _sample_one_hop() from NeighborSampler to DistNeighborSampler - do not go with disjoint flow in _sample() function - this is not needed because batch is calculated after - added tests for node sampling and disjoint (works without DistNeighborLoader) - added tests for node temporal sampling (works without DistNeighborLoader) - some minor changes like changing variables names etc This PR is based on the #8083, so both must be combined to pass the tests. Other distributed PRs: #8083 #8080 #8085 --------- Co-authored-by: Matthias Fey Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 1 + .../distributed/test_dist_neighbor_sampler.py | 283 ++++++++++++++++++ test/distributed/test_rpc.py | 2 +- .../distributed/dist_neighbor_sampler.py | 240 +++++++++------ .../distributed/local_graph_store.py | 1 + torch_geometric/sampler/neighbor_sampler.py | 65 +--- 6 files changed, 442 insertions(+), 150 deletions(-) create mode 100644 test/distributed/test_dist_neighbor_sampler.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 912d0fe25be6..8de5e46b2d91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209)) - Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083)) - Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210)) - Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220)) diff --git a/test/distributed/test_dist_neighbor_sampler.py b/test/distributed/test_dist_neighbor_sampler.py new file mode 100644 index 000000000000..77a63d34c912 --- /dev/null +++ b/test/distributed/test_dist_neighbor_sampler.py @@ -0,0 +1,283 @@ +import atexit +import socket + +import pytest +import torch + +from torch_geometric.data import Data +from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore +from torch_geometric.distributed.dist_context import DistContext +from torch_geometric.distributed.dist_neighbor_sampler import ( + DistNeighborSampler, + close_sampler, +) +from torch_geometric.distributed.rpc import init_rpc +from torch_geometric.sampler import NeighborSampler, NodeSamplerInput +from torch_geometric.sampler.neighbor_sampler import node_sample +from torch_geometric.testing import onlyLinux, withPackage + + +def create_data(rank: int, world_size: int, temporal: bool = False): + if rank == 0: # Partition 0: + node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9]) + edge_index = torch.tensor([ # Sorted by destination. + [1, 2, 3, 4, 5, 0, 0], + [0, 1, 2, 3, 4, 4, 9], + ]) + else: # Partition 1: + node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9]) + edge_index = torch.tensor([ # Sorted by destination. + [5, 6, 7, 8, 9, 5, 0], + [4, 5, 6, 7, 8, 9, 9], + ]) + + feature_store = LocalFeatureStore.from_data(node_id) + graph_store = LocalGraphStore.from_data( + edge_id=None, + edge_index=edge_index, + num_nodes=10, + is_sorted=True, + ) + + graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + graph_store.meta.update({'num_parts': 2}) + graph_store.partition_idx = rank + graph_store.num_partitions = world_size + + edge_index = torch.tensor([ # Create reference data: + [1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0], + [0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9], + ]) + data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10) + + if temporal: # Create time data: + data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4]) + feature_store.put_tensor(data.time, group_name=None, attr_name='time') + + return (feature_store, graph_store), data + + +def dist_neighbor_sampler( + world_size: int, + rank: int, + master_port: int, + disjoint: bool = False, +): + dist_data, data = create_data(rank, world_size) + + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name='dist-sampler-test', + ) + + # Initialize training process group of PyTorch: + torch.distributed.init_process_group( + backend='gloo', + rank=current_ctx.rank, + world_size=current_ctx.world_size, + init_method=f'tcp://localhost:{master_port}', + ) + + dist_sampler = DistNeighborSampler( + data=dist_data, + current_ctx=current_ctx, + rpc_worker_names={}, + num_neighbors=[-1, -1], + shuffle=False, + disjoint=disjoint, + ) + + init_rpc( + current_ctx=current_ctx, + rpc_worker_names={}, + master_addr='localhost', + master_port=master_port, + ) + + dist_sampler.register_sampler_rpc() + dist_sampler.init_event_loop() + + # Close RPC & worker group at exit: + atexit.register(close_sampler, 0, dist_sampler) + torch.distributed.barrier() + + if rank == 0: # Seed nodes: + input_node = torch.tensor([1, 6]) + else: + input_node = torch.tensor([4, 9]) + + inputs = NodeSamplerInput(input_id=None, node=input_node) + + # Evaluate distributed node sample function: + out_dist = dist_sampler.event_loop.run_task( + coro=dist_sampler.node_sample(inputs)) + + torch.distributed.barrier() + + sampler = NeighborSampler( + data=data, + num_neighbors=[-1, -1], + disjoint=disjoint, + ) + + # Evaluate node sample function: + out = node_sample(inputs, sampler._sample) + + # Compare distributed output with single machine output: + assert torch.equal(out_dist.node, out.node) + assert torch.equal(out_dist.row, out.row) + assert torch.equal(out_dist.col, out.col) + if disjoint: + assert torch.equal(out_dist.batch, out.batch) + assert out_dist.num_sampled_nodes == out.num_sampled_nodes + assert out_dist.num_sampled_edges == out.num_sampled_edges + + torch.distributed.barrier() + + +def dist_neighbor_sampler_temporal( + world_size: int, + rank: int, + master_port: int, + seed_time: torch.tensor = None, + temporal_strategy: str = 'uniform', +): + dist_data, data = create_data(rank, world_size, temporal=True) + + current_ctx = DistContext( + rank=rank, + global_rank=rank, + world_size=world_size, + global_world_size=world_size, + group_name='dist-sampler-test', + ) + + # Initialize training process group of PyTorch: + torch.distributed.init_process_group( + backend='gloo', + rank=current_ctx.rank, + world_size=current_ctx.world_size, + init_method=f'tcp://localhost:{master_port}', + ) + + num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1] + dist_sampler = DistNeighborSampler( + data=dist_data, + current_ctx=current_ctx, + rpc_worker_names={}, + num_neighbors=num_neighbors, + shuffle=False, + disjoint=True, + temporal_strategy=temporal_strategy, + time_attr='time', + ) + + init_rpc( + current_ctx=current_ctx, + rpc_worker_names={}, + master_addr='localhost', + master_port=master_port, + ) + + dist_sampler.register_sampler_rpc() + dist_sampler.init_event_loop() + + # Close RPC & worker group at exit: + atexit.register(close_sampler, 0, dist_sampler) + torch.distributed.barrier() + + if rank == 0: # Seed nodes: + input_node = torch.tensor([1, 6], dtype=torch.int64) + else: + input_node = torch.tensor([4, 9], dtype=torch.int64) + + inputs = NodeSamplerInput( + input_id=None, + node=input_node, + time=seed_time, + ) + + # Evaluate distributed node sample function: + out_dist = dist_sampler.event_loop.run_task( + coro=dist_sampler.node_sample(inputs)) + + torch.distributed.barrier() + + sampler = NeighborSampler( + data=data, + num_neighbors=num_neighbors, + disjoint=True, + temporal_strategy=temporal_strategy, + time_attr='time', + ) + + # Evaluate node sample function: + out = node_sample(inputs, sampler._sample) + + # Compare distributed output with single machine output: + assert torch.equal(out_dist.node, out.node) + assert torch.equal(out_dist.row, out.row) + assert torch.equal(out_dist.col, out.col) + assert torch.equal(out_dist.batch, out.batch) + assert out_dist.num_sampled_nodes == out.num_sampled_nodes + assert out_dist.num_sampled_edges == out.num_sampled_edges + + torch.distributed.barrier() + + +@onlyLinux +@withPackage('pyg_lib') +@pytest.mark.parametrize('disjoint', [False, True]) +def test_dist_neighbor_sampler(disjoint): + mp_context = torch.multiprocessing.get_context('spawn') + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + + world_size = 2 + w0 = mp_context.Process( + target=dist_neighbor_sampler, + args=(world_size, 0, port, disjoint), + ) + + w1 = mp_context.Process( + target=dist_neighbor_sampler, + args=(world_size, 1, port, disjoint), + ) + + w0.start() + w1.start() + w0.join() + w1.join() + + +@onlyLinux +@withPackage('pyg_lib') +@pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])]) +@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) +def test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy): + mp_context = torch.multiprocessing.get_context('spawn') + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + + world_size = 2 + w0 = mp_context.Process( + target=dist_neighbor_sampler_temporal, + args=(world_size, 0, port, seed_time, temporal_strategy), + ) + + w1 = mp_context.Process( + target=dist_neighbor_sampler_temporal, + args=(world_size, 1, port, seed_time, temporal_strategy), + ) + + w0.start() + w1.start() + w0.join() + w1.join() diff --git a/test/distributed/test_rpc.py b/test/distributed/test_rpc.py index 924f2a5a7b68..729bc8e92c2a 100644 --- a/test/distributed/test_rpc.py +++ b/test/distributed/test_rpc.py @@ -58,7 +58,7 @@ def run_rpc_feature_test( feature.num_partitions = world_size feature.partition_idx = rank - feature.feature_pb = partition_book + feature.node_feat_pb = partition_book feature.meta = meta feature.local_only = False feature.set_rpc_router(rpc_router) diff --git a/torch_geometric/distributed/dist_neighbor_sampler.py b/torch_geometric/distributed/dist_neighbor_sampler.py index 84ae7dcfbdf6..7173a37f4583 100644 --- a/torch_geometric/distributed/dist_neighbor_sampler.py +++ b/torch_geometric/distributed/dist_neighbor_sampler.py @@ -1,5 +1,6 @@ import itertools import logging +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -10,7 +11,7 @@ from torch_geometric.distributed.dist_context import DistContext, DistRole from torch_geometric.distributed.event_loop import ( ConcurrentEventLoop, - wrap_torch_future, + to_asyncio_future, ) from torch_geometric.distributed.rpc import ( RPCCallBase, @@ -34,16 +35,7 @@ ) from torch_geometric.sampler.base import NumNeighbors, SubgraphType from torch_geometric.sampler.utils import remap_keys -from torch_geometric.typing import ( - Any, - Dict, - EdgeType, - List, - NodeType, - Optional, - Tuple, - Union, -) +from torch_geometric.typing import EdgeType, NodeType NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]] @@ -68,51 +60,53 @@ class DistNeighborSampler: used by :class:`~torch_geometric.distributed.DistNeighborLoader`. """ def __init__( - self, - current_ctx: DistContext, - rpc_worker_names: Dict[DistRole, List[str]], - data: Tuple[LocalGraphStore, LocalFeatureStore], - num_neighbors: NumNeighborsType, - channel: Optional[mp.Queue] = None, - replace: bool = False, - subgraph_type: Union[SubgraphType, str] = 'directional', - disjoint: bool = False, - temporal_strategy: str = 'uniform', - time_attr: Optional[str] = None, - concurrency: int = 1, - **kwargs, + self, + current_ctx: DistContext, + rpc_worker_names: Dict[DistRole, List[str]], + data: Tuple[LocalFeatureStore, LocalGraphStore], + num_neighbors: NumNeighborsType, + channel: Optional[mp.Queue] = None, + replace: bool = False, + subgraph_type: Union[SubgraphType, str] = 'directional', + disjoint: bool = False, + temporal_strategy: str = 'uniform', + time_attr: Optional[str] = None, + concurrency: int = 1, + device: Optional[torch.device] = torch.device('cpu'), + **kwargs, ): self.current_ctx = current_ctx self.rpc_worker_names = rpc_worker_names self.feature_store, self.graph_store = data - assert isinstance(self.dist_graph, LocalGraphStore) - assert isinstance(self.dist_feature_store, LocalFeatureStore) - self.is_hetero = self.dist_graph.meta['is_hetero'] + assert isinstance(self.graph_store, LocalGraphStore) + assert isinstance(self.feature_store, LocalFeatureStore) + self.is_hetero = self.graph_store.meta['is_hetero'] self.num_neighbors = num_neighbors - self.channel = channel or mp.Queue() + self.channel = channel self.concurrency = concurrency + self.device = device self.event_loop = None self.replace = replace self.subgraph_type = SubgraphType(subgraph_type) self.disjoint = disjoint self.temporal_strategy = temporal_strategy self.time_attr = time_attr - self.with_edge_attr = self.dist_feature.has_edge_attr() + self.with_edge_attr = self.feature_store.has_edge_attr() self.csc = True def register_sampler_rpc(self) -> None: partition2workers = rpc_partition_to_workers( current_ctx=self.current_ctx, - num_partitions=self.dist_graph.num_partitions, - current_partition_idx=self.dist_graph.partition_idx, + num_partitions=self.graph_store.num_partitions, + current_partition_idx=self.graph_store.partition_idx, ) self.rpc_router = RPCRouter(partition2workers) - self.dist_feature.set_rpc_router(self.rpc_router) + self.feature_store.set_rpc_router(self.rpc_router) self._sampler = NeighborSampler( - data=(self.dist_feature_store, self.dist_graph_store), + data=(self.feature_store, self.graph_store), num_neighbors=self.num_neighbors, subgraph_type=self.subgraph_type, replace=self.replace, @@ -120,9 +114,13 @@ def register_sampler_rpc(self) -> None: temporal_strategy=self.temporal_strategy, time_attr=self.time_attr, ) - self.edge_permutation = self._sampler.perm - rpc_sample_callee = RPCSamplingCallee(self._sampler) + self.num_hops = self._sampler.num_neighbors.num_hops + self.node_types = self._sampler.node_types + self.edge_types = self._sampler.edge_types + self.node_time = self._sampler.node_time + + rpc_sample_callee = RPCSamplingCallee(self) self.rpc_sample_callee_id = rpc_register(rpc_sample_callee) def init_event_loop(self) -> None: @@ -178,7 +176,7 @@ async def node_sample( """ input_type = inputs.input_type self.input_type = input_type - batch_size = inputs.input_id.size()[0] + batch_size = inputs.node.size()[0] seed_dict = None seed_time_dict = None @@ -186,11 +184,13 @@ async def node_sample( if isinstance(inputs, NodeSamplerInput): seed = inputs.node.to(self.device) - seed_time = (inputs.time.to(self.device) - if inputs.time is not None else None) + seed_time = None + if self.time_attr is not None: + if inputs.time is not None: + seed_time = inputs.time.to(self.device) + else: + seed_time = self.node_time[seed] src_batch = torch.arange(batch_size) if self.disjoint else None - seed_dict = {input_type: seed} - seed_time_dict: Dict[NodeType, Tensor] = {input_type: seed_time} metadata = (seed, seed_time) elif isinstance(inputs, EdgeSamplerInput) and self.is_hetero: @@ -215,6 +215,9 @@ async def node_sample( if input_type is None: raise ValueError("Input type should be defined") + seed_dict: Dict[NodeType, Tensor] = {input_type: seed} + seed_time_dict: Dict[NodeType, Tensor] = {input_type: seed_time} + node_dict = NodeDict() batch_dict = BatchDict(self.disjoint) edge_dict: Dict[EdgeType, Tensor] = {} @@ -354,8 +357,8 @@ async def node_sample( ) else: src = seed - node = src.numpy() - batch = src_batch.numpy() if self.disjoint else None + node = src + batch = src_batch if self.disjoint else None node_with_dupl = [torch.empty(0, dtype=torch.int64)] batch_with_dupl = [torch.empty(0, dtype=torch.int64)] @@ -363,10 +366,10 @@ async def node_sample( sampled_nbrs_per_node = [] num_sampled_nodes = [seed.numel()] - num_sampled_edges = [0] + num_sampled_edges = [] # loop over the layers - for one_hop_num in self.num_neighbors: + for i, one_hop_num in enumerate(self.num_neighbors): out = await self.sample_one_hop(src, one_hop_num, seed_time, src_batch) if out.node.numel() == 0: @@ -383,9 +386,15 @@ async def node_sample( if self.disjoint: batch_with_dupl.append(out.batch) + if seed_time is not None and i < self.num_hops - 1: + # Get the seed time for the next layer based on the + # previous seed_time and sampled neighbors per node info: + seed_time = torch.repeat_interleave( + seed_time, torch.as_tensor(out.metadata[0])) + num_sampled_nodes.append(len(src)) num_sampled_edges.append(len(out.node)) - sampled_nbrs_per_node += out.metadata + sampled_nbrs_per_node += out.metadata[0] row, col = torch.ops.pyg.relabel_neighborhood( seed, @@ -398,11 +407,11 @@ async def node_sample( ) sampler_output = SamplerOutput( - node=torch.from_numpy(node), + node=node, row=row, col=col, edge=torch.cat(edge), - batch=torch.from_numpy(batch) if self.disjoint else None, + batch=batch if self.disjoint else None, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, metadata=metadata, @@ -425,7 +434,7 @@ def get_sampler_output( were sampled by each src node based on the :obj:`cumsum_neighbors_per_node`. Returns updated sampler output. """ - cumsum_neighbors_per_node = outputs[p_id].metadata + cumsum_neighbors_per_node = outputs[p_id].metadata[0] # do not include seed outputs[p_id].node = outputs[p_id].node[seed_size:] @@ -435,7 +444,7 @@ def get_sampler_output( sampled_nbrs_per_node = list(np.subtract(begin, end)) - outputs[p_id].metadata = sampled_nbrs_per_node + outputs[p_id].metadata = (sampled_nbrs_per_node, ) if self.disjoint: batch = [[src_batch[i]] * nbrs_per_node @@ -481,13 +490,13 @@ def merge_sampler_outputs( for o in outputs ] cumm_sampled_nbrs_per_node = [ - o.metadata if o is not None else [] for o in outputs + o.metadata[0] if o is not None else [] for o in outputs ] partition_ids = partition_ids.tolist() partition_orders = partition_orders.tolist() - partitions_num = self.dist_graph.meta["num_parts"] + partitions_num = self.graph_store.meta["num_parts"] out = torch.ops.pyg.merge_sampler_outputs( sampled_nodes_with_dupl, @@ -513,7 +522,7 @@ def merge_sampler_outputs( None, out_edge, out_batch if self.disjoint else None, - metadata=(out_sampled_nbrs_per_node), + metadata=(out_sampled_nbrs_per_node, ), ) async def sample_one_hop( @@ -522,7 +531,7 @@ async def sample_one_hop( one_hop_num: int, seed_time: Optional[Tensor] = None, src_batch: Optional[Tensor] = None, - etype: Optional[EdgeType] = None, + edge_type: Optional[EdgeType] = None, ) -> SamplerOutput: r"""Sample one-hop neighbors for a :obj:`srcs`. If src node is located on a local partition, evaluates the :obj:`_sample_one_hop` function on @@ -531,19 +540,20 @@ async def sample_one_hop( Returns merged samplers outputs from local / remote machines. """ - partition_ids = self.dist_graph.get_partition_ids_from_nids(srcs) + partition_ids = self.graph_store.get_partition_ids_from_nids(srcs) partition_orders = torch.zeros(len(partition_ids), dtype=torch.long) - p_outputs: List[SamplerOutput] = [None - ] * self.dist_graph.meta["num_parts"] + p_outputs: List[SamplerOutput] = [ + None + ] * self.graph_store.meta["num_parts"] futs: List[torch.futures.Future] = [] local_only = True single_partition = len(set(partition_ids.tolist())) == 1 - for i in range(self.dist_graph.num_partitions): - p_id = (self.dist_graph.partition_idx + - i) % self.dist_graph.num_partitions + for i in range(self.graph_store.num_partitions): + p_id = (self.graph_store.partition_idx + + i) % self.graph_store.num_partitions p_mask = partition_ids == p_id p_srcs = torch.masked_select(srcs, p_mask) p_seed_time = (torch.masked_select(seed_time, p_mask) @@ -553,36 +563,30 @@ async def sample_one_hop( partition_orders[p_mask] = p_indices if p_srcs.shape[0] > 0: - if p_id == self.dist_graph.partition_idx: - # sample on local machine - p_nbr_out = self._sampler._sample_one_hop( - p_srcs, one_hop_num, p_seed_time, self.csc, etype) + if p_id == self.graph_store.partition_idx: + # Sample for one hop on a local machine: + p_nbr_out = self._sample_one_hop(p_srcs, one_hop_num, + p_seed_time, edge_type) p_outputs.pop(p_id) p_outputs.insert(p_id, p_nbr_out) - else: - # sample on remote machine + + else: # Sample on a remote machine: local_only = False to_worker = self.rpc_router.get_to_worker(p_id) futs.append( rpc_async( to_worker, self.rpc_sample_callee_id, - args=( - p_srcs, - one_hop_num, - p_seed_time, - self.csc, - etype, - ), + args=(p_srcs, one_hop_num, p_seed_time, edge_type), )) if not local_only: # Src nodes are remote - res_fut_list = await wrap_torch_future( + res_fut_list = await to_asyncio_future( torch.futures.collect_all(futs)) for i, res_fut in enumerate(res_fut_list): - p_id = (self.dist_graph.partition_idx + i + - 1) % self.dist_graph.num_partitions + p_id = (self.graph_store.partition_idx + i + + 1) % self.graph_store.num_partitions p_outputs.pop(p_id) p_outputs.insert(p_id, res_fut.wait()) @@ -594,6 +598,56 @@ async def sample_one_hop( return self.merge_sampler_outputs(partition_ids, partition_orders, p_outputs, one_hop_num, src_batch) + def _sample_one_hop( + self, + input_nodes: Tensor, + num_neighbors: int, + seed_time: Optional[Tensor] = None, + edge_type: Optional[EdgeType] = None, + ) -> SamplerOutput: + r"""Implements one-hop neighbor sampling for a set of input nodes for a + specific edge type. + """ + if not self.is_hetero: + colptr = self._sampler.colptr + row = self._sampler.row + node_time = self.node_time + else: + rel_type = '__'.join(edge_type) + colptr = self._sampler.colptr_dict[rel_type] + row = self._sampler.row_dict[rel_type] + node_time = self.node_time.get(edge_type[2], + None) if self.node_time else None + + out = torch.ops.pyg.dist_neighbor_sample( + colptr, + row, + input_nodes.to(colptr.dtype), + num_neighbors, + node_time, + seed_time, + None, # TODO: edge_weight + True, # csc + self.replace, + self.subgraph_type != SubgraphType.induced, + self.disjoint and node_time is not None, + self.temporal_strategy, + ) + node, edge, cumsum_neighbors_per_node = out + + if self.disjoint and node_time is not None: + # We create a batch during the step of merging sampler outputs. + _, node = node.t().contiguous() + + return SamplerOutput( + node=node, + row=None, + col=None, + edge=edge, + batch=None, + metadata=(cumsum_neighbors_per_node, ), + ) + async def _collate_fn( self, output: Union[SamplerOutput, HeteroSamplerOutput] ) -> Union[SamplerOutput, HeteroSamplerOutput]: @@ -605,7 +659,7 @@ async def _collate_fn( nfeats = {} efeats = {} # Collect node labels of input node type. - node_labels = self.dist_feature.labels + node_labels = self.feature_store.labels if node_labels is not None: nlabels = node_labels[output.node[self.input_type]] else: @@ -614,48 +668,48 @@ async def _collate_fn( if output.node is not None: for ntype in output.node.keys(): if output.node[ntype].numel() > 0: - fut = self.dist_feature.lookup_features( + fut = self.feature_store.lookup_features( is_node_feat=True, index=output.node[ntype], input_type=ntype, ) - nfeat = await wrap_torch_future(fut) + nfeat = await to_asyncio_future(fut) nfeat = nfeat.to(torch.device("cpu")) nfeats[ntype] = nfeat else: nfeats[ntype] = None # Collect edge features if output.edge is not None and self.with_edge_attr: - for etype in output.edge.keys(): - if output.edge[etype].numel() > 0: - fut = self.dist_feature.lookup_features( + for edge_type in output.edge.keys(): + if output.edge[edge_type].numel() > 0: + fut = self.feature_store.lookup_features( is_node_feat=False, - index=output.edge[etype], - input_type=etype, + index=output.edge[edge_type], + input_type=edge_type, ) - efeat = await wrap_torch_future(fut) + efeat = await to_asyncio_future(fut) efeat = efeat.to(torch.device("cpu")) - efeats[etype] = efeat + efeats[edge_type] = efeat else: - efeats[etype] = None + efeats[edge_type] = None else: # Homo # Collect node labels. - nlabels = (self.dist_feature.labels[output.node] if - (self.dist_feature.labels is not None) else None) + nlabels = (self.feature_store.labels[output.node] if + (self.feature_store.labels is not None) else None) # Collect node features. if output.node is not None: - fut = self.dist_feature.lookup_features( + fut = self.feature_store.lookup_features( is_node_feat=True, index=output.node) - nfeats = await wrap_torch_future(fut) + nfeats = await to_asyncio_future(fut) nfeats = nfeats.to(torch.device("cpu")) # else: efeats = None # Collect edge features. if output.edge is not None and self.with_edge_attr: - fut = self.dist_feature.lookup_features( + fut = self.feature_store.lookup_features( is_node_feat=False, index=output.edge) - efeats = await wrap_torch_future(fut) + efeats = await to_asyncio_future(fut) efeats = efeats.to(torch.device("cpu")) else: efeats = None @@ -664,7 +718,7 @@ async def _collate_fn( return output def __repr__(self) -> str: - return f"{self.__class__.__name__}()-PID{mp.current_process().pid}" + return f'{self.__class__.__name__}(pid={mp.current_process().pid})' # Sampling Utilities ########################################################## diff --git a/torch_geometric/distributed/local_graph_store.py b/torch_geometric/distributed/local_graph_store.py index 01e215c2c912..336a474c6f41 100644 --- a/torch_geometric/distributed/local_graph_store.py +++ b/torch_geometric/distributed/local_graph_store.py @@ -98,6 +98,7 @@ def from_data( (CSC format). (default: :obj:`False`) """ graph_store = cls() + graph_store.meta = {'is_hetero': False} if not is_sorted: edge_index, edge_id = sort_edge_index( diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 52f6f533bea9..4ed97c9c3423 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -119,7 +119,10 @@ def __init__( feature_store, graph_store = data # Obtain graph metadata: - node_attrs = feature_store.get_all_tensor_attrs() + node_attrs = [ + attr for attr in feature_store.get_all_tensor_attrs() + if isinstance(attr.group_name, NodeType) + ] self.node_types = list(set(attr.group_name for attr in node_attrs)) edge_attrs = graph_store.get_all_edge_attrs() @@ -158,10 +161,10 @@ def __init__( self.edge_weight: Optional[Tensor] = None self.node_time: Optional[Tensor] = None - if time_attr is not None and len(time_attrs) != 1: - raise ValueError("Temporal sampling specified but did " - "not find any temporal data") - + if time_attr is not None: + if len(time_attrs) != 1: + raise ValueError("Temporal sampling specified but did " + "not find any temporal data") time_attrs[0].index = None # Reset index for full data. time_tensor = feature_store.get_tensor(time_attrs[0]) self.node_time = time_tensor @@ -224,7 +227,7 @@ def is_hetero(self) -> bool: return True # self.data_type == DataType.remote - return self.node_types != [None] + return self.edge_types != [None] @property def is_temporal(self) -> bool: @@ -433,56 +436,6 @@ def _sample( num_sampled_edges=num_sampled_edges, ) - def _sample_one_hop( - self, - input_nodes: Tensor, - num_neighbors: int, - seed_time: Optional[Tensor] = None, - edge_type: Optional[EdgeType] = None, - ) -> SamplerOutput: - r"""Implements one-hop neighbor sampling for a set of input nodes for a - specific edge type. - """ - rel_type = '__'.join(edge_type) if self.is_hetero else None - - if not self.is_hetero: - colptr = self.colptr - row = self.row - node_time = self.node_time - else: - rel_type = '__'.join(edge_type) - colptr = self.colptr_dict[rel_type] - row = self.row_dict[rel_type] - node_time = self.node_time.get(edge_type[2], None) - - out = torch.ops.pyg.dist_neighbor_sample( - colptr, - row, - input_nodes.to(colptr.dtype), - num_neighbors, - node_time, - seed_time, - None, # TODO: edge_weight - True, # csc - self.replace, - self.subgraph_type != SubgraphType.induced, - self.disjoint and node_time is not None, - self.temporal_strategy, - ) - node, edge, cumsum_neighbors_per_node = out - - if self.disjoint: - batch, node = node.t().contiguous() - - return SamplerOutput( - node=node, - row=None, - col=None, - edge=edge, - batch=batch, - metadata=(cumsum_neighbors_per_node, ), - ) - # Sampling Utilities ##########################################################