forked from pyg-team/pytorch_geometric
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
DistributedNeighborLoader
[3/6] (pyg-team#8080)
**[2/3] Distributed Loaders PRs** This PR includes`DistributedNeighborLoader` used for processing node sampler output in distributed training setup. 1. pyg-team#8079 2. pyg-team#8080 3. pyg-team#8085 Other PRs related to this module: DistSampler: pyg-team#7974 GraphStore\FeatureStore: pyg-team#8083 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey <[email protected]>
- Loading branch information
1 parent
f4ed34c
commit 40cc3b1
Showing
5 changed files
with
423 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
import socket | ||
|
||
import pytest | ||
import torch | ||
import torch.multiprocessing as mp | ||
|
||
from torch_geometric.data import Data, HeteroData | ||
from torch_geometric.datasets import FakeDataset, FakeHeteroDataset | ||
from torch_geometric.distributed import ( | ||
DistContext, | ||
DistNeighborLoader, | ||
DistNeighborSampler, | ||
LocalFeatureStore, | ||
LocalGraphStore, | ||
Partitioner, | ||
) | ||
from torch_geometric.distributed.partition import load_partition_info | ||
from torch_geometric.testing import onlyLinux, withPackage | ||
|
||
|
||
def create_dist_data(tmp_path, rank): | ||
graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) | ||
feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) | ||
( | ||
meta, | ||
num_partitions, | ||
partition_idx, | ||
node_pb, | ||
edge_pb, | ||
) = load_partition_info(tmp_path, rank) | ||
if meta['is_hetero']: | ||
node_pb = torch.cat(list(node_pb.values())) | ||
edge_pb = torch.cat(list(edge_pb.values())) | ||
else: | ||
feat_store.labels = torch.arange(node_pb.size()[0]) | ||
|
||
graph_store.partition_idx = partition_idx | ||
graph_store.num_partitions = num_partitions | ||
graph_store.node_pb = node_pb | ||
graph_store.edge_pb = edge_pb | ||
graph_store.meta = meta | ||
|
||
feat_store.partition_idx = partition_idx | ||
feat_store.num_partitions = num_partitions | ||
feat_store.node_feat_pb = node_pb | ||
feat_store.edge_feat_pb = edge_pb | ||
feat_store.meta = meta | ||
|
||
return feat_store, graph_store | ||
|
||
|
||
def dist_neighbor_loader_homo( | ||
tmp_path: str, | ||
world_size: int, | ||
rank: int, | ||
master_addr: str, | ||
master_port: int, | ||
num_workers: int, | ||
async_sampling: bool, | ||
device=torch.device('cpu'), | ||
): | ||
part_data = create_dist_data(tmp_path, rank) | ||
input_nodes = part_data[0].get_global_id(None) | ||
current_ctx = DistContext( | ||
rank=rank, | ||
global_rank=rank, | ||
world_size=world_size, | ||
global_world_size=world_size, | ||
group_name='dist-loader-test', | ||
) | ||
|
||
loader = DistNeighborLoader( | ||
part_data, | ||
num_neighbors=[1], | ||
batch_size=10, | ||
num_workers=num_workers, | ||
input_nodes=input_nodes, | ||
master_addr=master_addr, | ||
master_port=master_port, | ||
current_ctx=current_ctx, | ||
rpc_worker_names={}, | ||
concurrency=10, | ||
device=device, | ||
drop_last=True, | ||
async_sampling=async_sampling, | ||
) | ||
|
||
edge_index = part_data[1]._edge_index[(None, 'coo')] | ||
|
||
assert str(loader).startswith('DistNeighborLoader') | ||
assert str(mp.current_process().pid) in str(loader) | ||
assert isinstance(loader.neighbor_sampler, DistNeighborSampler) | ||
assert not part_data[0].meta['is_hetero'] | ||
|
||
for batch in loader: | ||
assert isinstance(batch, Data) | ||
assert batch.n_id.size() == (batch.num_nodes, ) | ||
assert batch.input_id.numel() == batch.batch_size == 10 | ||
assert batch.edge_index.device == device | ||
assert batch.edge_index.min() >= 0 | ||
assert batch.edge_index.max() < batch.num_nodes | ||
assert torch.equal( | ||
batch.n_id[batch.edge_index], | ||
edge_index[:, batch.e_id], | ||
) | ||
|
||
|
||
def dist_neighbor_loader_hetero( | ||
tmp_path: str, | ||
world_size: int, | ||
rank: int, | ||
master_addr: str, | ||
master_port: int, | ||
num_workers: int, | ||
async_sampling: bool, | ||
device=torch.device('cpu'), | ||
): | ||
part_data = create_dist_data(tmp_path, rank) | ||
input_nodes = ('v0', part_data[0].get_global_id('v0')) | ||
current_ctx = DistContext( | ||
rank=rank, | ||
global_rank=rank, | ||
world_size=world_size, | ||
global_world_size=world_size, | ||
group_name='dist-loader-test', | ||
) | ||
|
||
loader = DistNeighborLoader( | ||
part_data, | ||
num_neighbors=[10, 10], | ||
batch_size=10, | ||
num_workers=num_workers, | ||
input_nodes=input_nodes, | ||
master_addr=master_addr, | ||
master_port=master_port, | ||
current_ctx=current_ctx, | ||
rpc_worker_names={}, | ||
concurrency=10, | ||
device=device, | ||
drop_last=True, | ||
async_sampling=async_sampling, | ||
) | ||
|
||
assert str(loader).startswith('DistNeighborLoader') | ||
assert str(mp.current_process().pid) in str(loader) | ||
assert isinstance(loader.neighbor_sampler, DistNeighborSampler) | ||
assert part_data[0].meta['is_hetero'] | ||
|
||
for batch in loader: | ||
assert isinstance(batch, HeteroData) | ||
assert batch['v0'].input_id.numel() == batch['v0'].batch_size == 10 | ||
|
||
assert len(batch.node_types) == 2 | ||
for node_type in batch.node_types: | ||
assert torch.equal(batch[node_type].x, batch.x_dict[node_type]) | ||
assert batch.x_dict[node_type].device == device | ||
assert batch.x_dict[node_type].size(0) >= 0 | ||
assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes | ||
|
||
assert len(batch.edge_types) == 4 | ||
for edge_type in batch.edge_types: | ||
assert batch[edge_type].edge_index.device == device | ||
assert batch[edge_type].edge_attr.device == device | ||
assert (batch[edge_type].edge_attr.size(0) == | ||
batch[edge_type].edge_index.size(1)) | ||
|
||
if batch[edge_type].edge_index.numel() > 0: # Test edge mapping: | ||
src, _, dst = edge_type | ||
edge_index = part_data[1]._edge_index[(edge_type, "coo")] | ||
global_edge_index_1 = torch.stack([ | ||
batch[src].n_id[batch[edge_type].edge_index[0]], | ||
batch[dst].n_id[batch[edge_type].edge_index[1]], | ||
], dim=0) | ||
global_edge_index_2 = edge_index[:, batch[edge_type].e_id] | ||
assert torch.equal(global_edge_index_1, global_edge_index_2) | ||
|
||
|
||
@onlyLinux | ||
@withPackage('pyg_lib') | ||
@pytest.mark.parametrize('num_parts', [2]) | ||
@pytest.mark.parametrize('num_workers', [0]) | ||
@pytest.mark.parametrize('async_sampling', [True]) | ||
def test_dist_neighbor_loader_homo( | ||
tmp_path, | ||
num_parts, | ||
num_workers, | ||
async_sampling, | ||
): | ||
mp_context = torch.multiprocessing.get_context('spawn') | ||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||
s.bind(('127.0.0.1', 0)) | ||
port = s.getsockname()[1] | ||
s.close() | ||
addr = 'localhost' | ||
|
||
data = FakeDataset( | ||
num_graphs=1, | ||
avg_num_nodes=100, | ||
avg_degree=3, | ||
edge_dim=2, | ||
)[0] | ||
partitioner = Partitioner(data, num_parts, tmp_path) | ||
partitioner.generate_partition() | ||
|
||
w0 = mp_context.Process( | ||
target=dist_neighbor_loader_homo, | ||
args=(tmp_path, num_parts, 0, addr, port, num_workers, async_sampling), | ||
) | ||
|
||
w1 = mp_context.Process( | ||
target=dist_neighbor_loader_homo, | ||
args=(tmp_path, num_parts, 1, addr, port, num_workers, async_sampling), | ||
) | ||
|
||
w0.start() | ||
w1.start() | ||
w0.join() | ||
w1.join() | ||
|
||
|
||
@onlyLinux | ||
@withPackage('pyg_lib') | ||
@pytest.mark.parametrize('num_parts', [2]) | ||
@pytest.mark.parametrize('num_workers', [0]) | ||
@pytest.mark.parametrize('async_sampling', [True]) | ||
@pytest.mark.skip(reason="Breaks with no attribute 'num_hops'") | ||
def test_dist_neighbor_loader_hetero( | ||
tmp_path, | ||
num_parts, | ||
num_workers, | ||
async_sampling, | ||
): | ||
mp_context = torch.multiprocessing.get_context('spawn') | ||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||
s.bind(('127.0.0.1', 0)) | ||
port = s.getsockname()[1] | ||
s.close() | ||
addr = 'localhost' | ||
|
||
data = FakeHeteroDataset( | ||
num_graphs=1, | ||
avg_num_nodes=100, | ||
avg_degree=3, | ||
num_node_types=2, | ||
num_edge_types=4, | ||
edge_dim=2, | ||
)[0] | ||
partitioner = Partitioner(data, num_parts, tmp_path) | ||
partitioner.generate_partition() | ||
|
||
w0 = mp_context.Process( | ||
target=dist_neighbor_loader_hetero, | ||
args=(tmp_path, num_parts, 0, addr, port, num_workers, async_sampling), | ||
) | ||
|
||
w1 = mp_context.Process( | ||
target=dist_neighbor_loader_hetero, | ||
args=(tmp_path, num_parts, 1, addr, port, num_workers, async_sampling), | ||
) | ||
|
||
w0.start() | ||
w1.start() | ||
w0.join() | ||
w1.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,17 @@ | ||
from .dist_context import DistContext | ||
from .local_feature_store import LocalFeatureStore | ||
from .local_graph_store import LocalGraphStore | ||
from .partition import Partitioner | ||
from .dist_neighbor_sampler import DistNeighborSampler | ||
from .dist_loader import DistLoader | ||
from .dist_neighbor_loader import DistNeighborLoader | ||
|
||
__all__ = classes = [ | ||
'DistContext', | ||
'LocalFeatureStore', | ||
'LocalGraphStore', | ||
'Partitioner', | ||
'DistNeighborSampler', | ||
'DistLoader', | ||
'DistNeighborLoader', | ||
] |
Oops, something went wrong.