Skip to content

Commit

Permalink
[Feature] Remove pyg dependency of datasets (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiguang-Guo authored Jan 21, 2021
1 parent 91cb677 commit adc7281
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 3 deletions.
2 changes: 2 additions & 0 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def num_edges(self):
@property
def num_features(self):
r"""Returns the number of features per node in the graph."""
if self.x is None:
return 0
return 1 if self.x.dim() == 1 else self.x.size(1)

@property
Expand Down
261 changes: 260 additions & 1 deletion cogdl/datasets/pyg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import glob
import os
import os.path as osp
import shutil
import zipfile
from itertools import repeat

import numpy as np
import torch
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid, Reddit, TUDataset, QM9
from cogdl.data.dataset import Dataset
from . import register_dataset
from ..data import Data
from ..utils import download_url


def normalize_feature(data):
Expand All @@ -14,6 +24,255 @@ def normalize_feature(data):
return data


def parse_txt_array(src, sep=None, start=0, end=None, dtype=None, device=None):
src = [[float(x) for x in line.split(sep)[start:end]] for line in src]
src = torch.tensor(src, dtype=dtype).squeeze()
return src


def read_txt_array(path, sep=None, start=0, end=None, dtype=None, device=None):
with open(path, 'r') as f:
src = f.read().split('\n')[:-1]
return parse_txt_array(src, sep, start, end, dtype, device)


def read_file(folder, prefix, name, dtype=None):
path = osp.join(folder, '{}_{}.txt'.format(prefix, name))
return read_txt_array(path, sep=',', dtype=dtype)


def cat(seq):
seq = [item for item in seq if item is not None]
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
return torch.cat(seq, dim=-1) if len(seq) > 0 else None


def split(data, batch):
node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0)
node_slice = torch.cat([torch.tensor([0]), node_slice])

row, _ = data.edge_index
edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0)
edge_slice = torch.cat([torch.tensor([0]), edge_slice])

# Edge indices should start at zero for every graph.
data.edge_index -= node_slice[batch[row]].unsqueeze(0)
data.__num_nodes__ = torch.bincount(batch).tolist()

slices = {'edge_index': edge_slice}
if data.x is not None:
slices['x'] = node_slice
if data.edge_attr is not None:
slices['edge_attr'] = edge_slice
if data.y is not None:
if data.y.size(0) == batch.size(0):
slices['y'] = node_slice
else:
slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long)

return data, slices


def segment(src, indptr):
out_list = []
for i in range(indptr.size(-1) - 1):
indexptr = torch.arange(indptr[..., i].item(), indptr[..., i + 1].item(), dtype=torch.int64)
src_data = src.index_select(indptr.dim() - 1, indexptr)
out = torch.sum(src_data, dim=indptr.dim() - 1, keepdim=True)
out_list.append(out)
return torch.cat(out_list, dim=indptr.dim() - 1)


def coalesce(index, value, m, n):
row = index[0]
col = index[1]

idx = col.new_zeros(col.numel() + 1)
idx[1:] = row
idx[1:] *= n
idx[1:] += col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
row = row[perm]
col = col[perm]
if value is not None:
value = value[perm]

idx = col.new_full((col.numel() + 1,), -1)
idx[1:] = n * row + col
mask = idx[1:] > idx[:-1]

if mask.all(): # Skip if indices are already coalesced.
return torch.stack([row, col], dim=0), value

row = row[mask]
col = col[mask]

if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1,), value.size(0))])
# print(value.size(), ptr.size())
value = segment(value, ptr)
# print(value.size())
value = value[0] if isinstance(value, tuple) else value

return torch.stack([row, col], dim=0), value


def read_tu_data(folder, prefix):
files = glob.glob(osp.join(folder, '{}_*.txt'.format(prefix)))
names = [f.split(os.sep)[-1][len(prefix) + 1:-4] for f in files]

edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1
batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1

node_attributes = node_labels = None
if 'node_attributes' in names:
node_attributes = read_file(folder, prefix, 'node_attributes')
if 'node_labels' in names:
node_labels = read_file(folder, prefix, 'node_labels', torch.long)
if node_labels.dim() == 1:
node_labels = node_labels.unsqueeze(-1)
node_labels = node_labels - node_labels.min(dim=0)[0]
node_labels = node_labels.unbind(dim=-1)
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
x = cat([node_attributes, node_labels])

