forked from pyg-team/pytorch_geometric
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update
DistNeighborSampler
for homo graphs [2/6] (pyg-team#8209)
**Changes made:** - added support for temporal sampling - use torch.Tensors instead of numpy arrays - move _sample_one_hop() from NeighborSampler to DistNeighborSampler - do not go with disjoint flow in _sample() function - this is not needed because batch is calculated after - added tests for node sampling and disjoint (works without DistNeighborLoader) - added tests for node temporal sampling (works without DistNeighborLoader) - some minor changes like changing variables names etc This PR is based on the pyg-team#8083, so both must be combined to pass the tests. Other distributed PRs: pyg-team#8083 pyg-team#8080 pyg-team#8085 --------- Co-authored-by: Matthias Fey <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
9ea2233
commit f4ed34c
Showing
6 changed files
with
442 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,283 @@ | ||
import atexit | ||
import socket | ||
|
||
import pytest | ||
import torch | ||
|
||
from torch_geometric.data import Data | ||
from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore | ||
from torch_geometric.distributed.dist_context import DistContext | ||
from torch_geometric.distributed.dist_neighbor_sampler import ( | ||
DistNeighborSampler, | ||
close_sampler, | ||
) | ||
from torch_geometric.distributed.rpc import init_rpc | ||
from torch_geometric.sampler import NeighborSampler, NodeSamplerInput | ||
from torch_geometric.sampler.neighbor_sampler import node_sample | ||
from torch_geometric.testing import onlyLinux, withPackage | ||
|
||
|
||
def create_data(rank: int, world_size: int, temporal: bool = False): | ||
if rank == 0: # Partition 0: | ||
node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9]) | ||
edge_index = torch.tensor([ # Sorted by destination. | ||
[1, 2, 3, 4, 5, 0, 0], | ||
[0, 1, 2, 3, 4, 4, 9], | ||
]) | ||
else: # Partition 1: | ||
node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9]) | ||
edge_index = torch.tensor([ # Sorted by destination. | ||
[5, 6, 7, 8, 9, 5, 0], | ||
[4, 5, 6, 7, 8, 9, 9], | ||
]) | ||
|
||
feature_store = LocalFeatureStore.from_data(node_id) | ||
graph_store = LocalGraphStore.from_data( | ||
edge_id=None, | ||
edge_index=edge_index, | ||
num_nodes=10, | ||
is_sorted=True, | ||
) | ||
|
||
graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) | ||
graph_store.meta.update({'num_parts': 2}) | ||
graph_store.partition_idx = rank | ||
graph_store.num_partitions = world_size | ||
|
||
edge_index = torch.tensor([ # Create reference data: | ||
[1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0], | ||
[0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9], | ||
]) | ||
data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10) | ||
|
||
if temporal: # Create time data: | ||
data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4]) | ||
feature_store.put_tensor(data.time, group_name=None, attr_name='time') | ||
|
||
return (feature_store, graph_store), data | ||
|
||
|
||
def dist_neighbor_sampler( | ||
world_size: int, | ||
rank: int, | ||
master_port: int, | ||
disjoint: bool = False, | ||
): | ||
dist_data, data = create_data(rank, world_size) | ||
|
||
current_ctx = DistContext( | ||
rank=rank, | ||
global_rank=rank, | ||
world_size=world_size, | ||
global_world_size=world_size, | ||
group_name='dist-sampler-test', | ||
) | ||
|
||
# Initialize training process group of PyTorch: | ||
torch.distributed.init_process_group( | ||
backend='gloo', | ||
rank=current_ctx.rank, | ||
world_size=current_ctx.world_size, | ||
init_method=f'tcp://localhost:{master_port}', | ||
) | ||
|
||
dist_sampler = DistNeighborSampler( | ||
data=dist_data, | ||
current_ctx=current_ctx, | ||
rpc_worker_names={}, | ||
num_neighbors=[-1, -1], | ||
shuffle=False, | ||
disjoint=disjoint, | ||
) | ||
|
||
init_rpc( | ||
current_ctx=current_ctx, | ||
rpc_worker_names={}, | ||
master_addr='localhost', | ||
master_port=master_port, | ||
) | ||
|
||
dist_sampler.register_sampler_rpc() | ||
dist_sampler.init_event_loop() | ||
|
||
# Close RPC & worker group at exit: | ||
atexit.register(close_sampler, 0, dist_sampler) | ||
torch.distributed.barrier() | ||
|
||
if rank == 0: # Seed nodes: | ||
input_node = torch.tensor([1, 6]) | ||
else: | ||
input_node = torch.tensor([4, 9]) | ||
|
||
inputs = NodeSamplerInput(input_id=None, node=input_node) | ||
|
||
# Evaluate distributed node sample function: | ||
out_dist = dist_sampler.event_loop.run_task( | ||
coro=dist_sampler.node_sample(inputs)) | ||
|
||
torch.distributed.barrier() | ||
|
||
sampler = NeighborSampler( | ||
data=data, | ||
num_neighbors=[-1, -1], | ||
disjoint=disjoint, | ||
) | ||
|
||
# Evaluate node sample function: | ||
out = node_sample(inputs, sampler._sample) | ||
|
||
# Compare distributed output with single machine output: | ||
assert torch.equal(out_dist.node, out.node) | ||
assert torch.equal(out_dist.row, out.row) | ||
assert torch.equal(out_dist.col, out.col) | ||
if disjoint: | ||
assert torch.equal(out_dist.batch, out.batch) | ||
assert out_dist.num_sampled_nodes == out.num_sampled_nodes | ||
assert out_dist.num_sampled_edges == out.num_sampled_edges | ||
|
||
torch.distributed.barrier() | ||
|
||
|
||
def dist_neighbor_sampler_temporal( | ||
world_size: int, | ||
rank: int, | ||
master_port: int, | ||
seed_time: torch.tensor = None, | ||
temporal_strategy: str = 'uniform', | ||
): | ||
dist_data, data = create_data(rank, world_size, temporal=True) | ||
|
||
current_ctx = DistContext( | ||
rank=rank, | ||
global_rank=rank, | ||
world_size=world_size, | ||
global_world_size=world_size, | ||
group_name='dist-sampler-test', | ||
) | ||
|
||
# Initialize training process group of PyTorch: | ||
torch.distributed.init_process_group( | ||
backend='gloo', | ||
rank=current_ctx.rank, | ||
world_size=current_ctx.world_size, | ||
init_method=f'tcp://localhost:{master_port}', | ||
) | ||
|
||
num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1] | ||
dist_sampler = DistNeighborSampler( | ||
data=dist_data, | ||
current_ctx=current_ctx, | ||
rpc_worker_names={}, | ||
num_neighbors=num_neighbors, | ||
shuffle=False, | ||
disjoint=True, | ||
temporal_strategy=temporal_strategy, | ||
time_attr='time', | ||
) | ||
|
||
init_rpc( | ||
current_ctx=current_ctx, | ||
rpc_worker_names={}, | ||
master_addr='localhost', | ||
master_port=master_port, | ||
) | ||
|
||
dist_sampler.register_sampler_rpc() | ||
dist_sampler.init_event_loop() | ||
|
||
# Close RPC & worker group at exit: | ||
atexit.register(close_sampler, 0, dist_sampler) | ||
torch.distributed.barrier() | ||
|
||
if rank == 0: # Seed nodes: | ||
input_node = torch.tensor([1, 6], dtype=torch.int64) | ||
else: | ||
input_node = torch.tensor([4, 9], dtype=torch.int64) | ||
|
||
inputs = NodeSamplerInput( | ||
input_id=None, | ||
node=input_node, | ||
time=seed_time, | ||
) | ||
|
||
# Evaluate distributed node sample function: | ||
out_dist = dist_sampler.event_loop.run_task( | ||
coro=dist_sampler.node_sample(inputs)) | ||
|
||
torch.distributed.barrier() | ||
|
||
sampler = NeighborSampler( | ||
data=data, | ||
num_neighbors=num_neighbors, | ||
disjoint=True, | ||
temporal_strategy=temporal_strategy, | ||
time_attr='time', | ||
) | ||
|
||
# Evaluate node sample function: | ||
out = node_sample(inputs, sampler._sample) | ||
|
||
# Compare distributed output with single machine output: | ||
assert torch.equal(out_dist.node, out.node) | ||
assert torch.equal(out_dist.row, out.row) | ||
assert torch.equal(out_dist.col, out.col) | ||
assert torch.equal(out_dist.batch, out.batch) | ||
assert out_dist.num_sampled_nodes == out.num_sampled_nodes | ||
assert out_dist.num_sampled_edges == out.num_sampled_edges | ||
|
||
torch.distributed.barrier() | ||
|
||
|
||
@onlyLinux | ||
@withPackage('pyg_lib') | ||
@pytest.mark.parametrize('disjoint', [False, True]) | ||
def test_dist_neighbor_sampler(disjoint): | ||
mp_context = torch.multiprocessing.get_context('spawn') | ||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||
s.bind(('127.0.0.1', 0)) | ||
port = s.getsockname()[1] | ||
s.close() | ||
|
||
world_size = 2 | ||
w0 = mp_context.Process( | ||
target=dist_neighbor_sampler, | ||
args=(world_size, 0, port, disjoint), | ||
) | ||
|
||
w1 = mp_context.Process( | ||
target=dist_neighbor_sampler, | ||
args=(world_size, 1, port, disjoint), | ||
) | ||
|
||
w0.start() | ||
w1.start() | ||
w0.join() | ||
w1.join() | ||
|
||
|
||
@onlyLinux | ||
@withPackage('pyg_lib') | ||
@pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])]) | ||
@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last']) | ||
def test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy): | ||
mp_context = torch.multiprocessing.get_context('spawn') | ||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||
s.bind(('127.0.0.1', 0)) | ||
port = s.getsockname()[1] | ||
s.close() | ||
|
||
world_size = 2 | ||
w0 = mp_context.Process( | ||
target=dist_neighbor_sampler_temporal, | ||
args=(world_size, 0, port, seed_time, temporal_strategy), | ||
) | ||
|
||
w1 = mp_context.Process( | ||
target=dist_neighbor_sampler_temporal, | ||
args=(world_size, 1, port, seed_time, temporal_strategy), | ||
) | ||
|
||
w0.start() | ||
w1.start() | ||
w0.join() | ||
w1.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.