Skip to content

Commit

Permalink
[Feature] Update Unsup_Graphsage (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
QingFei1 authored Aug 22, 2022
1 parent 1be3266 commit c7f35aa
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 120 deletions.
98 changes: 95 additions & 3 deletions cogdl/data/sampler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@

from tracemalloc import start
from turtle import pos
from typing import List
import os
import random
import numpy as np
import scipy.sparse as sp
import torch
import torch.utils.data

from cogdl.utils import remove_self_loops, row_normalization
from cogdl.data import Graph, DataLoader
from cogdl.utils import RandomWalker


class NeighborSampler(DataLoader):
Expand All @@ -16,8 +19,7 @@ def __init__(self, dataset, sizes: List[int], mask=None, **kwargs):
batch_size = kwargs["batch_size"]
else:
batch_size = 8

if isinstance(dataset.data, Graph):
if isinstance(dataset.data, Graph):
self.dataset = NeighborSamplerDataset(dataset, sizes, batch_size, mask)
else:
self.dataset = dataset
Expand All @@ -34,6 +36,29 @@ def shuffle(self):
self.dataset.shuffle()


class UnsupNeighborSampler(DataLoader):
def __init__(self, dataset, sizes: List[int], mask=None, **kwargs):
if "batch_size" in kwargs:
batch_size = kwargs["batch_size"]
else:
batch_size = 8

if isinstance(dataset.data, Graph):
self.dataset = UnsupNeighborSamplerDataset(dataset, sizes, batch_size, mask)
else:
self.dataset = dataset
kwargs["batch_size"] = 1
kwargs["shuffle"] = False
kwargs["collate_fn"] = UnsupNeighborSampler.collate_fn
super(UnsupNeighborSampler, self).__init__(dataset=self.dataset, **kwargs)

@staticmethod
def collate_fn(data):
return data[0]

def shuffle(self):
self.dataset.shuffle()

class NeighborSamplerDataset(torch.utils.data.Dataset):
def __init__(self, dataset, sizes: List[int], batch_size: int, mask=None):
super(NeighborSamplerDataset, self).__init__()
Expand All @@ -51,6 +76,65 @@ def shuffle(self):
idx = torch.randperm(self.num_nodes)
self.node_idx = self.node_idx[idx]

def __len__(self):
return (self.num_nodes - 1) // self.batch_size + 1

def __getitem__(self, idx):
"""
Sample a subgraph with neighborhood sampling
Args:
idx: torch.Tensor / np.array
Target nodes
Returns:
if `size` is `[-1,]`,
(
source_nodes_id: Tensor,
sampled_edges: Tensor,
(number_of_source_nodes, number_of_target_nodes): Tuple[int]
)
otherwise,
(
target_nodes_id: Tensor
all_sampled_nodes_id: Tensor,
sampled_adjs: List[Tuple(Tensor, Tensor, Tuple[int]]
)
"""
batch = self.node_idx[idx * self.batch_size : (idx + 1) * self.batch_size]
node_id = batch
adj_list = []
for size in self.sizes:
src_id, graph = self.data.sample_adj(node_id, size, replace=False)
size = (len(src_id), len(node_id))
adj_list.append((src_id, graph, size)) # src_id, graph, (src_size, target_size)
node_id = src_id

if self.sizes == [-1]:
src_id, graph, _ = adj_list[0]
size = (len(src_id), len(batch))
return src_id, graph, size
else:
return batch, node_id, adj_list[::-1]


class UnsupNeighborSamplerDataset(torch.utils.data.Dataset):
def __init__(self, dataset, sizes: List[int], batch_size: int, mask=None):
super(UnsupNeighborSamplerDataset, self).__init__()
self.data = dataset.data
self.x = self.data.x
self.edge_index=self.data.edge_index
self.sizes = sizes
self.batch_size = batch_size
self.node_idx = torch.arange(0, self.data.x.shape[0], dtype=torch.long)
self.total_num_nodes=self.num_nodes = self.node_idx.shape[0]
if mask is not None:
self.node_idx = self.node_idx[mask]
self.num_nodes = self.node_idx.shape[0]
self.random_walker = RandomWalker()

def shuffle(self):
idx = torch.randperm(self.num_nodes)
self.node_idx = self.node_idx[idx]

def __len__(self):
return (self.num_nodes - 1) // self.batch_size + 1

Expand All @@ -75,6 +159,13 @@ def __getitem__(self, idx):
)
"""
batch = self.node_idx[idx * self.batch_size : (idx + 1) * self.batch_size]
self.random_walker.build_up(self.edge_index, self.total_num_nodes)
walk_res=self.random_walker.walk_one(batch,length=1,p=0.0)

neg_batch = torch.randint(0, self.total_num_nodes, (batch.numel(), ),
dtype=torch.int64)
pos_batch=torch.tensor(walk_res)
batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
node_id = batch
adj_list = []
for size in self.sizes:
Expand Down Expand Up @@ -197,3 +288,4 @@ def __len__(self):

def shuffle(self):
self.parts = torch.randint(0, self.n_cluster, size=(self.num_nodes,))

2 changes: 1 addition & 1 deletion cogdl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def build_model(args):
"sortpool": "cogdl.models.nn.sortpool.SortPool",
"srgcn": "cogdl.models.nn.srgcn.SRGCN",
"gcc": "cogdl.models.nn.gcc_model.GCCModel",
"unsup_graphsage": "cogdl.models.nn.unsup_graphsage.SAGE",
"unsup_graphsage": "cogdl.models.nn.graphsage.Graphsage",
"graphsaint": "cogdl.models.nn.graphsaint.GraphSAINT",
"m3s": "cogdl.models.nn.m3s.M3S",
"moe_gcn": "cogdl.models.nn.moe_gcn.MoEGCN",
Expand Down
82 changes: 0 additions & 82 deletions cogdl/models/nn/unsup_graphsage.py

This file was deleted.

22 changes: 22 additions & 0 deletions cogdl/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,25 @@ def walk(self, start, walk_length, restart_p=0.0):
result = random_walk(start, walk_length, self.indptr, self.indices, restart_p)
result = np.array(result, dtype=np.int64)
return result

def walk_one(self,start,length,p):
walk_res = [np.zeros(length, dtype=np.int32)] * len(start)
p=0.0
for i in range(len(start)):
node=start[i]
result = [np.int32(0)] * length
index = np.int32(0)
_node = node
while index < length:
start1 = self.indptr[node]
end1 = self.indptr[node + 1]
sample1 = random.randint(start1, end1 - 1)
node = self.indices[sample1]
if np.random.uniform(0, 1) > p:
result[index] = node
else:
result[index] = _node
index += 1
k = int(np.floor(np.random.rand() * len(result)))
walk_res[i] = result[k]
return walk_res
1 change: 1 addition & 0 deletions cogdl/wrappers/data_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def fetch_data_wrapper(name):
"triple_link_prediction_dw": "cogdl.wrappers.data_wrapper.link_prediction.TripleDataWrapper",
"cluster_dw": "cogdl.wrappers.data_wrapper.node_classification.ClusterWrapper",
"graphsage_dw": "cogdl.wrappers.data_wrapper.node_classification.GraphSAGEDataWrapper",
"unsup_graphsage_dw": "cogdl.wrappers.data_wrapper.node_classification.UnsupGraphSAGEDataWrapper",
"m3s_dw": "cogdl.wrappers.data_wrapper.node_classification.M3SDataWrapper",
"network_embedding_dw": "cogdl.wrappers.data_wrapper.node_classification.NetworkEmbeddingDataWrapper",
"node_classification_dw": "cogdl.wrappers.data_wrapper.node_classification.FullBatchNodeClfDataWrapper",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cluster_dw import ClusterWrapper
from .graphsage_dw import GraphSAGEDataWrapper
from .unsup_graphsage_dw import UnsupGraphSAGEDataWrapper
from .m3s_dw import M3SDataWrapper
from .network_embedding_dw import NetworkEmbeddingDataWrapper
from .node_classification_dw import FullBatchNodeClfDataWrapper
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from .. import DataWrapper
from cogdl.data.sampler import UnsupNeighborSamplerDataset,UnsupNeighborSampler


class UnsupGraphSAGEDataWrapper(DataWrapper):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--sample-size", type=int, nargs='+', default=[10, 10])
# fmt: on

def __init__(self, dataset, batch_size: int, sample_size: list):
super(UnsupGraphSAGEDataWrapper, self).__init__(dataset)
self.dataset = dataset
self.train_dataset = UnsupNeighborSamplerDataset(
dataset, sizes=sample_size, batch_size=batch_size, mask=dataset.data.train_mask
)
self.val_dataset = UnsupNeighborSamplerDataset(
dataset, sizes=sample_size, batch_size=batch_size * 2, mask=dataset.data.val_mask
)
self.test_dataset = UnsupNeighborSamplerDataset(
dataset=self.dataset,
mask=None,
sizes=[-1],
batch_size=batch_size * 2,
)
self.x = self.dataset.data.x
self.y = self.dataset.data.y
self.batch_size = batch_size
self.sample_size = sample_size

def train_wrapper(self):
self.dataset.data.train()
return UnsupNeighborSampler(
dataset=self.train_dataset,
mask=self.dataset.data.train_mask,
sizes=self.sample_size,
num_workers=4,
shuffle=False,
batch_size=self.batch_size,
)

def test_wrapper(self):
return (
self.dataset,
UnsupNeighborSampler(
dataset=self.test_dataset,
mask=None,
sizes=[-1],
batch_size=self.batch_size * 2,
shuffle=False,
num_workers=4,
),
)

def train_transform(self, batch):
target_id, n_id, adjs = batch
x_src = self.x[n_id]

return x_src, adjs


def get_train_dataset(self):
return self.train_dataset

def pre_transform(self):
self.dataset.data.add_remaining_self_loops()


1 change: 1 addition & 0 deletions cogdl/wrappers/default_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def set_default_wrapper_config():

node_classification_wrappers["m3s"]["dw"] = "m3s_dw"
node_classification_wrappers["graphsage"]["dw"] = "graphsage_dw"
node_classification_wrappers["unsup_graphsage"]["dw"] = "unsup_graphsage_dw"
node_classification_wrappers["pprgo"]["dw"] = "pprgo_dw"
node_classification_wrappers["sagn"]["dw"] = "sagn_dw"

Expand Down
Loading

0 comments on commit c7f35aa

Please sign in to comment.