Skip to content

Commit

Permalink
Add base class DistLoader (#8079)
Browse files Browse the repository at this point in the history
**[1/3] Distributed Loaders PRs** 
This PR includes base class of `DistributedLoader` that handles RPC
connection and handling requests from `DistributedNeighborSampler`
processes.

It includes basic `DistNeighborSampler` functions used by the loader.

1.  #8079
2.  #8080
3.  #8085

Other PRs related to this module:
DistSampler: #7974
GraphStore\FeatureStore:
#8083

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2023
1 parent be380c3 commit 13b3243
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 8 deletions.
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `DistLoader` base class ([#8079](https://github.com/pyg-team/pytorch_geometric/pull/8079))
- Added `HyperGraphData` to support hypergraphs ([#7611](https://github.com/pyg-team/pytorch_geometric/pull/7611))
- Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102)
- Added `module_headers` property to `nn.Sequential` models ([#8093](https://github.com/pyg-team/pytorch_geometric/pull/8093)
- Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102))
- Added `module_headers` property to `nn.Sequential` models ([#8093](https://github.com/pyg-team/pytorch_geometric/pull/8093))
- Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088), [#8092](https://github.com/pyg-team/pytorch_geometric/pull/8092), [#8106](https://github.com/pyg-team/pytorch_geometric/pull/8106))
- Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938)
- Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894)
- Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938))
- Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894))
- Added `edge_attr` support to `ResGatedGraphConv` ([#8048](https://github.com/pyg-team/pytorch_geometric/pull/8048))
- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054), [#8057](https://github.com/pyg-team/pytorch_geometric/pull/8057), [#8058](https://github.com/pyg-team/pytorch_geometric/pull/8058))
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
Expand Down
125 changes: 125 additions & 0 deletions torch_geometric/distributed/dist_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import atexit
import logging
import os
from typing import Any, Dict, List, Optional, Union

import torch.multiprocessing as mp

from torch_geometric.distributed.dist_context import DistContext, DistRole
from torch_geometric.distributed.dist_neighbor_sampler import close_sampler
from torch_geometric.distributed.rpc import global_barrier, init_rpc


class DistLoader:
r"""A base class for creating distributed data loading routines.
Args:
current_ctx (DistContext): Distributed context info of the current
process.
rpc_worker_names (Dict[DistRole, List[str]]): RPC workers identifiers.
master_addr (str, optional): RPC address for distributed loader
communication.
Refers to the IP address of the master node. (default: :obj:`None`)
master_port (int or str, optional): The open port for RPC communication
with the master node. (default: :obj:`None`)
channel (mp.Queue, optional): A communication channel for messages.
(default: :obj:`None`)
num_rpc_threads (int, optional): The number of threads in the
thread-pool used by
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
requests. (default: :obj:`16`)
rpc_timeout (int, optional): The default timeout in seconds for RPC
requests.
If the RPC has not completed in this timeframe, an exception will
be raised.
Callers can override this timeout for
individual RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
(default: :obj:`180`)
"""
def __init__(
self,
current_ctx: DistContext,
rpc_worker_names: Dict[DistRole, List[str]],
master_addr: Optional[str] = None,
master_port: Optional[Union[int, str]] = None,
channel: Optional[mp.Queue] = None,
num_rpc_threads: int = 16,
rpc_timeout: int = 180,
**kwargs,
):
if master_addr is None and os.environ.get('MASTER_ADDR') is not None:
master_addr = os.environ['MASTER_ADDR']
if master_addr is None:
raise ValueError(f"Missing master address for RPC communication "
f"in '{self.__class__.__name__}'. Try to provide "
f"it or set it via the 'MASTER_ADDR' environment "
f"variable.")

if master_port is None and os.environ.get('MASTER_PORT') is not None:
master_port = int(os.environ['MASTER_PORT'])
if master_port is None:
raise ValueError(f"Missing master port for RPC communication in "
f"'{self.__class__.__name__}'. Try to provide it "
f"or set it via the 'MASTER_ADDR' environment "
f"variable.")

assert num_rpc_threads > 0
assert rpc_timeout > 0

self.current_ctx = current_ctx
self.rpc_worker_names = rpc_worker_names
self.master_addr = master_addr
self.master_port = master_port
self.channel = channel or mp.Queue()
self.pid = mp.current_process().pid
self.num_rpc_threads = num_rpc_threads
self.rpc_timeout = rpc_timeout
self.num_workers = kwargs.get('num_workers', 0)

logging.info(f"[{self}] MASTER_ADDR={master_addr}, "
f"MASTER_PORT={master_port}")

if self.num_workers == 0: # Initialize RPC in main process:
self.worker_init_fn(0)

def channel_get(self, out: Any) -> Any:
if self.channel is not None:
out = self.channel.get()
logging.debug(f"[{self}] Retrieved message")
return out

def worker_init_fn(self, worker_id: int):
try:
num_sampler_proc = self.num_workers if self.num_workers > 0 else 1
self.current_ctx_worker = DistContext(
world_size=self.current_ctx.world_size * num_sampler_proc,
rank=self.current_ctx.rank * num_sampler_proc + worker_id,
global_world_size=self.current_ctx.world_size *
num_sampler_proc,
global_rank=self.current_ctx.rank * num_sampler_proc +
worker_id,
group_name='mp_sampling_worker',
)

init_rpc(
current_ctx=self.current_ctx_worker,
rpc_worker_names={},
master_addr=self.master_addr,
master_port=self.master_port,
num_rpc_threads=self.num_rpc_threads,
rpc_timeout=self.rpc_timeout,
)
assert hasattr(self, 'neighbor_sampler')
self.neighbor_sampler.register_sampler_rpc()
self.neighbor_sampler.init_event_loop()
# close RPC & worker group at exit:
atexit.register(close_sampler, worker_id, self.neighbor_sampler)
global_barrier(timeout=10) # Wait for all workers to initialize.

except RuntimeError:
raise RuntimeError(f"`{self}.init_fn()` could not initialize the "
f"worker loop of the neighbor sampler")

def __repr__(self) -> str:
return f'{self.__class__.__name__}(pid={self.pid})'
112 changes: 112 additions & 0 deletions torch_geometric/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import torch.multiprocessing as mp

from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore
from torch_geometric.distributed.dist_context import DistContext, DistRole
from torch_geometric.distributed.event_loop import ConcurrentEventLoop
from torch_geometric.distributed.rpc import (
RPCCallBase,
RPCRouter,
rpc_partition_to_workers,
rpc_register,
shutdown_rpc,
)
from torch_geometric.sampler import NeighborSampler
from torch_geometric.sampler.base import NumNeighbors, SubgraphType
from torch_geometric.typing import EdgeType

NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]


class RPCSamplingCallee(RPCCallBase):
r"""A wrapper for RPC callee that will perform RPC sampling from remote
processes."""
def __init__(self, sampler: NeighborSampler):
super().__init__()
self.sampler = sampler

def rpc_async(self, *args, **kwargs) -> Any:
return self.sampler._sample_one_hop(*args, **kwargs)

def rpc_sync(self, *args, **kwargs) -> Any:
pass


class DistNeighborSampler:
r"""An implementation of a distributed and asynchronised neighbor sampler
used by :class:`~torch_geometric.distributed.DistNeighborLoader`."""
def __init__(
self,
current_ctx: DistContext,
rpc_worker_names: Dict[DistRole, List[str]],
data: Tuple[LocalGraphStore, LocalFeatureStore],
num_neighbors: NumNeighborsType,
channel: Optional[mp.Queue] = None,
replace: bool = False,
subgraph_type: Union[SubgraphType, str] = 'directional',
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
concurrency: int = 1,
**kwargs,
):
self.current_ctx = current_ctx
self.rpc_worker_names = rpc_worker_names

self.feature_store, self.graph_store = data
assert isinstance(self.dist_graph, LocalGraphStore)
assert isinstance(self.dist_feature_store, LocalFeatureStore)
self.is_hetero = self.dist_graph.meta['is_hetero']

self.num_neighbors = num_neighbors
self.channel = channel or mp.Queue()
self.concurrency = concurrency
self.event_loop = None
self.replace = replace
self.subgraph_type = SubgraphType(subgraph_type)
self.disjoint = disjoint
self.temporal_strategy = temporal_strategy
self.time_attr = time_attr
self.with_edge_attr = self.dist_feature.has_edge_attr()
self.edge_permutation = None # TODO: Debug edge_perm for LinkLoader

def register_sampler_rpc(self) -> None:
partition2workers = rpc_partition_to_workers(
current_ctx=self.current_ctx,
num_partitions=self.dist_graph.num_partitions,
current_partition_idx=self.dist_graph.partition_idx,
)
self.rpc_router = RPCRouter(partition2workers)
self.dist_feature.set_rpc_router(self.rpc_router)

self._sampler = NeighborSampler(
data=(self.dist_feature_store, self.dist_graph_store),
num_neighbors=self.num_neighbors,
subgraph_type=self.subgraph_type,
replace=self.replace,
disjoint=self.disjoint,
temporal_strategy=self.temporal_strategy,
time_attr=self.time_attr,
)
rpc_sample_callee = RPCSamplingCallee(self._sampler)
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)

