Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GraphStore and FeatureStore [1/6] #8083

Merged
merged 23 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4d69c3a
update graph store
kgajdamo Sep 27, 2023
bbb3e9e
update local feat store & test for graph store
JakubPietrakIntel Sep 27, 2023
ad98c6d
lgs/lfs refactor labels
JakubPietrakIntel Sep 28, 2023
c011ffc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
c7a834e
fix pb for hetero
JakubPietrakIntel Oct 2, 2023
076e9dc
fix edge\node feat pb
JakubPietrakIntel Oct 2, 2023
49a82a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2023
309c7ce
add `load_partition_info`
JakubPietrakIntel Oct 2, 2023
e70f440
move edge utils to `distributed.utils`
JakubPietrakIntel Oct 2, 2023
d582b8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2023
77651fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2023
4cabd98
remove unused import
kgajdamo Oct 17, 2023
d2b4f1d
update CHANGELOG.md
kgajdamo Oct 17, 2023
434c743
set 'is_hetero' info when creating lgs from data
kgajdamo Oct 17, 2023
4f10373
force sort lgs on dst node to fix edge[perm]
JakubPietrakIntel Oct 17, 2023
ab04bfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2023
5985892
clean sampler
JakubPietrakIntel Oct 17, 2023
b60df3a
fix partition test
JakubPietrakIntel Oct 17, 2023
ac743e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2023
70a5594
format code
kgajdamo Oct 26, 2023
e037c37
Merge branch 'master' into intel/dist-gfs
rusty1s Nov 6, 2023
3dd6564
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
088ef71
update
rusty1s Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))
- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))

Expand Down
82 changes: 72 additions & 10 deletions test/distributed/test_local_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,25 @@ def test_local_graph_store():
edge_index = get_random_edge_index(100, 100, 300)
edge_id = torch.tensor([1, 2, 3, 5, 8, 4])

graph_store.put_edge_index(edge_index, edge_type=None, layout='coo',
size=(100, 100))
graph_store.put_edge_index(
edge_index,
edge_type=None,
layout='coo',
size=(100, 100),
)

graph_store.put_edge_id(edge_id, edge_type=None, layout='coo',
size=(100, 100))
graph_store.put_edge_id(
edge_id,
edge_type=None,
layout='coo',
size=(100, 100),
)

assert len(graph_store.get_all_edge_attrs()) == 1
edge_attr = graph_store.get_all_edge_attrs()[0]
assert torch.equal(graph_store.get_edge_index(edge_attr), edge_index)
assert torch.equal(graph_store.get_edge_id(edge_attr), edge_id)

assert not graph_store.is_sorted
graph_store.remove_edge_index(edge_attr)
graph_store.remove_edge_id(edge_attr)
assert len(graph_store.get_all_edge_attrs()) == 0
Expand All @@ -29,14 +37,20 @@ def test_local_graph_store():
def test_homogeneous_graph_store():
edge_id = torch.randperm(300)
edge_index = get_random_edge_index(100, 100, 300)
edge_index[1] = torch.sort(edge_index[1])[0]

graph_store = LocalGraphStore.from_data(edge_id, edge_index, num_nodes=100)
graph_store = LocalGraphStore.from_data(
edge_id,
edge_index,
num_nodes=100,
is_sorted=True,
)

assert len(graph_store.get_all_edge_attrs()) == 1
edge_attr = graph_store.get_all_edge_attrs()[0]
assert edge_attr.edge_type is None
assert edge_attr.layout.value == 'coo'
assert not edge_attr.is_sorted
assert edge_attr.is_sorted
assert edge_attr.size == (100, 100)

assert torch.equal(
Expand All @@ -52,16 +66,22 @@ def test_homogeneous_graph_store():
def test_heterogeneous_graph_store():
edge_type = ('paper', 'to', 'paper')
edge_id_dict = {edge_type: torch.randperm(300)}
edge_index_dict = {edge_type: get_random_edge_index(100, 100, 300)}
edge_index = get_random_edge_index(100, 100, 300)
edge_index[1] = torch.sort(edge_index[1])[0]
edge_index_dict = {edge_type: edge_index}

graph_store = LocalGraphStore.from_hetero_data(
edge_id_dict, edge_index_dict, num_nodes_dict={'paper': 100})
edge_id_dict,
edge_index_dict,
num_nodes_dict={'paper': 100},
is_sorted=True,
)

