-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
**[1/3] Distributed Loaders PRs** This PR includes base class of `DistributedLoader` that handles RPC connection and handling requests from `DistributedNeighborSampler` processes. It includes basic `DistNeighborSampler` functions used by the loader. 1. #8079 2. #8080 3. #8085 Other PRs related to this module: DistSampler: #7974 GraphStore\FeatureStore: #8083 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <[email protected]>
- Loading branch information
1 parent
be380c3
commit 13b3243
Showing
5 changed files
with
304 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import atexit | ||
import logging | ||
import os | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import torch.multiprocessing as mp | ||
|
||
from torch_geometric.distributed.dist_context import DistContext, DistRole | ||
from torch_geometric.distributed.dist_neighbor_sampler import close_sampler | ||
from torch_geometric.distributed.rpc import global_barrier, init_rpc | ||
|
||
|
||
class DistLoader: | ||
r"""A base class for creating distributed data loading routines. | ||
Args: | ||
current_ctx (DistContext): Distributed context info of the current | ||
process. | ||
rpc_worker_names (Dict[DistRole, List[str]]): RPC workers identifiers. | ||
master_addr (str, optional): RPC address for distributed loader | ||
communication. | ||
Refers to the IP address of the master node. (default: :obj:`None`) | ||
master_port (int or str, optional): The open port for RPC communication | ||
with the master node. (default: :obj:`None`) | ||
channel (mp.Queue, optional): A communication channel for messages. | ||
(default: :obj:`None`) | ||
num_rpc_threads (int, optional): The number of threads in the | ||
thread-pool used by | ||
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute | ||
requests. (default: :obj:`16`) | ||
rpc_timeout (int, optional): The default timeout in seconds for RPC | ||
requests. | ||
If the RPC has not completed in this timeframe, an exception will | ||
be raised. | ||
Callers can override this timeout for | ||
individual RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and | ||
:meth:`~torch.distributed.rpc.rpc_async` if necessary. | ||
(default: :obj:`180`) | ||
""" | ||
def __init__( | ||
self, | ||
current_ctx: DistContext, | ||
rpc_worker_names: Dict[DistRole, List[str]], | ||
master_addr: Optional[str] = None, | ||
master_port: Optional[Union[int, str]] = None, | ||
channel: Optional[mp.Queue] = None, | ||
num_rpc_threads: int = 16, | ||
rpc_timeout: int = 180, | ||
**kwargs, | ||
): | ||
if master_addr is None and os.environ.get('MASTER_ADDR') is not None: | ||
master_addr = os.environ['MASTER_ADDR'] | ||
if master_addr is None: | ||
raise ValueError(f"Missing master address for RPC communication " | ||
f"in '{self.__class__.__name__}'. Try to provide " | ||
f"it or set it via the 'MASTER_ADDR' environment " | ||
f"variable.") | ||
|
||
if master_port is None and os.environ.get('MASTER_PORT') is not None: | ||
master_port = int(os.environ['MASTER_PORT']) | ||
if master_port is None: | ||
raise ValueError(f"Missing master port for RPC communication in " | ||
f"'{self.__class__.__name__}'. Try to provide it " | ||
f"or set it via the 'MASTER_ADDR' environment " | ||
f"variable.") | ||
|
||
assert num_rpc_threads > 0 | ||
assert rpc_timeout > 0 | ||
|
||
self.current_ctx = current_ctx | ||
self.rpc_worker_names = rpc_worker_names | ||
self.master_addr = master_addr | ||
self.master_port = master_port | ||
self.channel = channel or mp.Queue() | ||
self.pid = mp.current_process().pid | ||
self.num_rpc_threads = num_rpc_threads | ||
self.rpc_timeout = rpc_timeout | ||
self.num_workers = kwargs.get('num_workers', 0) | ||
|
||
logging.info(f"[{self}] MASTER_ADDR={master_addr}, " | ||
f"MASTER_PORT={master_port}") | ||
|
||
if self.num_workers == 0: # Initialize RPC in main process: | ||
self.worker_init_fn(0) | ||
|
||
def channel_get(self, out: Any) -> Any: | ||
if self.channel is not None: | ||
out = self.channel.get() | ||
logging.debug(f"[{self}] Retrieved message") | ||
return out | ||
|
||
def worker_init_fn(self, worker_id: int): | ||
try: | ||
num_sampler_proc = self.num_workers if self.num_workers > 0 else 1 | ||
self.current_ctx_worker = DistContext( | ||
world_size=self.current_ctx.world_size * num_sampler_proc, | ||
rank=self.current_ctx.rank * num_sampler_proc + worker_id, | ||
global_world_size=self.current_ctx.world_size * | ||
num_sampler_proc, | ||
global_rank=self.current_ctx.rank * num_sampler_proc + | ||
worker_id, | ||
group_name='mp_sampling_worker', | ||
) | ||
|
||
init_rpc( | ||
current_ctx=self.current_ctx_worker, | ||
rpc_worker_names={}, | ||
master_addr=self.master_addr, | ||
master_port=self.master_port, | ||
num_rpc_threads=self.num_rpc_threads, | ||
rpc_timeout=self.rpc_timeout, | ||
) | ||
assert hasattr(self, 'neighbor_sampler') | ||
self.neighbor_sampler.register_sampler_rpc() | ||
self.neighbor_sampler.init_event_loop() | ||
# close RPC & worker group at exit: | ||
atexit.register(close_sampler, worker_id, self.neighbor_sampler) | ||
global_barrier(timeout=10) # Wait for all workers to initialize. | ||
|
||
except RuntimeError: | ||
raise RuntimeError(f"`{self}.init_fn()` could not initialize the " | ||
f"worker loop of the neighbor sampler") | ||
|
||
def __repr__(self) -> str: | ||
return f'{self.__class__.__name__}(pid={self.pid})' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import logging | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import torch.multiprocessing as mp | ||
|
||
from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore | ||
from torch_geometric.distributed.dist_context import DistContext, DistRole | ||
from torch_geometric.distributed.event_loop import ConcurrentEventLoop | ||
from torch_geometric.distributed.rpc import ( | ||
RPCCallBase, | ||
RPCRouter, | ||
rpc_partition_to_workers, | ||
rpc_register, | ||
shutdown_rpc, | ||
) | ||
from torch_geometric.sampler import NeighborSampler | ||
from torch_geometric.sampler.base import NumNeighbors, SubgraphType | ||
from torch_geometric.typing import EdgeType | ||
|
||
NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]] | ||
|
||
|
||
class RPCSamplingCallee(RPCCallBase): | ||
r"""A wrapper for RPC callee that will perform RPC sampling from remote | ||
processes.""" | ||
def __init__(self, sampler: NeighborSampler): | ||
super().__init__() | ||
self.sampler = sampler | ||
|
||
def rpc_async(self, *args, **kwargs) -> Any: | ||
return self.sampler._sample_one_hop(*args, **kwargs) | ||
|
||
def rpc_sync(self, *args, **kwargs) -> Any: | ||
pass | ||
|
||
|
||
class DistNeighborSampler: | ||
r"""An implementation of a distributed and asynchronised neighbor sampler | ||
used by :class:`~torch_geometric.distributed.DistNeighborLoader`.""" | ||
def __init__( | ||
self, | ||
current_ctx: DistContext, | ||
rpc_worker_names: Dict[DistRole, List[str]], | ||
data: Tuple[LocalGraphStore, LocalFeatureStore], | ||
num_neighbors: NumNeighborsType, | ||
channel: Optional[mp.Queue] = None, | ||
replace: bool = False, | ||
subgraph_type: Union[SubgraphType, str] = 'directional', | ||
disjoint: bool = False, | ||
temporal_strategy: str = 'uniform', | ||
time_attr: Optional[str] = None, | ||
concurrency: int = 1, | ||
**kwargs, | ||
): | ||
self.current_ctx = current_ctx | ||
self.rpc_worker_names = rpc_worker_names | ||
|
||
self.feature_store, self.graph_store = data | ||
assert isinstance(self.dist_graph, LocalGraphStore) | ||
assert isinstance(self.dist_feature_store, LocalFeatureStore) | ||
self.is_hetero = self.dist_graph.meta['is_hetero'] | ||
|
||
self.num_neighbors = num_neighbors | ||
self.channel = channel or mp.Queue() | ||
self.concurrency = concurrency | ||
self.event_loop = None | ||
self.replace = replace | ||
self.subgraph_type = SubgraphType(subgraph_type) | ||
self.disjoint = disjoint | ||
self.temporal_strategy = temporal_strategy | ||
self.time_attr = time_attr | ||
self.with_edge_attr = self.dist_feature.has_edge_attr() | ||
self.edge_permutation = None # TODO: Debug edge_perm for LinkLoader | ||
|
||
def register_sampler_rpc(self) -> None: | ||
partition2workers = rpc_partition_to_workers( | ||
current_ctx=self.current_ctx, | ||
num_partitions=self.dist_graph.num_partitions, | ||
current_partition_idx=self.dist_graph.partition_idx, | ||
) | ||
self.rpc_router = RPCRouter(partition2workers) | ||
self.dist_feature.set_rpc_router(self.rpc_router) | ||
|
||
self._sampler = NeighborSampler( | ||
data=(self.dist_feature_store, self.dist_graph_store), | ||
num_neighbors=self.num_neighbors, | ||
subgraph_type=self.subgraph_type, | ||
replace=self.replace, | ||
disjoint=self.disjoint, | ||
temporal_strategy=self.temporal_strategy, | ||
time_attr=self.time_attr, | ||
) | ||
rpc_sample_callee = RPCSamplingCallee(self._sampler) | ||
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee) | ||
|
||
def init_event_loop(self) -> None: | ||
self.event_loop = ConcurrentEventLoop(self.concurrency) | ||
self.event_loop.start_loop() | ||
|
||
|
||
# Sampling Utilities ########################################################## | ||
|
||
|
||
def close_sampler(worker_id: int, sampler: DistNeighborSampler): | ||
# Make sure that mp.Queue is empty at exit and RAM is cleared: | ||
try: | ||
logging.info(f"Closing event loop for worker ID {worker_id}") | ||
sampler.event_loop.shutdown_loop() | ||
except AttributeError: | ||
pass | ||
logging.info(f"Closing RPC for worker ID {worker_id}") | ||
shutdown_rpc(graceful=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torch_geometric.data import HeteroData | ||
from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore | ||
|
||
|
||
def filter_dist_store( | ||
feature_store: LocalFeatureStore, | ||
graph_store: LocalGraphStore, | ||
node_dict: Dict[str, Tensor], | ||
row_dict: Dict[str, Tensor], | ||
col_dict: Dict[str, Tensor], | ||
edge_dict: Dict[str, Optional[Tensor]], | ||
custom_cls: Optional[HeteroData] = None, | ||
meta: Optional[Dict[str, Tensor]] = None, | ||
) -> HeteroData: | ||
r"""Constructs a :class:`HeteroData` object from a feature store that only | ||
holds nodes in `node` end edges in `edge` for each node and edge type, | ||
respectively. Sorted attribute values are provided as metadata from | ||
:class:`DistNeighborSampler`.""" | ||
# Construct a new `HeteroData` object: | ||
data = custom_cls() if custom_cls is not None else HeteroData() | ||
nfeats, nlabels, efeats = meta[-3:] | ||
|
||
# Filter edge storage: | ||
required_edge_attrs = [] | ||
for attr in graph_store.get_all_edge_attrs(): | ||
key = attr.edge_type | ||
if key in row_dict and key in col_dict: | ||
required_edge_attrs.append(attr) | ||
edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0) | ||
data[attr.edge_type].edge_index = edge_index | ||
|
||
# Filter node storage: | ||
required_node_attrs = [] | ||
for attr in feature_store.get_all_tensor_attrs(): | ||
if attr.group_name in node_dict: | ||
attr.index = node_dict[attr.group_name] | ||
required_node_attrs.append(attr) | ||
data[attr.group_name].num_nodes = attr.index.size(0) | ||
|
||
if nfeats is not None: | ||
for attr in required_node_attrs: | ||
if nfeats[attr.group_name] is not None: | ||
data[attr.group_name][attr.attr_name] = nfeats[attr.group_name] | ||
|
||
if efeats is not None: | ||
for attr in required_edge_attrs: | ||
if efeats[attr.edge_type] is not None: | ||
data[attr.edge_type].edge_attr = efeats[attr.edge_type] | ||
|
||
for label in nlabels: | ||
data[label].y = nlabels[label] | ||
|
||
return data |