Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BGRL with CogDL #408

Merged
merged 3 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/bgrl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Large-Scale Representation Learning on Graphs via Bootstrapping (BGRL) with CogDL
This is an attempt to implement BGRL with CogDL for graph representation. The authors' implementation can be found [here](https://github.com/nerdslab/bgrl). Another version of the implementation from [Namkyeong](https://github.com/Namkyeong/BGRL_Pytorch) can also be used as a reference.

## Hyperparameters
Some optional parameters are allowed to be added to the training process.

`layers`: the dimension for each layer of GNN.

`pred_hid`: the hidden dimension of the predict moudle.

`aug_params`: the ratio of pollution for graph augmentation.

## Usage
You can find their datasets [here](https://pan.baidu.com/s/15RyvXD2G-xwGM9jrT7IDLQ?pwd=85vv) and put them in the path `./data`. Experiments on their datasets with given hyperparameters can be achieved by the following commands.

### Wiki-CS
```
python train.py --name WikiCS --aug_params 0.2 0.1 0.2 0.3 --layers 512 256 --pred_hid 512 --lr 0.0001 -epochs 10000 -cs 250
```
### Amazon Computers
```
python train.py --name computers --aug_params 0.2 0.1 0.5 0.4 --layers 256 128 --pred_hid 512 --lr 0.0005 --epochs 10000 -cs 250
```
### Amazon Photo
```
python train.py --name photo --aug_params 0.1 0.2 0.4 0.1 --layers 512 256 --pred_hid 512 --lr 0.0001 --epochs 10000 -cs 250
```
### Coauthor CS
```
python train.py --name cs --aug_params 0.3 0.4 0.3 0.2 --layers 512 256 --pred_hid 512 --lr 0.00001 --epochs 10000 -cs 250
```
### Coauthor Physics
```
python train.py --name physics --aug_params 0.1 0.4 0.4 0.1 --layers 256 128 --pred_hid 512 --lr 0.00001 --epochs 10000 -cs 250
```

## Performance
The results on five datasets shown on the table.

| |Wiki-CS|Computers|Photo |CS |Physics|
|------ |------ |---------|---------|-----|-------|
|Paper |79.98 |90.34 |93.17 |93.31|95.73 |
|Namkyeong |79.50 |88.21 |92.76 |92.49|94.89 |
|CogDL |79.76 |88.06 |92.91 |93.05|95.46 |
* Hyperparameters are from original paper

100 changes: 100 additions & 0 deletions examples/bgrl/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

import sys
import os
import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
from itertools import chain
from cogdl.data import Graph
from cogdl.utils.graph_utils import to_undirected, remove_self_loops

import utils
import json


def process_npz(path):
with np.load(path) as f:
x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']), f['attr_shape']).todense()
x = torch.from_numpy(x).to(torch.float)
x[x > 0] = 1

adj = sp.csr_matrix((f['adj_data'], f['adj_indices'], f['adj_indptr']),
f['adj_shape']).tocoo()
row = torch.from_numpy(adj.row).to(torch.long)
col = torch.from_numpy(adj.col).to(torch.long)
edge_index = torch.stack([row, col], dim=0)
edge_index, _ = remove_self_loops(edge_index)
edge_index = to_undirected(edge_index, num_nodes=x.size(0))

y = torch.from_numpy(f['labels']).to(torch.long)

return Graph(x=x, edge_index=edge_index, y=y)


def process_json(path):
with open(path, 'r') as f:
data = json.load(f)

x = torch.tensor(data['features'], dtype=torch.float)
y = torch.tensor(data['labels'], dtype=torch.long)

edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])]
edges = list(chain(*edges))
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_index = to_undirected(edge_index, num_nodes=x.size(0))

train_mask = torch.tensor(data['train_masks'], dtype=torch.bool)
train_mask = train_mask.t().contiguous()

val_mask = torch.tensor(data['val_masks'], dtype=torch.bool)
val_mask = val_mask.t().contiguous()

test_mask = torch.tensor(data['test_mask'], dtype=torch.bool)

stopping_mask = torch.tensor(data['stopping_masks'], dtype=torch.bool)
stopping_mask = stopping_mask.t().contiguous()

return Graph(
x=x,
y=y,
edge_index=edge_index,
train_mask=train_mask,
val_mask=val_mask,
test_mask=test_mask,
stopping_mask=stopping_mask
)


def normalize_feature(data):
feature = data.x
feature = feature - feature.min()
data.x = feature / feature.sum(dim=-1, keepdim=True).clamp_(min=1.)