assert len(graph_store.get_all_edge_attrs()) == 1
edge_attr = graph_store.get_all_edge_attrs()[0]
assert edge_attr.edge_type == edge_type
assert edge_attr.layout.value == 'coo'
assert not edge_attr.is_sorted
assert edge_attr.is_sorted
assert edge_attr.size == (100, 100)

assert torch.equal(
Expand All @@ -72,3 +92,45 @@ def test_heterogeneous_graph_store():
graph_store.get_edge_index(edge_type, layout='coo'),
edge_index_dict[edge_type],
)


def test_sorted_graph_store():
edge_index_sorted = torch.tensor([[1, 7, 5, 6, 1], [0, 0, 1, 1, 2]])
edge_id_sorted = torch.tensor([0, 1, 2, 3, 4])

edge_index = torch.tensor([[1, 5, 7, 1, 6], [0, 1, 0, 2, 1]])
edge_id = torch.tensor([0, 2, 1, 4, 3])

graph_store = LocalGraphStore.from_data(
edge_id,
edge_index,
num_nodes=8,
is_sorted=False,
)
assert torch.equal(
graph_store.get_edge_index(edge_type=None, layout='coo'),
edge_index_sorted,
)
assert torch.equal(
graph_store.get_edge_id(edge_type=None, layout='coo'),
edge_id_sorted,
)

edge_type = ('paper', 'to', 'paper')
edge_index_dict = {edge_type: edge_index}
edge_id_dict = {edge_type: edge_id}

graph_store = LocalGraphStore.from_hetero_data(
edge_id_dict,
edge_index_dict,
num_nodes_dict={'paper': 8},
is_sorted=False,
)
assert torch.equal(
graph_store.get_edge_index(edge_type, layout='coo'),
edge_index_sorted,
)
assert torch.equal(
graph_store.get_edge_id(edge_type, layout='coo'),
edge_id_sorted,
)
22 changes: 21 additions & 1 deletion torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,16 @@ def _edge_to_layout(
row = ptr2index(row)

if attr.layout != EdgeLayout.CSC: # COO->CSC
num_cols = attr.size[1] if attr.size else int(col.max()) + 1
if hasattr(self, 'meta') and self.meta.get('is_hetero', False):
# Hotfix for `LocalGraphStore`, where in heterogeneous
# graphs, edge indices for different edge types have
# continuous indices not starting at 0.
num_cols = int(col.max()) + 1
elif attr.size is not None:
num_cols = attr.size[1]
else:
num_cols = int(col.max()) + 1

if not attr.is_sorted: # Not sorted by destination.
col, perm = index_sort(col, max_value=num_cols)
row = row[perm]
Expand All @@ -300,6 +309,17 @@ def _edges_to_layout(
store: bool = False,
) -> ConversionOutputType:

is_hetero = True # Default.
if hasattr(self, 'meta'): # `LocalGraphStore` hack.
is_hetero = self.meta.get('is_hetero', False)

if not is_hetero:
edge_attrs: List[EdgeAttr] = []
for attr in self.get_all_edge_attrs():
edge_attrs.append(attr)

return self._edge_to_layout(edge_attrs[0], layout, store)

# Obtain all edge attributes, grouped by type:
edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list)
for attr in self.get_all_edge_attrs():
Expand Down
6 changes: 2 additions & 4 deletions torch_geometric/distributed/dist_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def __init__(
self.temporal_strategy = temporal_strategy
self.time_attr = time_attr
self.with_edge_attr = self.dist_feature.has_edge_attr()
_, _, self.edge_permutation = self.dist_graph.csc()
self.csc = True

def register_sampler_rpc(self) -> None:
Expand All @@ -121,6 +120,8 @@ def register_sampler_rpc(self) -> None:
temporal_strategy=self.temporal_strategy,
time_attr=self.time_attr,
)
self.edge_permutation = self._sampler.perm

rpc_sample_callee = RPCSamplingCallee(self._sampler)
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)

