diff --git a/examples/bgrl/README.md b/examples/bgrl/README.md new file mode 100644 index 00000000..69e143f8 --- /dev/null +++ b/examples/bgrl/README.md @@ -0,0 +1,46 @@ +# Large-Scale Representation Learning on Graphs via Bootstrapping (BGRL) with CogDL +This is an attempt to implement BGRL with CogDL for graph representation. The authors' implementation can be found [here](https://github.com/nerdslab/bgrl). Another version of the implementation from [Namkyeong](https://github.com/Namkyeong/BGRL_Pytorch) can also be used as a reference. + +## Hyperparameters +Some optional parameters are allowed to be added to the training process. + +`layers`: the dimension for each layer of GNN. + +`pred_hid`: the hidden dimension of the predict moudle. + +`aug_params`: the ratio of pollution for graph augmentation. + +## Usage +You can find their datasets [here](https://pan.baidu.com/s/15RyvXD2G-xwGM9jrT7IDLQ?pwd=85vv) and put them in the path `./data`. Experiments on their datasets with given hyperparameters can be achieved by the following commands. + +### Wiki-CS +``` +python train.py --name WikiCS --aug_params 0.2 0.1 0.2 0.3 --layers 512 256 --pred_hid 512 --lr 0.0001 -epochs 10000 -cs 250 +``` +### Amazon Computers +``` +python train.py --name computers --aug_params 0.2 0.1 0.5 0.4 --layers 256 128 --pred_hid 512 --lr 0.0005 --epochs 10000 -cs 250 +``` +### Amazon Photo +``` +python train.py --name photo --aug_params 0.1 0.2 0.4 0.1 --layers 512 256 --pred_hid 512 --lr 0.0001 --epochs 10000 -cs 250 +``` +### Coauthor CS +``` +python train.py --name cs --aug_params 0.3 0.4 0.3 0.2 --layers 512 256 --pred_hid 512 --lr 0.00001 --epochs 10000 -cs 250 +``` +### Coauthor Physics +``` +python train.py --name physics --aug_params 0.1 0.4 0.4 0.1 --layers 256 128 --pred_hid 512 --lr 0.00001 --epochs 10000 -cs 250 +``` + +## Performance +The results on five datasets shown on the table. + +| |Wiki-CS|Computers|Photo |CS |Physics| +|------ |------ |---------|---------|-----|-------| +|Paper |79.98 |90.34 |93.17 |93.31|95.73 | +|Namkyeong |79.50 |88.21 |92.76 |92.49|94.89 | +|CogDL |79.76 |88.06 |92.91 |93.05|95.46 | +* Hyperparameters are from original paper + diff --git a/examples/bgrl/data.py b/examples/bgrl/data.py new file mode 100644 index 00000000..9d428e15 --- /dev/null +++ b/examples/bgrl/data.py @@ -0,0 +1,100 @@ + +import sys +import os +import torch +import torch.nn.functional as F +import numpy as np +import scipy.sparse as sp +from itertools import chain +from cogdl.data import Graph +from cogdl.utils.graph_utils import to_undirected, remove_self_loops + +import utils +import json + + +def process_npz(path): + with np.load(path) as f: + x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']), f['attr_shape']).todense() + x = torch.from_numpy(x).to(torch.float) + x[x > 0] = 1 + + adj = sp.csr_matrix((f['adj_data'], f['adj_indices'], f['adj_indptr']), + f['adj_shape']).tocoo() + row = torch.from_numpy(adj.row).to(torch.long) + col = torch.from_numpy(adj.col).to(torch.long) + edge_index = torch.stack([row, col], dim=0) + edge_index, _ = remove_self_loops(edge_index) + edge_index = to_undirected(edge_index, num_nodes=x.size(0)) + + y = torch.from_numpy(f['labels']).to(torch.long) + + return Graph(x=x, edge_index=edge_index, y=y) + + +def process_json(path): + with open(path, 'r') as f: + data = json.load(f) + + x = torch.tensor(data['features'], dtype=torch.float) + y = torch.tensor(data['labels'], dtype=torch.long) + + edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])] + edges = list(chain(*edges)) + edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() + edge_index = to_undirected(edge_index, num_nodes=x.size(0)) + + train_mask = torch.tensor(data['train_masks'], dtype=torch.bool) + train_mask = train_mask.t().contiguous() + + val_mask = torch.tensor(data['val_masks'], dtype=torch.bool) + val_mask = val_mask.t().contiguous() + + test_mask = torch.tensor(data['test_mask'], dtype=torch.bool) + + stopping_mask = torch.tensor(data['stopping_masks'], dtype=torch.bool) + stopping_mask = stopping_mask.t().contiguous() + + return Graph( + x=x, + y=y, + edge_index=edge_index, + train_mask=train_mask, + val_mask=val_mask, + test_mask=test_mask, + stopping_mask=stopping_mask + ) + + +def normalize_feature(data): + feature = data.x + feature = feature - feature.min() + data.x = feature / feature.sum(dim=-1, keepdim=True).clamp_(min=1.) + + +def get_data(dataset): + dataset_filepath = { + "photo": "./data/Photo/raw/amazon_electronics_photo.npz", + "computers": "./data/Computers/raw/amazon_electronics_computers.npz", + "cs": "./data/CS/raw/ms_academic_cs.npz", + "physics": "./data/Physics/raw/ms_academic_phy.npz", + "WikiCS": "./data/WikiCS/raw/data.json" + } + assert dataset in dataset_filepath + filepath = dataset_filepath[dataset] + if dataset in ['WikiCS']: + data = process_json(filepath) + normalize_feature(data) + std, mean = torch.std_mean(data.x, dim=0, unbiased=False) + data.x = (data.x - mean) / std + data.edge_index = to_undirected(data.edge_index) + else: + data = process_npz(filepath) + normalize_feature(data) + + data.add_remaining_self_loops() + data.sym_norm() + + data = utils.create_masks(data=data) + return data + diff --git a/examples/bgrl/models.py b/examples/bgrl/models.py new file mode 100644 index 00000000..7816f168 --- /dev/null +++ b/examples/bgrl/models.py @@ -0,0 +1,130 @@ +from cogdl.layers import GCNLayer + +import torch.nn.functional as F +import torch.nn as nn +import torch + +import numpy as np + +import copy + +""" +The following code is borrowed from BYOL, SelfGNN +and slightly modified for BGRL +""" + + +class EMA: + def __init__(self, beta, epochs): + super().__init__() + self.beta = beta + self.step = 0 + self.total_steps = epochs + + def update_average(self, old, new): + if old is None: + return new + beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0 + self.step += 1 + return old * beta + (1 - beta) * new + + +def loss_fn(x, y): + x = F.normalize(x, dim=-1, p=2) + y = F.normalize(y, dim=-1, p=2) + return 2 - 2 * (x * y).sum(dim=-1) + + +def update_moving_average(ema_updater, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = ema_updater.update_average(old_weight, up_weight) + + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + + +class Encoder(nn.Module): + + def __init__(self, layer_config, dropout=None, project=False, **kwargs): + super().__init__() + + self.conv1 = GCNLayer(layer_config[0], layer_config[1], bias=False, norm=None) + self.bn1 = nn.BatchNorm1d(layer_config[1], momentum=0.99) + self.prelu1 = nn.PReLU() + self.conv2 = GCNLayer(layer_config[1], layer_config[2], bias=False, norm=None) + self.bn2 = nn.BatchNorm1d(layer_config[2], momentum=0.99) + self.prelu2 = nn.PReLU() + + def forward(self, x, graph, edge_weight=None): + + # x = self.conv1(x, edge_index, edge_weight=edge_weight) + x = self.conv1(graph, x) + x = self.prelu1(self.bn1(x)) + # x = self.conv2(x, edge_index, edge_weight=edge_weight) + x = self.conv2(graph, x) + x = self.prelu2(self.bn2(x)) + + return x + + +def init_weights(m): + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + +class BGRL(nn.Module): + + def __init__(self, layer_config, pred_hid, dropout=0.0, moving_average_decay=0.99, epochs=1000, **kwargs): + super().__init__() + self.student_encoder = Encoder(layer_config=layer_config, dropout=dropout, **kwargs) + self.teacher_encoder = copy.deepcopy(self.student_encoder) + set_requires_grad(self.teacher_encoder, False) + self.teacher_ema_updater = EMA(moving_average_decay, epochs) + rep_dim = layer_config[-1] + self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_hid), nn.PReLU(), nn.Linear(pred_hid, rep_dim)) + self.student_predictor.apply(init_weights) + + def reset_moving_average(self): + del self.teacher_encoder + self.teacher_encoder = None + + def update_moving_average(self): + assert self.teacher_encoder is not None, 'teacher encoder has not been created yet' + update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder) + + def forward(self, x1, x2, graph_v1, graph_v2, edge_weight_v1=None, edge_weight_v2=None): + v1_student = self.student_encoder(x=x1, graph=graph_v1, edge_weight=edge_weight_v1) + v2_student = self.student_encoder(x=x2, graph=graph_v2, edge_weight=edge_weight_v2) + + v1_pred = self.student_predictor(v1_student) + v2_pred = self.student_predictor(v2_student) + + with torch.no_grad(): + v1_teacher = self.teacher_encoder(x=x1, graph=graph_v1, edge_weight=edge_weight_v1) + v2_teacher = self.teacher_encoder(x=x2, graph=graph_v2, edge_weight=edge_weight_v2) + + loss1 = loss_fn(v1_pred, v2_teacher.detach()) + loss2 = loss_fn(v2_pred, v1_teacher.detach()) + + loss = loss1 + loss2 + return v1_student, v2_student, loss.mean() + + +class LogisticRegression(nn.Module): + def __init__(self, num_dim, num_class): + super().__init__() + self.linear = nn.Linear(num_dim, num_class) + torch.nn.init.xavier_uniform_(self.linear.weight.data) + self.linear.bias.data.fill_(0.0) + self.cross_entropy = nn.CrossEntropyLoss() + + def forward(self, x, y): + + logits = self.linear(x) + loss = self.cross_entropy(logits, y) + + return logits, loss \ No newline at end of file diff --git a/examples/bgrl/train.py b/examples/bgrl/train.py new file mode 100644 index 00000000..404aa09b --- /dev/null +++ b/examples/bgrl/train.py @@ -0,0 +1,185 @@ +''' +This code is borrowed from https://github.com/Namkyeong/BGRL_Pytorch +''' + +import numpy as np + +import torch + + +import models +import utils +import data +import os +import sys +import warnings + +from torch import optim +from tensorboardX import SummaryWriter + +from sklearn import metrics +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import GridSearchCV, ShuffleSplit, train_test_split +from sklearn.multiclass import OneVsRestClassifier +from sklearn.preprocessing import OneHotEncoder, normalize + +warnings.filterwarnings("ignore") +torch.manual_seed(0) + + +class ModelTrainer: + + def __init__(self, args): + self._args = args + + self._init() + self.writer = SummaryWriter(log_dir="saved/BGRL_dataset({})".format(args.name)) + + def _init(self): + args = self._args + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device) + self._device = f'cuda:{args.device}' if torch.cuda.is_available() else "cpu" + # self._dataset = data.Dataset(root=args.root, name=args.name)[0] + self._dataset = data.get_data(args.name) + print(f"Data: {self._dataset}") + hidden_layers = [int(dim) for dim in args.layers] + layers = [self._dataset.x.shape[1]] + hidden_layers + self._model = models.BGRL(layer_config=layers, pred_hid=args.pred_hid, dropout=args.dropout, epochs=args.epochs).to(self._device) + print(self._model) + + self._optimizer = optim.AdamW(params=self._model.parameters(), lr=args.lr, weight_decay=1e-5) + + # learning rate + def lr_scheduler(epoch): + if epoch <= args.warmup_epochs: + return epoch / args.warmup_epochs + else: + return (1 + np.cos((epoch - args.warmup_epochs) * np.pi / (self._args.epochs - args.warmup_epochs))) * 0.5 + # lr_scheduler = lambda epoch: epoch / args.warmup_epochs if epoch <= args.warmup_epochs \ + # else ( 1 + np.cos((epoch - args.warmup_epochs) * np.pi / (self._args.epochs - args.warmup_epochs))) * 0.5 + + self._scheduler = optim.lr_scheduler.LambdaLR(self._optimizer, lr_lambda=lr_scheduler) + + def train(self): + # get initial test results + print(self._args) + print("start training!") + + print("Initial Evaluation...") + self.infer_embeddings() + test_best, test_std_best = self.evaluate() + print("test: {:.4f}".format(test_best)) + + # start training + self._model.train() + for epoch in range(self._args.epochs): + + self._dataset.to(self._device) + + augmentation = utils.Augmentation(float(self._args.aug_params[0]), float(self._args.aug_params[1]), float(self._args.aug_params[2]), float(self._args.aug_params[3])) + view1, view2 = augmentation._feature_masking(self._dataset, self._device) + + v1_output, v2_output, loss = self._model( + x1=view1.x, x2=view2.x, graph_v1=view1, graph_v2=view2, + edge_weight_v1=view1.edge_attr, edge_weight_v2=view2.edge_attr) + + self._optimizer.zero_grad() + loss.backward() + self._optimizer.step() + self._scheduler.step() + self._model.update_moving_average() + sys.stdout.write('\rEpoch {}/{}, loss {:.4f}, lr {}'.format(epoch + 1, self._args.epochs, loss.data, self._optimizer.param_groups[0]['lr'])) + sys.stdout.flush() + + if (epoch + 1) % self._args.cache_step == 0: + print("") + print("\nEvaluating {}th epoch..".format(epoch + 1)) + + self.infer_embeddings() + test_acc, test_std = self.evaluate() + + self.writer.add_scalar("stats/learning_rate", self._optimizer.param_groups[0]["lr"] , epoch + 1) + self.writer.add_scalar("accs/test_acc", test_acc, epoch + 1) + print("test: {:.4f} \n".format(test_acc)) + + print() + print("Training Done!") + + def infer_embeddings(self): + + self._model.train(False) + self._embeddings = self._labels = None + + self._dataset.to(self._device) + v1_output, v2_output, _ = self._model( + x1=self._dataset.x, x2=self._dataset.x, + graph_v1=self._dataset, + graph_v2=self._dataset, + edge_weight_v1=self._dataset.edge_attr, + edge_weight_v2=self._dataset.edge_attr) + emb = v1_output.detach() + y = self._dataset.y.detach() + if self._embeddings is None: + self._embeddings, self._labels = emb, y + else: + self._embeddings = torch.cat([self._embeddings, emb]) + self._labels = torch.cat([self._labels, y]) + + def evaluate(self): + """ + Used for producing the results of Experiment 3.2 in the BGRL paper. + """ + test_accs = [] + + self._embeddings = self._embeddings.cpu().numpy() + self._labels = self._labels.cpu().numpy() + self._dataset.to(torch.device("cpu")) + + one_hot_encoder = OneHotEncoder(categories='auto', sparse=False) + self._labels = one_hot_encoder.fit_transform(self._labels.reshape(-1, 1)).astype(np.bool) + + self._embeddings = normalize(self._embeddings, norm='l2') + + for i in range(20): + + self._train_mask = self._dataset.train_mask[i] + self._dev_mask = self._dataset.val_mask[i] + if self._args.name in ["WikiCS"]: + self._test_mask = self._dataset.test_mask + else: + self._test_mask = self._dataset.test_mask[i] + + # grid search with one-vs-rest classifiers + best_test_acc, best_acc = 0, 0 + + for c in 2.0 ** np.arange(-10, 11): + clf = OneVsRestClassifier(LogisticRegression(solver='liblinear', C=c)) + clf.fit(self._embeddings[self._train_mask], self._labels[self._train_mask]) + + y_pred = clf.predict_proba(self._embeddings[self._dev_mask]) + y_pred = np.argmax(y_pred, axis=1) + y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) + val_acc = metrics.accuracy_score(self._labels[self._dev_mask], y_pred) + if val_acc > best_acc: + best_acc = val_acc + y_pred = clf.predict_proba(self._embeddings[self._test_mask]) + y_pred = np.argmax(y_pred, axis=1) + y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) + best_test_acc = metrics.accuracy_score(self._labels[self._test_mask], y_pred) + test_accs.append(best_test_acc) + return np.mean(test_accs), np.std(test_accs) + + +def train_eval(args): + trainer = ModelTrainer(args) + trainer.train() + trainer.writer.close() + + +def main(): + args = utils.parse_args() + train_eval(args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/bgrl/utils.py b/examples/bgrl/utils.py new file mode 100644 index 00000000..014ca224 --- /dev/null +++ b/examples/bgrl/utils.py @@ -0,0 +1,135 @@ +from cogdl.utils import dropout_adj + +import os.path as osp +import os + +import argparse + +import numpy as np + +import torch + +""" +The Following code is borrowed from SelfGNN +""" + + +class Augmentation: + + def __init__(self, p_f1=0.2, p_f2=0.1, p_e1=0.2, p_e2=0.3): + """ + two simple graph augmentation functions --> "Node feature masking" and "Edge masking" + Random binary node feature mask following Bernoulli distribution with parameter p_f + Random binary edge mask following Bernoulli distribution with parameter p_e + """ + self.p_f1 = p_f1 + self.p_f2 = p_f2 + self.p_e1 = p_e1 + self.p_e2 = p_e2 + self.method = "BGRL" + + def _feature_masking(self, data, device): + feat_mask1 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f1 + feat_mask2 = torch.FloatTensor(data.x.shape[1]).uniform_() > self.p_f2 + feat_mask1, feat_mask2 = feat_mask1.to(device), feat_mask2.to(device) + x1, x2 = data.x.clone(), data.x.clone() + x1, x2 = x1 * feat_mask1, x2 * feat_mask2 + + edge_index1, edge_attr1 = dropout_adj(data.edge_index, data.edge_attr, drop_rate=self.p_e1) + edge_index2, edge_attr2 = dropout_adj(data.edge_index, data.edge_attr, drop_rate=self.p_e2) + + new_data1, new_data2 = data.clone(), data.clone() + new_data1.x, new_data2.x = x1, x2 + new_data1.edge_index, new_data2.edge_index = edge_index1, edge_index2 + new_data1.edge_attr , new_data2.edge_attr = edge_attr1, edge_attr2 + + return new_data1, new_data2 + + def __call__(self, data): + + return self._feature_masking(data) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--root", "-r", type=str, default="data", + help="Path to data directory, where all the datasets will be placed. Default is 'data'") + parser.add_argument("--name", "-n", type=str, default="WikiCS", + help="Name of the dataset. Supported names are: cora, citeseer, pubmed, photo, computers, cs, WikiCS, and physics") + parser.add_argument("--layers", "-l", nargs="+", default=[ + 512, 256], help="The number of units of each layer of the GNN. Default is [512, 128]") + parser.add_argument("--pred_hid", '-ph', type=int, + default=512, help="The number of hidden units of layer of the predictor. Default is 512") + parser.add_argument("--init-parts", "-ip", type=int, default=1, + help="The number of initial partitions. Default is 1. Applicable for ClusterSelfGNN") + parser.add_argument("--final-parts", "-fp", type=int, default=1, + help="The number of final partitions. Default is 1. Applicable for ClusterSelfGNN") + parser.add_argument("--aug_params", "-p", nargs="+", default=[ + 0.1, 0.2, 0.4, 0.1], help="Hyperparameters for augmentation (p_f1, p_f2, p_e1, p_e2). Default is [0.2, 0.1, 0.2, 0.3]") + parser.add_argument("--lr", '-lr', type=float, default=0.00001, + help="Learning rate. Default is 0.0001.") + parser.add_argument("--warmup_epochs", '-we', type=int, default=1000, + help="Warmup epochs. Default is 1000.") + parser.add_argument("--dropout", "-do", type=float, + default=0.0, help="Dropout rate. Default is 0.2") + parser.add_argument("--cache-step", '-cs', type=int, default=10, + help="The step size to cache the model, that is, every cache_step the model is persisted. Default is 100.") + parser.add_argument("--epochs", '-e', type=int, + default=20, help="The number of epochs") + parser.add_argument("--device", '-d', type=int, + default=3, help="GPU to use") + return parser.parse_args() + + +def create_dirs(dirs): + for dir_tree in dirs: + sub_dirs = dir_tree.split("/") + path = "" + for sub_dir in sub_dirs: + path = osp.join(path, sub_dir) + os.makedirs(path, exist_ok=True) + + +def create_masks(data): + """ + Splits data into training, validation, and test splits in a stratified manner if + it is not already splitted. Each split is associated with a mask vector, which + specifies the indices for that split. The data will be modified in-place + :param data: Data object + :return: The modified data + """ + if not hasattr(data, "val_mask"): + + data.train_mask = data.dev_mask = data.test_mask = None + + for i in range(20): + labels = data.y.numpy() + dev_size = int(labels.shape[0] * 0.1) + test_size = int(labels.shape[0] * 0.8) + + perm = np.random.permutation(labels.shape[0]) + test_index = perm[:test_size] + dev_index = perm[test_size:test_size + dev_size] + + data_index = np.arange(labels.shape[0]) + test_mask = torch.tensor(np.in1d(data_index, test_index), dtype=torch.bool) + dev_mask = torch.tensor(np.in1d(data_index, dev_index), dtype=torch.bool) + train_mask = ~(dev_mask + test_mask) + test_mask = test_mask.reshape(1, -1) + dev_mask = dev_mask.reshape(1, -1) + train_mask = train_mask.reshape(1, -1) + + if hasattr(data, "train_mask") and data.train_mask is not None: + data.train_mask = torch.cat((data.train_mask, train_mask), dim=0) + data.val_mask = torch.cat((data.val_mask, dev_mask), dim=0) + data.test_mask = torch.cat((data.test_mask, test_mask), dim=0) + else: + data.train_mask = train_mask + data.val_mask = dev_mask + data.test_mask = test_mask + + else : + data.train_mask = data.train_mask.T + data.val_mask = data.val_mask.T + + return data \ No newline at end of file