Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LPFormer model and example #9956

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for the `LPFormer` model ([#9956](https://github.com/pyg-team/pytorch_geometric/pull/9956))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
Expand Down
195 changes: 195 additions & 0 deletions examples/lpformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import random
from argparse import ArgumentParser
from collections import defaultdict

import numpy as np
import torch
from ogb.linkproppred import Evaluator, PygLinkPropPredDataset
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from tqdm import tqdm

from torch_geometric.nn.models import LPFormer

parser = ArgumentParser()
parser.add_argument('--data_name', type=str, default='ogbl-ppa')
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--runs', help="# random seeds to run over", type=int,
default=5)
parser.add_argument('--batch_size', type=int, default=32768)
parser.add_argument('--hidden_channels', type=int, default=64)
parser.add_argument('--gnn_layers', type=int, default=3)
parser.add_argument('--dropout', help="Applies to GNN and Transformer",
type=float, default=0.1)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--eps', help="PPR precision", type=float, default=5e-5)
parser.add_argument('--thresholds',
help="List of cn, 1-hop, >1-hop (in that order)",
nargs="+", default=[0, 1e-4, 1e-2])
args = parser.parse_args()

device = torch.device(args.device)

dataset = PygLinkPropPredDataset(name=args.data_name)
data = dataset[0].to(device)
data.edge_index = data.edge_index.to(device)

if hasattr(data, 'x') and data.x is not None:
data.x = data.x.to(device).to(torch.float)

split_edge = dataset.get_edge_split()
split_data = {
"train_pos": split_edge['train']['edge'].to(device),
"valid_pos": split_edge['valid']['edge'].to(device),
"valid_neg": split_edge['valid']['edge_neg'].to(device),
"test_pos": split_edge['test']['edge'].to(device),
"test_neg": split_edge['test']['edge_neg'].to(device)
}

if hasattr(data, 'edge_weight') and data.edge_weight is not None:
edge_weight = data.edge_weight.to(torch.float)
data.edge_weight = data.edge_weight.view(-1).to(torch.float)
else:
edge_weight = torch.ones(data.edge_index.size(1)).to(device).float()

# Convert edge_index to SparseTensor for efficiency
adj_prop = SparseTensor.from_edge_index(
data.edge_index, edge_weight.squeeze(-1),
[data.num_nodes, data.num_nodes]).to(device)

evaluator_hit = Evaluator(name=args.data_name)

model = LPFormer(data.x.size(-1), args.hidden_channels,
num_gnn_layers=args.gnn_layers,
ppr_thresholds=args.thresholds, gnn_dropout=args.dropout,
transformer_dropout=args.dropout, gcn_cache=True).to(device)

# Get PPR matrix in sparse format
ppr_matrix = model.calc_sparse_ppr(data.edge_index, data.num_nodes,
eps=args.eps).to(device)


def train_epoch():
model.train()
train_pos = split_data['train_pos'].to(device)
adjt_mask = torch.ones(train_pos.size(0), dtype=torch.bool, device=device)

total_loss = total_examples = 0
d = DataLoader(range(train_pos.size(0)), args.batch_size, shuffle=True)

for perm in tqdm(d, "Epoch"):
edges = train_pos[perm].t()

# Mask positive input samples - Common strategy during training
adjt_mask[perm] = 0
edge2keep = train_pos[adjt_mask, :]
masked_adj_prop = SparseTensor.from_edge_index(
edge2keep.t(), sparse_sizes=(data['num_nodes'],
data['num_nodes'])).to_device(device)
masked_adj_prop = masked_adj_prop.to_symmetric()
# For next batch
adjt_mask[perm] = 1

pos_out = model(edges, data.x, masked_adj_prop, ppr_matrix)
pos_loss = -torch.log(torch.sigmoid(pos_out) + 1e-6).mean()

# Trivial random sampling
neg_edges = torch.randint(0, data['num_nodes'],
(edges.size(0), edges.size(1)),
dtype=torch.long, device=edges.device)

neg_out = model(neg_edges, data.x, adj_prop, ppr_matrix)
neg_loss = -torch.log(1 - torch.sigmoid(neg_out) + 1e-6).mean()

loss = pos_loss + neg_loss
loss.backward()

optimizer.step()
optimizer.zero_grad()

num_examples = pos_out.size(0)
total_loss += loss.item() * num_examples
total_examples += num_examples

return total_loss / total_examples


@torch.no_grad()
def test():
# NOTE: Eval for ogbl-citation2 is different
# See `train.py` in https://github.com/HarryShomer/LPFormer/ for more
# Also see there for how to eval under the HeaRT setting
model.eval()
all_preds = defaultdict(list)

for split_key, split_vals in split_data.items():
if "train" not in split_key:
preds = []
for perm in DataLoader(range(split_vals.size(0)), args.batch_size):
edges = split_vals[perm].t()
perm_logits = model(edges, data.x, adj_prop, ppr_matrix)
preds += [torch.sigmoid(perm_logits).cpu()]

all_preds[split_key] = torch.cat(preds, dim=0)

val_hits = evaluator_hit.eval({
'y_pred_pos': all_preds['valid_pos'],
'y_pred_neg': all_preds['valid_neg']
})[f'hits@{evaluator_hit.K}']
test_hits = evaluator_hit.eval({
'y_pred_pos': all_preds['test_pos'],
'y_pred_neg': all_preds['test_neg']
})[f'hits@{evaluator_hit.K}']

return val_hits, test_hits


def set_seeds(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


# Train over args.runs seeds and average results
# Best result for reach run chosen via validation
val_perf_runs = []
test_perf_runs = []
for run in range(1, args.runs + 1):
print("=" * 75)
print(f"RUNNING run={run}")
print("=" * 75)

set_seeds(run)
model.reset_parameters()
optimizer = torch.optim.Adam(list(model.parameters()), lr=args.lr)

best_valid = 0
best_valid_test = 0

for epoch in range(1, 1 + args.epochs):
loss = train_epoch()
print(f"Epoch {epoch} Loss: {loss:.4f}\n")

if epoch % 5 == 0:
print("Evaluating model...\n", flush=True)
eval_val, eval_test = test()

print(f"Valid Hits@{evaluator_hit.K} = {eval_val}")
print(f"Test Hits@{evaluator_hit.K} = {eval_test}")

if eval_val > best_valid:
best_valid = eval_val
best_valid_test = eval_test

print(
f"\nBest Performance:\n Valid={best_valid}\n Test={best_valid_test}")
val_perf_runs.append(best_valid)
test_perf_runs.append(best_valid_test)

if args.runs > 1:
print("\n\n")
print(f"Results over {args.runs} runs:")
print(f" Valid = {np.mean(val_perf_runs)} +/- {np.std(val_perf_runs)}")
print(f" Test = {np.mean(test_perf_runs)} +/- {np.std(test_perf_runs)}")
38 changes: 38 additions & 0 deletions test/nn/models/test_lpformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch

import torch_geometric.typing
from torch_geometric.nn import LPFormer
from torch_geometric.testing import withPackage
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import to_undirected


@withPackage('numba') # For ppr calculation
def test_lpformer():
model = LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)
assert str(
model
) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)'