edge_attributes, edge_labels = None, None
if 'edge_attributes' in names:
edge_attributes = read_file(folder, prefix, 'edge_attributes')
if 'edge_labels' in names:
edge_labels = read_file(folder, prefix, 'edge_labels', torch.long)
if edge_labels.dim() == 1:
edge_labels = edge_labels.unsqueeze(-1)
edge_labels = edge_labels - edge_labels.min(dim=0)[0]
edge_labels = edge_labels.unbind(dim=-1)
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)
edge_attr = cat([edge_attributes, edge_labels])

y = None
if 'graph_attributes' in names: # Regression problem.
y = read_file(folder, prefix, 'graph_attributes')
elif 'graph_labels' in names: # Classification problem.
y = read_file(folder, prefix, 'graph_labels', torch.long)
_, y = y.unique(sorted=True, return_inverse=True)

num_nodes = edge_index.max().item() + 1 if x is None else x.size(0)

mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
if edge_attr is not None:
edge_attr = edge_attr[mask]

edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, num_nodes)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data, slices = split(data, batch)

return data, slices


class TUDataset(Dataset):
url = 'https://www.chrsmrrs.com/graphkerneldatasets'

def __init__(self, root, name):
self.name = name
super(TUDataset, self).__init__(root)
self.data, self.slices = torch.load(self.processed_paths[0])
if self.data.x is not None:
num_node_attributes = self.num_node_attributes
self.data.x = self.data.x[:, num_node_attributes:]
if self.data.edge_attr is not None:
num_edge_attributes = self.num_edge_attributes
self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:]

@property
def raw_file_names(self):
names = ['A', 'graph_indicator']
return ['{}_{}.txt'.format(self.name, name) for name in names]

@property
def processed_file_names(self):
return 'data.pt'

def download(self):
url = self.url
folder = osp.join(self.root)
path = download_url('{}/{}.zip'.format(url, self.name), folder)
with zipfile.ZipFile(path, 'r') as f:
f.extractall(folder)
os.unlink(path)
shutil.rmtree(self.raw_dir)
os.rename(osp.join(folder, self.name), self.raw_dir)

def process(self):
self.data = read_tu_data(self.raw_dir, self.name)
torch.save(self.data, self.processed_paths[0])

@property
def num_node_labels(self):
if self.data.x is None:
return 0
for i in range(self.data.x.size(1)):
x = self.data.x[:, i:]
if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
return self.data.x.size(1) - i
return 0

@property
def num_node_attributes(self):
if self.data.x is None:
return 0
return self.data.x.size(1) - self.num_node_labels

@property
def num_edge_labels(self):
if self.data.edge_attr is None:
return 0
for i in range(self.data.edge_attr.size(1)):
if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0):
return self.data.edge_attr.size(1) - i
return 0

@property
def num_edge_attributes(self):
if self.data.edge_attr is None:
return 0
return self.data.edge_attr.size(1) - self.num_edge_labels

@property
def num_classes(self):
r"""The number of classes in the dataset."""
y = self.data.y
return y.max().item() + 1 if y.dim() == 1 else y.size(1)

def __len__(self):
for item in self.slices.values():
return len(item) - 1
return 0

def get(self, idx):
data = self.data.__class__()
if hasattr(self.data, '__num_nodes__'):
data.num_nodes = self.data.__num_nodes__[idx]

for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
start, end = slices[idx].item(), slices[idx + 1].item()
if torch.is_tensor(item):
s = list(repeat(slice(None), item.dim()))
s[self.data.cat_dim(key, item)] = slice(start, end)
elif start + 1 == end:
s = slices[start]
else:
s = slice(start, end)
data[key] = item[s]

return data


@register_dataset("mutag")
class MUTAGDataset(TUDataset):
def __init__(self, args=None):
Expand Down
4 changes: 2 additions & 2 deletions cogdl/tasks/graph_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def generate_data(self, dataset, args):
return {"train": train_set, "test": test_set}
else:
datalist = []
if isinstance(dataset[0], Data):
return dataset
# if isinstance(dataset[0], Data):
# return dataset
for idata in dataset:
data = Data()
for key in idata.keys:
Expand Down

0 comments on commit adc7281

Please sign in to comment.