Skip to content

Commit

Permalink
remove GraphStore changes
Browse files Browse the repository at this point in the history
this code was moved to another PR #8083
  • Loading branch information
kgajdamo committed Sep 27, 2023
1 parent 029d664 commit a097de5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 72 deletions.
102 changes: 36 additions & 66 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,7 @@ def _edge_to_layout(
row = ptr2index(row)

if attr.layout != EdgeLayout.CSC: # COO->CSC
try:
if self.meta['is_hetero']:
# Hotfix for LocalGraphStore, where in hetero graph edge indices for different edge types have continuous indices, not starting at 0
num_cols = int(col.max()) + 1
else:
num_cols = attr.size[1] if attr.size else int(
col.max()) + 1
except AttributeError:
num_cols = attr.size[1] if attr.size else int(
col.max()) + 1
num_cols = attr.size[1] if attr.size else int(col.max()) + 1
if not attr.is_sorted: # Not sorted by destination.
col, perm = index_sort(col, max_value=num_cols)
row = row[perm]
Expand All @@ -316,59 +307,38 @@ def _edges_to_layout(
store: bool = False,
) -> ConversionOutputType:

try:
is_hetero = self.meta["is_hetero"]
except AttributeError:
# assume default is_hetero = True
is_hetero = True

if not is_hetero: # Homo
edge_attrs: List[EdgeAttr] = []
for attr in self.get_all_edge_attrs():
edge_attrs.append(attr)

# Convert layout from its most favorable original layout:
row, col, perm = [], [], []
# for attrs in edge_attrs:
row, col, perm = (self._edge_to_layout(edge_attrs[0], layout,
store))

return row, col, perm
else:
# Obtain all edge attributes, grouped by type:
edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list)
for attr in self.get_all_edge_attrs():
edge_type_attrs[attr.edge_type].append(attr)

# Check that requested edge types exist and filter:
if edge_types is not None:
for edge_type in edge_types:
if edge_type not in edge_type_attrs:
raise ValueError(
f"The 'edge_index' of type '{edge_type}' "
f"was not found in the graph store.")

edge_type_attrs = {
key: attr
for key, attr in edge_type_attrs.items()
if key in edge_types
}

# Convert layout from its most favorable original layout:
row_dict, col_dict, perm_dict = {}, {}, {}
for edge_type, attrs in edge_type_attrs.items():
layouts = [attr.layout for attr in attrs]

if layout in layouts: # No conversion needed.
attr = attrs[layouts.index(layout)]
elif EdgeLayout.COO in layouts: # Prefer COO for conversion.
attr = attrs[layouts.index(EdgeLayout.COO)]
elif EdgeLayout.CSC in layouts:
attr = attrs[layouts.index(EdgeLayout.CSC)]
elif EdgeLayout.CSR in layouts:
attr = attrs[layouts.index(EdgeLayout.CSR)]

row_dict[edge_type], col_dict[edge_type], perm_dict[
edge_type] = (self._edge_to_layout(attr, layout, store))

return row_dict, col_dict, perm_dict
# Obtain all edge attributes, grouped by type:
edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list)
for attr in self.get_all_edge_attrs():
edge_type_attrs[attr.edge_type].append(attr)

# Check that requested edge types exist and filter:
if edge_types is not None:
for edge_type in edge_types:
if edge_type not in edge_type_attrs:
raise ValueError(f"The 'edge_index' of type '{edge_type}' "
f"was not found in the graph store.")

edge_type_attrs = {
key: attr
for key, attr in edge_type_attrs.items() if key in edge_types
}

# Convert layout from its most favorable original layout:
row_dict, col_dict, perm_dict = {}, {}, {}
for edge_type, attrs in edge_type_attrs.items():
layouts = [attr.layout for attr in attrs]

if layout in layouts: # No conversion needed.
attr = attrs[layouts.index(layout)]
elif EdgeLayout.COO in layouts: # Prefer COO for conversion.
attr = attrs[layouts.index(EdgeLayout.COO)]
elif EdgeLayout.CSC in layouts:
attr = attrs[layouts.index(EdgeLayout.CSC)]
elif EdgeLayout.CSR in layouts:
attr = attrs[layouts.index(EdgeLayout.CSR)]

row_dict[edge_type], col_dict[edge_type], perm_dict[edge_type] = (
self._edge_to_layout(attr, layout, store))

return row_dict, col_dict, perm_dict
6 changes: 0 additions & 6 deletions torch_geometric/distributed/local_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def from_data(
edge_id: Tensor,
edge_index: Tensor,
num_nodes: int,
is_sorted: Optional[bool] = False,
) -> 'LocalGraphStore':
r"""Creates a local graph store from a homogeneous :pyg:`PyG` graph.
Expand All @@ -102,7 +101,6 @@ def from_data(
edge_type=None,
layout='coo',
size=(num_nodes, num_nodes),
is_sorted=is_sorted,
)

graph_store = cls()
Expand All @@ -116,7 +114,6 @@ def from_hetero_data(
edge_id_dict: Dict[EdgeType, Tensor],
edge_index_dict: Dict[EdgeType, Tensor],
num_nodes_dict: Dict[NodeType, int],
is_sorted_dict: Dict[NodeType, bool],
) -> 'LocalGraphStore':
r"""Creates a local graph store from a heterogeneous :pyg:`PyG` graph.
Expand All @@ -127,8 +124,6 @@ def from_hetero_data(
indices of every edge type.
num_nodes_dict (Dict[NodeType, int]): The number of nodes in the
local graph of every node type.
is_sorted_dict (Dict[NodeType, bool]): Implicit information on node
order of every node type.
"""
attr_dict = {}
for edge_type in edge_index_dict.keys():
Expand All @@ -137,7 +132,6 @@ def from_hetero_data(
edge_type=edge_type,
layout='coo',
size=(num_nodes_dict[src], num_nodes_dict[dst]),
is_sorted=is_sorted_dict[dst],
)

graph_store = cls()
Expand Down

0 comments on commit a097de5

Please sign in to comment.