num_nodes = 20
x = torch.randn(num_nodes, 16)
edges = torch.randint(0, num_nodes - 1, (2, 110))
edge_index, test_edges = edges[:, :100], edges[:, 100:]
edge_index = to_undirected(edge_index)

ppr_matrix = model.calc_sparse_ppr(edge_index, num_nodes, eps=1e-4)

assert ppr_matrix.is_sparse
assert ppr_matrix.size() == (num_nodes, num_nodes)
assert ppr_matrix.sum().item() > 0

# Test with dense edge_index
out = model(test_edges, x, edge_index, ppr_matrix)
assert out.size() == (10, )

# Test with sparse edge_index
if torch_geometric.typing.WITH_TORCH_SPARSE:
adj = SparseTensor.from_edge_index(edge_index,
sparse_sizes=(num_nodes, num_nodes))
out2 = model(test_edges, x, adj, ppr_matrix)
assert out2.size() == (10, )
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .git_mol import GITMol
from .molecule_gpt import MoleculeGPT
from .glem import GLEM
from .lpformer import LPFormer
# Deprecated:
from torch_geometric.explain.algorithm.captum import (to_captum_input,
captum_output_to_dicts)
Expand Down Expand Up @@ -82,4 +83,5 @@
'GITMol',
'MoleculeGPT',
'GLEM',
'LPFormer',
]
Loading
Loading