Skip to content

Commit

Permalink
Update DistNeighborSampler for homo graphs [2/6] (pyg-team#8209)
Browse files Browse the repository at this point in the history
**Changes made:**
- added support for temporal sampling
- use torch.Tensors instead of numpy arrays
- move _sample_one_hop() from NeighborSampler to DistNeighborSampler
- do not go with disjoint flow in _sample() function - this is not
needed because batch is calculated after
- added tests for node sampling and disjoint (works without
DistNeighborLoader)
- added tests for node temporal sampling (works without
DistNeighborLoader)
- some minor changes like changing variables names etc

This PR is based on the pyg-team#8083, so both must be combined to pass the
tests.

Other distributed PRs:
pyg-team#8083 
pyg-team#8080 
pyg-team#8085

---------

Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 10, 2023
1 parent 9ea2233 commit f4ed34c
Show file tree
Hide file tree
Showing 6 changed files with 442 additions and 150 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209))
- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))
- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))
Expand Down
283 changes: 283 additions & 0 deletions test/distributed/test_dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
import atexit
import socket

import pytest
import torch

from torch_geometric.data import Data
from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore
from torch_geometric.distributed.dist_context import DistContext
from torch_geometric.distributed.dist_neighbor_sampler import (
DistNeighborSampler,
close_sampler,
)
from torch_geometric.distributed.rpc import init_rpc
from torch_geometric.sampler import NeighborSampler, NodeSamplerInput
from torch_geometric.sampler.neighbor_sampler import node_sample
from torch_geometric.testing import onlyLinux, withPackage


def create_data(rank: int, world_size: int, temporal: bool = False):
if rank == 0: # Partition 0:
node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9])
edge_index = torch.tensor([ # Sorted by destination.
[1, 2, 3, 4, 5, 0, 0],
[0, 1, 2, 3, 4, 4, 9],
])
else: # Partition 1:
node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9])
edge_index = torch.tensor([ # Sorted by destination.
[5, 6, 7, 8, 9, 5, 0],
[4, 5, 6, 7, 8, 9, 9],
])

feature_store = LocalFeatureStore.from_data(node_id)
graph_store = LocalGraphStore.from_data(
edge_id=None,
edge_index=edge_index,
num_nodes=10,
is_sorted=True,
)

graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
graph_store.meta.update({'num_parts': 2})
graph_store.partition_idx = rank
graph_store.num_partitions = world_size

edge_index = torch.tensor([ # Create reference data:
[1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0],
[0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9],
])
data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10)

if temporal: # Create time data:
data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4])
feature_store.put_tensor(data.time, group_name=None, attr_name='time')

return (feature_store, graph_store), data


def dist_neighbor_sampler(
world_size: int,
rank: int,
master_port: int,
disjoint: bool = False,
):
dist_data, data = create_data(rank, world_size)

current_ctx = DistContext(
rank=rank,
global_rank=rank,
world_size=world_size,
global_world_size=world_size,
group_name='dist-sampler-test',
)

# Initialize training process group of PyTorch:
torch.distributed.init_process_group(
backend='gloo',
rank=current_ctx.rank,
world_size=current_ctx.world_size,
init_method=f'tcp://localhost:{master_port}',
)

dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=[-1, -1],
shuffle=False,
disjoint=disjoint,
)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.register_sampler_rpc()
dist_sampler.init_event_loop()

# Close RPC & worker group at exit:
atexit.register(close_sampler, 0, dist_sampler)
torch.distributed.barrier()

if rank == 0: # Seed nodes:
input_node = torch.tensor([1, 6])
else:
input_node = torch.tensor([4, 9])

inputs = NodeSamplerInput(input_id=None, node=input_node)

# Evaluate distributed node sample function:
out_dist = dist_sampler.event_loop.run_task(
coro=dist_sampler.node_sample(inputs))

torch.distributed.barrier()

sampler = NeighborSampler(
data=data,
num_neighbors=[-1, -1],
disjoint=disjoint,
)

# Evaluate node sample function:
out = node_sample(inputs, sampler._sample)