def init_event_loop(self) -> None:
self.event_loop = ConcurrentEventLoop(self.concurrency)
self.event_loop.start_loop()


# Sampling Utilities ##########################################################


def close_sampler(worker_id: int, sampler: DistNeighborSampler):
# Make sure that mp.Queue is empty at exit and RAM is cleared:
try:
logging.info(f"Closing event loop for worker ID {worker_id}")
sampler.event_loop.shutdown_loop()
except AttributeError:
pass
logging.info(f"Closing RPC for worker ID {worker_id}")
shutdown_rpc(graceful=True)
8 changes: 4 additions & 4 deletions torch_geometric/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import threading
from abc import ABC, abstractmethod
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional

from torch.distributed import rpc

Expand All @@ -23,15 +23,15 @@ def rpc_require_initialized(func: Callable) -> Callable:


@rpc_require_initialized
def global_all_gather(obj, timeout=None):
def global_all_gather(obj, timeout: Optional[int] = None):
r"""Gathers objects from all groups in a list."""
if timeout is None:
return rpc.api._all_gather(obj)
return rpc.api._all_gather(obj, timeout=timeout)


@rpc_require_initialized
def global_barrier(timeout=None):
def global_barrier(timeout: Optional[int] = None):
r""" Block until all local and remote RPC processes."""
try:
global_all_gather(obj=None, timeout=timeout)
Expand All @@ -45,7 +45,7 @@ def init_rpc(
master_addr: str,
master_port: int,
num_rpc_threads: int = 16,
rpc_timeout: float = 240,
rpc_timeout: int = 240,
):
with _rpc_init_lock:
if rpc_is_initialized():
Expand Down
58 changes: 58 additions & 0 deletions torch_geometric/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Dict, Optional

