diff --git a/torch_geometric/data/graph_store.py b/torch_geometric/data/graph_store.py index 7d5b3b5e3c57..a0cc6beda5ba 100644 --- a/torch_geometric/data/graph_store.py +++ b/torch_geometric/data/graph_store.py @@ -285,16 +285,7 @@ def _edge_to_layout( row = ptr2index(row) if attr.layout != EdgeLayout.CSC: # COO->CSC - try: - if self.meta['is_hetero']: - # Hotfix for LocalGraphStore, where in hetero graph edge indices for different edge types have continuous indices, not starting at 0 - num_cols = int(col.max()) + 1 - else: - num_cols = attr.size[1] if attr.size else int( - col.max()) + 1 - except AttributeError: - num_cols = attr.size[1] if attr.size else int( - col.max()) + 1 + num_cols = attr.size[1] if attr.size else int(col.max()) + 1 if not attr.is_sorted: # Not sorted by destination. col, perm = index_sort(col, max_value=num_cols) row = row[perm] @@ -316,59 +307,38 @@ def _edges_to_layout( store: bool = False, ) -> ConversionOutputType: - try: - is_hetero = self.meta["is_hetero"] - except AttributeError: - # assume default is_hetero = True - is_hetero = True - - if not is_hetero: # Homo - edge_attrs: List[EdgeAttr] = [] - for attr in self.get_all_edge_attrs(): - edge_attrs.append(attr) - - # Convert layout from its most favorable original layout: - row, col, perm = [], [], [] - # for attrs in edge_attrs: - row, col, perm = (self._edge_to_layout(edge_attrs[0], layout, - store)) - - return row, col, perm - else: - # Obtain all edge attributes, grouped by type: - edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list) - for attr in self.get_all_edge_attrs(): - edge_type_attrs[attr.edge_type].append(attr) - - # Check that requested edge types exist and filter: - if edge_types is not None: - for edge_type in edge_types: - if edge_type not in edge_type_attrs: - raise ValueError( - f"The 'edge_index' of type '{edge_type}' " - f"was not found in the graph store.") - - edge_type_attrs = { - key: attr - for key, attr in edge_type_attrs.items() - if key in edge_types - } - - # Convert layout from its most favorable original layout: - row_dict, col_dict, perm_dict = {}, {}, {} - for edge_type, attrs in edge_type_attrs.items(): - layouts = [attr.layout for attr in attrs] - - if layout in layouts: # No conversion needed. - attr = attrs[layouts.index(layout)] - elif EdgeLayout.COO in layouts: # Prefer COO for conversion. - attr = attrs[layouts.index(EdgeLayout.COO)] - elif EdgeLayout.CSC in layouts: - attr = attrs[layouts.index(EdgeLayout.CSC)] - elif EdgeLayout.CSR in layouts: - attr = attrs[layouts.index(EdgeLayout.CSR)] - - row_dict[edge_type], col_dict[edge_type], perm_dict[ - edge_type] = (self._edge_to_layout(attr, layout, store)) - - return row_dict, col_dict, perm_dict + # Obtain all edge attributes, grouped by type: + edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list) + for attr in self.get_all_edge_attrs(): + edge_type_attrs[attr.edge_type].append(attr) + + # Check that requested edge types exist and filter: + if edge_types is not None: + for edge_type in edge_types: + if edge_type not in edge_type_attrs: + raise ValueError(f"The 'edge_index' of type '{edge_type}' " + f"was not found in the graph store.") + + edge_type_attrs = { + key: attr + for key, attr in edge_type_attrs.items() if key in edge_types + } + + # Convert layout from its most favorable original layout: + row_dict, col_dict, perm_dict = {}, {}, {} + for edge_type, attrs in edge_type_attrs.items(): + layouts = [attr.layout for attr in attrs] + + if layout in layouts: # No conversion needed. + attr = attrs[layouts.index(layout)] + elif EdgeLayout.COO in layouts: # Prefer COO for conversion. + attr = attrs[layouts.index(EdgeLayout.COO)] + elif EdgeLayout.CSC in layouts: + attr = attrs[layouts.index(EdgeLayout.CSC)] + elif EdgeLayout.CSR in layouts: + attr = attrs[layouts.index(EdgeLayout.CSR)] + + row_dict[edge_type], col_dict[edge_type], perm_dict[edge_type] = ( + self._edge_to_layout(attr, layout, store)) + + return row_dict, col_dict, perm_dict diff --git a/torch_geometric/distributed/local_graph_store.py b/torch_geometric/distributed/local_graph_store.py index 5b8dfebbaad0..5127f0634ebc 100644 --- a/torch_geometric/distributed/local_graph_store.py +++ b/torch_geometric/distributed/local_graph_store.py @@ -89,7 +89,6 @@ def from_data( edge_id: Tensor, edge_index: Tensor, num_nodes: int, - is_sorted: Optional[bool] = False, ) -> 'LocalGraphStore': r"""Creates a local graph store from a homogeneous :pyg:`PyG` graph. @@ -102,7 +101,6 @@ def from_data( edge_type=None, layout='coo', size=(num_nodes, num_nodes), - is_sorted=is_sorted, ) graph_store = cls() @@ -116,7 +114,6 @@ def from_hetero_data( edge_id_dict: Dict[EdgeType, Tensor], edge_index_dict: Dict[EdgeType, Tensor], num_nodes_dict: Dict[NodeType, int], - is_sorted_dict: Dict[NodeType, bool], ) -> 'LocalGraphStore': r"""Creates a local graph store from a heterogeneous :pyg:`PyG` graph. @@ -127,8 +124,6 @@ def from_hetero_data( indices of every edge type. num_nodes_dict (Dict[NodeType, int]): The number of nodes in the local graph of every node type. - is_sorted_dict (Dict[NodeType, bool]): Implicit information on node - order of every node type. """ attr_dict = {} for edge_type in edge_index_dict.keys(): @@ -137,7 +132,6 @@ def from_hetero_data( edge_type=edge_type, layout='coo', size=(num_nodes_dict[src], num_nodes_dict[dst]), - is_sorted=is_sorted_dict[dst], ) graph_store = cls()