Expand Down Expand Up @@ -660,9 +661,6 @@ async def _collate_fn(
efeats = None

output.metadata = (*output.metadata, nfeats, nlabels, efeats)
if self.is_hetero:
output.row = remap_keys(output.row, self._sampler.to_edge_type)
output.col = remap_keys(output.col, self._sampler.to_edge_type)
return output

def __repr__(self) -> str:
Expand Down
60 changes: 33 additions & 27 deletions torch_geometric/distributed/local_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,21 @@ class LocalFeatureStore(FeatureStore):
"""
def __init__(self):
super().__init__(tensor_attr_cls=LocalTensorAttr)

self._feat: Dict[Tuple[Union[NodeType, EdgeType], str], Tensor] = {}

# Save the global node/edge IDs:
self._global_id: Dict[Union[NodeType, EdgeType], Tensor] = {}

# Save the mapping from global node/edge IDs to indices in `_feat`:
self._global_id_to_index: Dict[Union[NodeType, EdgeType], Tensor] = {}

# For partition/rpc information related to distribute features:
self.num_partitions = 1
self.partition_idx = 0
self.feature_pb: Union[Tensor, Dict[NodeOrEdgeType, Tensor]]
self.local_only = False
# For partition/RPC info related to distributed features:
self.num_partitions: int = 1
self.partition_idx: int = 0
# Mapping between node ID and partition ID:
self.node_feat_pb: Union[Tensor, Dict[NodeType, Tensor]]
# Mapping between edge ID and partition ID:
self.edge_feat_pb: Union[Tensor, Dict[EdgeType, Tensor]]
self.labels: Optional[Tensor] = None # Node labels.

self.local_only: bool = False
self.rpc_router: Optional[RPCRouter] = None
self.meta: Optional[Dict] = None
self.rpc_call_id: Optional[int] = None
Expand Down Expand Up @@ -147,6 +148,16 @@ def set_rpc_router(self, rpc_router: RPCRouter):
else:
self.rpc_call_id = None

def has_edge_attr(self) -> bool:
has_edge_attr = False
for k in [key for key in self._feat.keys() if 'edge_attr' in key]:
try:
self.get_tensor(k[0], 'edge_attr')
has_edge_attr = True
except KeyError:
pass
return has_edge_attr

def lookup_features(
self,
index: Tensor,
Expand All @@ -164,8 +175,11 @@ def when_finish(*_):
try:
remote_feature_list = remote_fut.wait()
# combine the feature from remote and local
result = torch.zeros(index.size(0), local_feature[0].size(1),
dtype=local_feature[0].dtype)
result = torch.zeros(
index.size(0),
local_feature[0].size(1),
dtype=local_feature[0].dtype,
)
result[local_feature[1]] = local_feature[0]
for remote in remote_feature_list:
result[remote[1]] = remote[0]
Expand All @@ -184,12 +198,7 @@ def _local_lookup_features(
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> Tuple[Tensor, Tensor]:
r"""Lookup the features in local nodes based on node/edge IDs."""
if self.meta['is_hetero']:
feat = self
pb = self.feature_pb[input_type]
else:
feat = self
pb = self.feature_pb
pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb

input_order = torch.arange(index.size(0), dtype=torch.long)
partition_ids = pb[index]
Expand All @@ -198,23 +207,23 @@ def _local_lookup_features(
local_ids = torch.masked_select(index, local_mask)
local_index = torch.masked_select(input_order, local_mask)

if self.meta["is_hetero"]:
if self.meta['is_hetero']:
if is_node_feat:
kwargs = dict(group_name=input_type, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
ret_feat = self.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
kwargs = dict(group_name=input_type, attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
ret_feat = self.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
if is_node_feat:
kwargs = dict(group_name=None, attr_name='x')
ret_feat = feat.get_tensor_from_global_id(
ret_feat = self.get_tensor_from_global_id(
index=local_ids, **kwargs)
else:
kwargs = dict(group_name=(None, None), attr_name='edge_attr')
ret_feat = feat.get_tensor_from_global_id(
ret_feat = self.get_tensor_from_global_id(
index=local_ids, **kwargs)

return ret_feat, local_index
Expand All @@ -226,18 +235,15 @@ def _remote_lookup_features(
input_type: Optional[Union[NodeType, EdgeType]] = None,
) -> torch.futures.Future:
r"""Fetch the remote features with the remote node/edge IDs."""
if self.meta["is_hetero"]:
pb = self.feature_pb[input_type]
else:
pb = self.feature_pb
pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb

input_order = torch.arange(index.size(0), dtype=torch.long)
partition_ids = pb[index]
futs, indexes = [], []
for pidx in range(0, self.num_partitions):
if pidx == self.partition_idx:
continue
remote_mask = (partition_ids == pidx)
remote_mask = partition_ids == pidx
remote_ids = index[remote_mask]
if remote_ids.shape[0] > 0:
to_worker = self.rpc_router.get_to_worker(pidx)
Expand Down
Loading