import torch
from torch import Tensor

from torch_geometric.data import HeteroData
from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore


def filter_dist_store(
feature_store: LocalFeatureStore,
graph_store: LocalGraphStore,
node_dict: Dict[str, Tensor],
row_dict: Dict[str, Tensor],
col_dict: Dict[str, Tensor],
edge_dict: Dict[str, Optional[Tensor]],
custom_cls: Optional[HeteroData] = None,
meta: Optional[Dict[str, Tensor]] = None,
) -> HeteroData:
r"""Constructs a :class:`HeteroData` object from a feature store that only
holds nodes in `node` end edges in `edge` for each node and edge type,
respectively. Sorted attribute values are provided as metadata from
:class:`DistNeighborSampler`."""
# Construct a new `HeteroData` object:
data = custom_cls() if custom_cls is not None else HeteroData()
nfeats, nlabels, efeats = meta[-3:]

# Filter edge storage:
required_edge_attrs = []
for attr in graph_store.get_all_edge_attrs():
key = attr.edge_type
if key in row_dict and key in col_dict:
required_edge_attrs.append(attr)
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
data[attr.edge_type].edge_index = edge_index

# Filter node storage:
required_node_attrs = []
for attr in feature_store.get_all_tensor_attrs():
if attr.group_name in node_dict:
attr.index = node_dict[attr.group_name]
required_node_attrs.append(attr)
data[attr.group_name].num_nodes = attr.index.size(0)

if nfeats is not None:
for attr in required_node_attrs:
if nfeats[attr.group_name] is not None:
data[attr.group_name][attr.attr_name] = nfeats[attr.group_name]

if efeats is not None:
for attr in required_edge_attrs:
if efeats[attr.edge_type] is not None:
data[attr.edge_type].edge_attr = efeats[attr.edge_type]

for label in nlabels:
data[label].y = nlabels[label]

return data

0 comments on commit 13b3243

Please sign in to comment.