def get_data(dataset):
dataset_filepath = {
"photo": "./data/Photo/raw/amazon_electronics_photo.npz",
"computers": "./data/Computers/raw/amazon_electronics_computers.npz",
"cs": "./data/CS/raw/ms_academic_cs.npz",
"physics": "./data/Physics/raw/ms_academic_phy.npz",
"WikiCS": "./data/WikiCS/raw/data.json"
}
assert dataset in dataset_filepath
filepath = dataset_filepath[dataset]
if dataset in ['WikiCS']:
data = process_json(filepath)
normalize_feature(data)
std, mean = torch.std_mean(data.x, dim=0, unbiased=False)
data.x = (data.x - mean) / std
data.edge_index = to_undirected(data.edge_index)
else:
data = process_npz(filepath)
normalize_feature(data)

data.add_remaining_self_loops()
data.sym_norm()

data = utils.create_masks(data=data)
return data

130 changes: 130 additions & 0 deletions examples/bgrl/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from cogdl.layers import GCNLayer

import torch.nn.functional as F
import torch.nn as nn
import torch

import numpy as np

import copy

"""
The following code is borrowed from BYOL, SelfGNN
and slightly modified for BGRL
"""


class EMA:
def __init__(self, beta, epochs):
super().__init__()
self.beta = beta
self.step = 0
self.total_steps = epochs

def update_average(self, old, new):
if old is None:
return new
beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0
self.step += 1
return old * beta + (1 - beta) * new


def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)


def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)


def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val


class Encoder(nn.Module):

def __init__(self, layer_config, dropout=None, project=False, **kwargs):
super().__init__()

self.conv1 = GCNLayer(layer_config[0], layer_config[1], bias=False, norm=None)
self.bn1 = nn.BatchNorm1d(layer_config[1], momentum=0.99)
self.prelu1 = nn.PReLU()
self.conv2 = GCNLayer(layer_config[1], layer_config[2], bias=False, norm=None)
self.bn2 = nn.BatchNorm1d(layer_config[2], momentum=0.99)
self.prelu2 = nn.PReLU()

def forward(self, x, graph, edge_weight=None):

# x = self.conv1(x, edge_index, edge_weight=edge_weight)
x = self.conv1(graph, x)
x = self.prelu1(self.bn1(x))
# x = self.conv2(x, edge_index, edge_weight=edge_weight)
x = self.conv2(graph, x)
x = self.prelu2(self.bn2(x))

return x


def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)


class BGRL(nn.Module):

def __init__(self, layer_config, pred_hid, dropout=0.0, moving_average_decay=0.99, epochs=1000, **kwargs):
super().__init__()
self.student_encoder = Encoder(layer_config=layer_config, dropout=dropout, **kwargs)
self.teacher_encoder = copy.deepcopy(self.student_encoder)
set_requires_grad(self.teacher_encoder, False)
self.teacher_ema_updater = EMA(moving_average_decay, epochs)
rep_dim = layer_config[-1]
self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_hid), nn.PReLU(), nn.Linear(pred_hid, rep_dim))
self.student_predictor.apply(init_weights)

def reset_moving_average(self):
del self.teacher_encoder
self.teacher_encoder = None

def update_moving_average(self):
assert self.teacher_encoder is not None, 'teacher encoder has not been created yet'
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)

def forward(self, x1, x2, graph_v1, graph_v2, edge_weight_v1=None, edge_weight_v2=None):
v1_student = self.student_encoder(x=x1, graph=graph_v1, edge_weight=edge_weight_v1)
v2_student = self.student_encoder(x=x2, graph=graph_v2, edge_weight=edge_weight_v2)

v1_pred = self.student_predictor(v1_student)
v2_pred = self.student_predictor(v2_student)

with torch.no_grad():
v1_teacher = self.teacher_encoder(x=x1, graph=graph_v1, edge_weight=edge_weight_v1)
v2_teacher = self.teacher_encoder(x=x2, graph=graph_v2, edge_weight=edge_weight_v2)

loss1 = loss_fn(v1_pred, v2_teacher.detach())
loss2 = loss_fn(v2_pred, v1_teacher.detach())

loss = loss1 + loss2
return v1_student, v2_student, loss.mean()


class LogisticRegression(nn.Module):
def __init__(self, num_dim, num_class):
super().__init__()
self.linear = nn.Linear(num_dim, num_class)
torch.nn.init.xavier_uniform_(self.linear.weight.data)
self.linear.bias.data.fill_(0.0)
self.cross_entropy = nn.CrossEntropyLoss()

def forward(self, x, y):

logits = self.linear(x)
loss = self.cross_entropy(logits, y)

return logits, loss
Loading