# Compare distributed output with single machine output:
assert torch.equal(out_dist.node, out.node)
assert torch.equal(out_dist.row, out.row)
assert torch.equal(out_dist.col, out.col)
if disjoint:
assert torch.equal(out_dist.batch, out.batch)
assert out_dist.num_sampled_nodes == out.num_sampled_nodes
assert out_dist.num_sampled_edges == out.num_sampled_edges

torch.distributed.barrier()


def dist_neighbor_sampler_temporal(
world_size: int,
rank: int,
master_port: int,
seed_time: torch.tensor = None,
temporal_strategy: str = 'uniform',
):
dist_data, data = create_data(rank, world_size, temporal=True)

current_ctx = DistContext(
rank=rank,
global_rank=rank,
world_size=world_size,
global_world_size=world_size,
group_name='dist-sampler-test',
)

# Initialize training process group of PyTorch:
torch.distributed.init_process_group(
backend='gloo',
rank=current_ctx.rank,
world_size=current_ctx.world_size,
init_method=f'tcp://localhost:{master_port}',
)

num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1]
dist_sampler = DistNeighborSampler(
data=dist_data,
current_ctx=current_ctx,
rpc_worker_names={},
num_neighbors=num_neighbors,
shuffle=False,
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr='time',
)

init_rpc(
current_ctx=current_ctx,
rpc_worker_names={},
master_addr='localhost',
master_port=master_port,
)

dist_sampler.register_sampler_rpc()
dist_sampler.init_event_loop()

# Close RPC & worker group at exit:
atexit.register(close_sampler, 0, dist_sampler)
torch.distributed.barrier()

if rank == 0: # Seed nodes:
input_node = torch.tensor([1, 6], dtype=torch.int64)
else:
input_node = torch.tensor([4, 9], dtype=torch.int64)

inputs = NodeSamplerInput(
input_id=None,
node=input_node,
time=seed_time,
)

# Evaluate distributed node sample function:
out_dist = dist_sampler.event_loop.run_task(
coro=dist_sampler.node_sample(inputs))

torch.distributed.barrier()

sampler = NeighborSampler(
data=data,
num_neighbors=num_neighbors,
disjoint=True,
temporal_strategy=temporal_strategy,
time_attr='time',
)

# Evaluate node sample function:
out = node_sample(inputs, sampler._sample)

# Compare distributed output with single machine output:
assert torch.equal(out_dist.node, out.node)
assert torch.equal(out_dist.row, out.row)
assert torch.equal(out_dist.col, out.col)
assert torch.equal(out_dist.batch, out.batch)
assert out_dist.num_sampled_nodes == out.num_sampled_nodes
assert out_dist.num_sampled_edges == out.num_sampled_edges

torch.distributed.barrier()


@onlyLinux
@withPackage('pyg_lib')
@pytest.mark.parametrize('disjoint', [False, True])
def test_dist_neighbor_sampler(disjoint):
mp_context = torch.multiprocessing.get_context('spawn')
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()

world_size = 2
w0 = mp_context.Process(
target=dist_neighbor_sampler,
args=(world_size, 0, port, disjoint),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler,
args=(world_size, 1, port, disjoint),
)

w0.start()
w1.start()
w0.join()
w1.join()


@onlyLinux
@withPackage('pyg_lib')
@pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])])
@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])
def test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy):
mp_context = torch.multiprocessing.get_context('spawn')
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()

world_size = 2
w0 = mp_context.Process(
target=dist_neighbor_sampler_temporal,
args=(world_size, 0, port, seed_time, temporal_strategy),
)

w1 = mp_context.Process(
target=dist_neighbor_sampler_temporal,
args=(world_size, 1, port, seed_time, temporal_strategy),
)

w0.start()
w1.start()
w0.join()
w1.join()
2 changes: 1 addition & 1 deletion test/distributed/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run_rpc_feature_test(

feature.num_partitions = world_size
feature.partition_idx = rank
feature.feature_pb = partition_book
feature.node_feat_pb = partition_book
feature.meta = meta
feature.local_only = False
feature.set_rpc_router(rpc_router)
Expand Down
Loading

0 comments on commit f4ed34c

Please sign in to comment.