diff --git a/cogdl/configs.py b/cogdl/configs.py index 84f8f4b4..f9df112f 100644 --- a/cogdl/configs.py +++ b/cogdl/configs.py @@ -156,6 +156,7 @@ }, }, "ppnp": { + "general": {}, "flickr": { "lr": 0.005, "weight_decay": 0.001, @@ -209,6 +210,57 @@ # 96.26 }, }, + "correct_smooth_mlp": { + "general": {}, + "ogbn_arxiv": { + "correct_norm": "row", + "smooth_norm": "col", + "correct_alpha": 0.9791632871592579, + "smooth_alpha": 0.7564990804200602, + "num_correct_prop": 50, + "num_smooth_prop": 50, + "autoscale": True, + "norm": "batchnorm", + }, + "ogbn_products": { + "correct_norm": "sym", + "smooth_norm": "row", + "correct_alpha": 1.0, + "smooth_alpha": 0.8, + "num_correct_prop": 50, + "num_smooth_prop": 50, + "autoscale": False, + "scale": 10.0, + "norm": "batchnorm", + "act_first": True, + }, + }, + "sagn": { + "general": { + "data_gpu": True, + "lr": 0.001, + "hidden-size": 512, + "attn-drop": 0.0, + "dropout": 0.7, + }, + "flickr": { + "threshold": 0.5, + "label-hop": 2, + "weight-decay": 3e-6, + "nstage": [50, 50, 50], + "nhop": 2, + "batch-size": 256, + }, + "reddit": { + "threshold": 0.9, + "lr": 0.0001, + "batch-size": 1000, + "nhop": 2, + "label-nhop": 4, + "weight-decay": 0.0, + "nstage": [500, 500, 500], + }, + }, }, "unsupervised_node_classification": { "deepwalk": { diff --git a/cogdl/data/data.py b/cogdl/data/data.py index 3930ba11..152a91d6 100644 --- a/cogdl/data/data.py +++ b/cogdl/data/data.py @@ -196,6 +196,13 @@ def row_norm(self): self.normalize_adj("row") self.__symmetric__ = False + def col_norm(self): + if self.row is None: + self.generate_normalization("col") + else: + self.normalize_adj("col") + self.__symmetric__ = False + def generate_normalization(self, norm="sym"): if self.__normed__: return @@ -209,6 +216,9 @@ def generate_normalization(self, norm="sym"): edge_norm[torch.isinf(edge_norm)] = 0 self.__out_norm__ = None self.__in_norm__ = edge_norm.view(-1, 1) + elif norm == "col": + self.row, _, _ = csr2coo(self.row_ptr, self.col, self.weight) + self.weight = row_normalization(self.num_nodes, self.col, self.row, self.weight) else: raise NotImplementedError self.__normed__ = norm @@ -223,6 +233,8 @@ def normalize_adj(self, norm="sym"): self.weight = symmetric_normalization(self.num_nodes, self.row, self.col, self.weight) elif norm == "row": self.weight = row_normalization(self.num_nodes, self.row, self.col, self.weight) + elif norm == "col": + self.weight = row_normalization(self.num_nodes, self.col, self.row, self.weight) else: raise NotImplementedError self.__normed__ = norm @@ -450,9 +462,16 @@ def remove_self_loops(self): def row_norm(self): self._adj.row_norm() + def col_norm(self): + self._adj.col_norm() + def sym_norm(self): self._adj.sym_norm() + def normalize(self, key="sym"): + assert key in ["row", "sym", "col"], "Support row/col/sym normalization" + getattr(self, f"{key}_norm")() + def is_symmetric(self): return self._adj.is_symmetric() @@ -462,6 +481,28 @@ def set_symmetric(self): def set_asymmetric(self): self._adj.set_symmetric(False) + def is_inductive(self): + return self._adj_train is not None + + def mask2nid(self, split): + mask = getattr(self, f"{split}_mask") + if mask is not None: + if mask.dtype is torch.bool: + return torch.where(mask)[0] + return mask + + @property + def train_nid(self): + return self.mask2nid("train") + + @property + def val_nid(self): + return self.mask2nid("val") + + @property + def test_nid(self): + return self.mask2nid("test") + @contextmanager def local_graph(self, key=None): self.__temp_adj_stack__.append(self._adj) diff --git a/cogdl/datasets/ogb.py b/cogdl/datasets/ogb.py index e86d76ad..e1848ccd 100644 --- a/cogdl/datasets/ogb.py +++ b/cogdl/datasets/ogb.py @@ -84,6 +84,12 @@ class OGBArxivDataset(OGBNDataset): def __init__(self, data_path="data"): dataset = "ogbn-arxiv" super(OGBArxivDataset, self).__init__(data_path, dataset) + self.preprocessing() + + def preprocessing(self): + row, col = self.data.edge_index + edge_index = to_undirected(torch.stack([row, col])) + self.data.edge_index = edge_index def get_evaluator(self): evaluator = NodeEvaluator(name="ogbn-arxiv") diff --git a/cogdl/match.yml b/cogdl/match.yml index 4d99a2cc..f98eca78 100644 --- a/cogdl/match.yml +++ b/cogdl/match.yml @@ -34,6 +34,8 @@ node_classification: - mvgrl - grace - self_auxiliary_task + - correct_smooth_mlp + - sagn dataset: - cora - citeseer diff --git a/cogdl/models/__init__.py b/cogdl/models/__init__.py index 5816f97f..86452e37 100644 --- a/cogdl/models/__init__.py +++ b/cogdl/models/__init__.py @@ -130,4 +130,7 @@ def build_model(args): "self_auxiliary_task": "cogdl.models.nn.self_auxiliary_task", "moe_gcn": "cogdl.models.nn.moe_gcn", "lightgcn": "cogdl.models.nn.lightgcn", + "correct_smooth": "cogdl.models.nn.correct_smooth", + "correct_smooth_mlp": "cogdl.models.nn.correct_smooth", + "sagn": "cogdl.models.nn.sagn", } diff --git a/cogdl/models/base_model.py b/cogdl/models/base_model.py index d1bb0a99..7fa3c581 100644 --- a/cogdl/models/base_model.py +++ b/cogdl/models/base_model.py @@ -42,7 +42,7 @@ def graph_classification_loss(self, batch): return self.loss_fn(pred, batch.y) @staticmethod - def get_trainer(task: Any, args: Any) -> Optional[Type[BaseTrainer]]: + def get_trainer(args=None) -> Optional[Type[BaseTrainer]]: return None def set_device(self, device): diff --git a/cogdl/models/nn/agc.py b/cogdl/models/nn/agc.py index 4bd1df75..26c015cb 100644 --- a/cogdl/models/nn/agc.py +++ b/cogdl/models/nn/agc.py @@ -28,7 +28,8 @@ def __init__(self, num_clusters, max_iter): self.k = 0 self.features_matrix = None - def get_trainer(self, task, args): + @staticmethod + def get_trainer(args): return AGCTrainer def get_features(self, data): diff --git a/cogdl/models/nn/correct_smooth.py b/cogdl/models/nn/correct_smooth.py new file mode 100644 index 00000000..3f9f023f --- /dev/null +++ b/cogdl/models/nn/correct_smooth.py @@ -0,0 +1,217 @@ +from functools import partial + +import torch +import torch.nn.functional as F + +from .. import BaseModel, register_model +from .mlp import MLP +from cogdl.data import Graph +from cogdl.utils import spmm + + +def autoscale_post(x, lower, upper): + return torch.clamp(x, lower, upper) + + +def fixed_post(x, y, nid): + x[nid] = y[nid] + return x + + +def pre_residual_correlation(preds, labels, split_idx): + labels[labels.isnan()] = 0 + labels = labels.long() + nclass = labels.max().item() + 1 + nnode = preds.shape[0] + err = torch.zeros((nnode, nclass), device=preds.device) + err[split_idx] = F.one_hot(labels[split_idx], nclass).float().squeeze(1) - preds[split_idx] + return err + + +def pre_outcome_correlation(preds, labels, label_nid): + """Generates the initial labels used for outcome correlation""" + c = labels.max() + 1 + y = preds.clone() + if len(label_nid) > 0: + y[label_nid] = F.one_hot(labels[label_nid], c).float().squeeze(1) + return y + + +def outcome_correlation(g, labels, alpha, nprop, post_step, alpha_term=True): + result = labels.clone() + for _ in range(nprop): + result = alpha * spmm(g, result) + if alpha_term: + result += (1 - alpha) * labels + else: + result += labels + result = post_step(result) + return result + + +def correlation_autoscale(preds, y, resid, residual_nid, scale=1.0): + orig_diff = y[residual_nid].abs().sum() / residual_nid.shape[0] + resid_scale = orig_diff / resid.abs().sum(dim=1, keepdim=True) + resid_scale[resid_scale.isinf()] = 1.0 + cur_idxs = resid_scale > 1000 + resid_scale[cur_idxs] = 1.0 + res_result = preds + resid_scale * resid + res_result[res_result.isnan()] = preds[res_result.isnan()] + return res_result + + +def correlation_fixed(preds, y, resid, residual_nid, scale=1.0): + return preds + scale * resid + + +def diffusion(g, x, nhtop, p=1, alpha=0.5): + x = x ** p + for _ in range(nhtop): + x = (1 - alpha) * x + alpha * spmm(g, x) + x = x ** p + return x + + +@register_model("correct_smooth") +class CorrectSmooth(BaseModel): + @staticmethod + def add_args(parser): + parser.add_argument("--correct-alpha", type=float, default=1.0) + parser.add_argument("--smooth-alpha", type=float, default=0.8) + parser.add_argument("--num-correct-prop", type=int, default=50) + parser.add_argument("--num-smooth-prop", type=int, default=50) + parser.add_argument("--autoscale", action="store_true") + parser.add_argument("--correct-norm", type=str, default="sym") + parser.add_argument("--smooth-norm", type=str, default="row") + parser.add_argument("--scale", type=float, default=1.0) + + @classmethod + def build_model_from_args(cls, args): + return cls( + args.correct_alpha, + args.smooth_alpha, + args.num_correct_prop, + args.num_smooth_prop, + args.autoscale, + args.correct_norm, + args.smooth_norm, + args.scale, + ) + + def __init__( + self, + correct_alpha, + smooth_alpha, + num_correct_prop, + num_smooth_prop, + autoscale=False, + correct_norm="row", + smooth_norm="col", + scale=1.0, + ): + super(CorrectSmooth, self).__init__() + self.op_dict = { + "correct_g": correct_norm, + "smooth_g": smooth_norm, + "num_correct_prop": num_correct_prop, + "num_smooth_prop": num_smooth_prop, + "correct_alpha": correct_alpha, + "smooth_alpha": smooth_alpha, + "autoscale": autoscale, + "scale": scale, + } + + def __call__(self, graph, x, train_only=True): + g1 = graph + g2 = Graph(edge_index=g1.edge_index) + + g1.normalize(self.op_dict["correct_g"]) + g2.normalize(self.op_dict["smooth_g"]) + + train_nid, valid_nid, _ = g1.train_nid, g1.val_nid, g1.test_nid + y = g1.y + + if train_only: + label_nid = train_nid + residual_nid = train_nid + else: + label_nid = torch.cat((train_nid, valid_nid)) + residual_nid = train_nid + + # Correct + y = pre_residual_correlation(x, y, residual_nid) + + if self.op_dict["autoscale"]: + post_func = partial(autoscale_post, lower=-1.0, upper=1.0) + scale_func = correlation_autoscale + else: + post_func = partial(fixed_post, y=y, nid=residual_nid) + scale_func = correlation_fixed + + resid = outcome_correlation( + g1, y, self.op_dict["correct_alpha"], nprop=self.op_dict["num_correct_prop"], post_step=post_func + ) + res_result = scale_func(x, y, resid, residual_nid, self.op_dict["scale"]) + + # Smooth + y = pre_outcome_correlation(res_result, g1.y, label_nid) + result = outcome_correlation( + g2, + y, + self.op_dict["smooth_alpha"], + nprop=self.op_dict["num_smooth_prop"], + post_step=partial(autoscale_post, lower=0, upper=1), + ) + return result + + +@register_model("correct_smooth_mlp") +class CorrectSmoothMLP(BaseModel): + @staticmethod + def add_args(parser): + CorrectSmooth.add_args(parser) + MLP.add_args(parser) + parser.add_argument("--use-embeddings", action="store_true") + + @classmethod + def build_model_from_args(cls, args): + return cls(args) + + def __init__(self, args): + super(CorrectSmoothMLP, self).__init__() + if args.use_embeddings: + args.num_features = args.num_features * 2 + args.act_first = True if args.dataset == "ogbn-products" else False + self.use_embeddings = args.use_embeddings + self.mlp = MLP.build_model_from_args(args) + self.c_s = CorrectSmooth.build_model_from_args(args) + self.rescale_feats = args.rescale_feats if hasattr(args, "rescale_feats") else args.dataset == "ogbn-arxiv" + self.cache_x = None + + def forward(self, graph): + if self.cache_x is not None: + x = self.cache_x + elif self.use_embeddings: + _x = graph.x.contiguous() + _x = diffusion(graph, _x, nhtop=10) + x = torch.cat([graph.x, _x], dim=1) + if self.rescale_feats: + x = (x - x.mean(0)) / x.std(0) + self.cache_x = x + else: + x = graph.x + out = self.mlp(x) + return out + + def predict(self, data): + out = self.forward(data) + return out + + def postprocess(self, data, out): + print("Correct and Smoothing...") + if len(data.y.shape) == 1: + out = F.softmax(out, dim=-1) + # else: + # out = torch.sigmoid(out) + out = self.c_s(data, out) + return out diff --git a/cogdl/models/nn/daegc.py b/cogdl/models/nn/daegc.py index 82ef8823..6d1fbcd4 100644 --- a/cogdl/models/nn/daegc.py +++ b/cogdl/models/nn/daegc.py @@ -51,7 +51,8 @@ def __init__(self, num_features, hidden_size, embedding_size, num_heads, dropout self.att2 = GATLayer(hidden_size * num_heads, embedding_size, dropout=dropout, alpha=0.2, nhead=1, concat=False) self.cluster_center = None - def get_trainer(self, task, args): + @staticmethod + def get_trainer(args=None): return DAEGCTrainer def forward(self, graph): diff --git a/cogdl/models/nn/deepergcn.py b/cogdl/models/nn/deepergcn.py index 5497a370..bb824639 100644 --- a/cogdl/models/nn/deepergcn.py +++ b/cogdl/models/nn/deepergcn.py @@ -120,5 +120,5 @@ def predict(self, graph): return self.forward(graph) @staticmethod - def get_trainer(taskType: Any, args): + def get_trainer(args): return RandomClusterTrainer diff --git a/cogdl/models/nn/dgi.py b/cogdl/models/nn/dgi.py index cdc2308f..f2b01cb6 100644 --- a/cogdl/models/nn/dgi.py +++ b/cogdl/models/nn/dgi.py @@ -168,5 +168,5 @@ def embed(self, data, msk=None): return h_1.detach() # , c.detach() @staticmethod - def get_trainer(task, args): + def get_trainer(args): return SelfSupervisedPretrainer diff --git a/cogdl/models/nn/dgl_jknet.py b/cogdl/models/nn/dgl_jknet.py index 0b6c96c2..34dee3f2 100644 --- a/cogdl/models/nn/dgl_jknet.py +++ b/cogdl/models/nn/dgl_jknet.py @@ -264,5 +264,5 @@ def set_graph(self, graph): self.graph = graph @staticmethod - def get_trainer(taskType, args): + def get_trainer(args): return JKNetTrainer diff --git a/cogdl/models/nn/gae.py b/cogdl/models/nn/gae.py index 62edd5e1..b403efb2 100644 --- a/cogdl/models/nn/gae.py +++ b/cogdl/models/nn/gae.py @@ -26,7 +26,8 @@ def make_loss(self, data, adj): def get_features(self, data): return self.embed(data).detach() - def get_trainer(self, task, args): + @staticmethod + def get_trainer(args=None): return GAETrainer @@ -88,5 +89,6 @@ def make_loss(self, data, adj): print("recon_loss = %.3f, kl_loss = %.3f" % (recon_loss, kl_loss)) return recon_loss + kl_loss - def get_trainer(self, task, args): + @staticmethod + def get_trainer(args): return GAETrainer diff --git a/cogdl/models/nn/grace.py b/cogdl/models/nn/grace.py index 2d0f957c..600ad509 100644 --- a/cogdl/models/nn/grace.py +++ b/cogdl/models/nn/grace.py @@ -183,5 +183,5 @@ def drop_feature(self, x: torch.Tensor, droprate: float): return x @staticmethod - def get_trainer(task, args): + def get_trainer(args): return SelfSupervisedPretrainer diff --git a/cogdl/models/nn/graphsage.py b/cogdl/models/nn/graphsage.py index 1dc8fa6e..657cde5b 100644 --- a/cogdl/models/nn/graphsage.py +++ b/cogdl/models/nn/graphsage.py @@ -138,7 +138,7 @@ def inference(self, x_all, data_loader): return x_all @staticmethod - def get_trainer(task: Any, args: Any): + def get_trainer(args): if args.dataset not in ["cora", "citeseer", "pubmed"]: return NeighborSamplingTrainer if hasattr(args, "use_trainer"): diff --git a/cogdl/models/nn/graphsaint.py b/cogdl/models/nn/graphsaint.py index 1cf070bb..eb5a79f5 100644 --- a/cogdl/models/nn/graphsaint.py +++ b/cogdl/models/nn/graphsaint.py @@ -170,5 +170,5 @@ def predict(self, data): return self.forward(data) @staticmethod - def get_trainer(task, args): + def get_trainer(args): return SAINTTrainer diff --git a/cogdl/models/nn/m3s.py b/cogdl/models/nn/m3s.py index 6e8b1a0b..0e0ec6fe 100644 --- a/cogdl/models/nn/m3s.py +++ b/cogdl/models/nn/m3s.py @@ -58,5 +58,5 @@ def predict(self, data): return self.forward(data) @staticmethod - def get_trainer(taskType, args): + def get_trainer(args): return M3STrainer diff --git a/cogdl/models/nn/mlp.py b/cogdl/models/nn/mlp.py index fe78d892..e9c1980b 100644 --- a/cogdl/models/nn/mlp.py +++ b/cogdl/models/nn/mlp.py @@ -34,6 +34,8 @@ def add_args(parser): parser.add_argument("--hidden-size", type=int, default=16) parser.add_argument("--num-layers", type=int, default=2) parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument("--norm", type=str, default=None) + parser.add_argument("--activation", type=str, default="relu") # fmt: on @classmethod @@ -44,15 +46,32 @@ def build_model_from_args(cls, args): args.hidden_size, args.num_layers, args.dropout, + args.activation, + args.norm, + args.act_first if hasattr(args, "act_first") else False, ) - def __init__(self, in_feats, out_feats, hidden_size, num_layers, dropout=0.0, activation="relu", norm=None): + def __init__( + self, + in_feats, + out_feats, + hidden_size, + num_layers, + dropout=0.0, + activation="relu", + norm=None, + act_first=False, + bias=True, + ): super(MLP, self).__init__() self.norm = norm self.activation = get_activation(activation) + self.act_first = act_first self.dropout = dropout shapes = [in_feats] + [hidden_size] * (num_layers - 1) + [out_feats] - self.mlp = nn.ModuleList([nn.Linear(shapes[layer], shapes[layer + 1]) for layer in range(num_layers)]) + self.mlp = nn.ModuleList( + [nn.Linear(shapes[layer], shapes[layer + 1], bias=bias) for layer in range(num_layers)] + ) if norm is not None and num_layers > 1: if norm == "layernorm": self.norm_list = nn.ModuleList(nn.LayerNorm(x) for x in shapes[1:-1]) @@ -60,15 +79,27 @@ def __init__(self, in_feats, out_feats, hidden_size, num_layers, dropout=0.0, ac self.norm_list = nn.ModuleList(nn.BatchNorm1d(x) for x in shapes[1:-1]) else: raise NotImplementedError(f"{norm} is not implemented in CogDL.") + self.reset_parameters() + + def reset_parameters(self): + for layer in self.mlp: + layer.reset_parameters() + if hasattr(self, "norm_list"): + for n in self.norm_list: + n.reset_parameters() def forward(self, x, *args, **kwargs): if isinstance(x, Graph): x = x.x for i, fc in enumerate(self.mlp[:-1]): x = fc(x) + if self.act_first: + x = self.activation(x) if self.norm: x = self.norm_list[i](x) - x = self.activation(x) + + if not self.act_first: + x = self.activation(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.mlp[-1](x) return x diff --git a/cogdl/models/nn/moe_gcn.py b/cogdl/models/nn/moe_gcn.py index 54484b5c..6e524c84 100644 --- a/cogdl/models/nn/moe_gcn.py +++ b/cogdl/models/nn/moe_gcn.py @@ -38,11 +38,12 @@ class GraphConv(nn.Module): Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ - def __init__(self, in_features, out_features, dropout=0.0, residual=False, norm=None, bias=True): + def __init__(self, in_features, out_features, dropout=0.0, residual=False, activation=None, norm=None, bias=True): super(GraphConv, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) + self.act = get_activation(activation) if dropout > 0: self.dropout = nn.Dropout(dropout) else: @@ -111,7 +112,8 @@ def __init__(self, conv_func, conv_params, out_feats, dropout=0.0, in_feats=None def reset_parameters(self): """Reinitialize model parameters.""" # self.graph_conv.reset_parameters() - self.res_connection.reset_parameters() + if self.res_connection is not None: + self.res_connection.reset_parameters() def forward(self, graph, feats): new_feats = self.graph_conv(graph, feats) @@ -121,7 +123,6 @@ def forward(self, graph, feats): new_feats = F.dropout(new_feats, p=self.dropout, training=self.training) new_feats = self.pos_ff(new_feats) - new_feats = self.act(new_feats) return new_feats diff --git a/cogdl/models/nn/mvgrl.py b/cogdl/models/nn/mvgrl.py index 8976d73a..8443e840 100644 --- a/cogdl/models/nn/mvgrl.py +++ b/cogdl/models/nn/mvgrl.py @@ -194,5 +194,5 @@ def embed(self, data, msk=None): return (h_1 + h_2).detach() # , c.detach() @staticmethod - def get_trainer(taskType, args): + def get_trainer(args): return SelfSupervisedPretrainer diff --git a/cogdl/models/nn/pprgo.py b/cogdl/models/nn/pprgo.py index b4566908..63629716 100644 --- a/cogdl/models/nn/pprgo.py +++ b/cogdl/models/nn/pprgo.py @@ -86,5 +86,5 @@ def predict(self, graph, batch_size, norm): return predictions @staticmethod - def get_trainer(taskType: Any, args: Any): + def get_trainer(args: Any): return PPRGoTrainer diff --git a/cogdl/models/nn/pyg_gcn.py b/cogdl/models/nn/pyg_gcn.py index b1b4b695..cbd0888f 100644 --- a/cogdl/models/nn/pyg_gcn.py +++ b/cogdl/models/nn/pyg_gcn.py @@ -29,9 +29,6 @@ def build_model_from_args(cls, args): args.dropout, ) - def get_trainer(self, task, args): - return None - def __init__(self, num_features, num_classes, hidden_size, num_layers, dropout): super(GCN, self).__init__() diff --git a/cogdl/models/nn/pyg_gpt_gnn.py b/cogdl/models/nn/pyg_gpt_gnn.py index 6ba4723f..68e1f5cc 100644 --- a/cogdl/models/nn/pyg_gpt_gnn.py +++ b/cogdl/models/nn/pyg_gpt_gnn.py @@ -200,9 +200,7 @@ def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any: pass @staticmethod - def get_trainer( - taskType: Any, args - ) -> Optional[Type[Union[GPT_GNNHomogeneousTrainer, GPT_GNNHeterogeneousTrainer]]]: + def get_trainer(args) -> Optional[Type[Union[GPT_GNNHomogeneousTrainer, GPT_GNNHeterogeneousTrainer]]]: # if taskType == NodeClassification: return GPT_GNNHomogeneousTrainer # elif taskType == HeterogeneousNodeClassification: diff --git a/cogdl/models/nn/pyg_supergat.py b/cogdl/models/nn/pyg_supergat.py index bc0d3878..f7fb6bf2 100644 --- a/cogdl/models/nn/pyg_supergat.py +++ b/cogdl/models/nn/pyg_supergat.py @@ -471,7 +471,7 @@ def modules(self) -> List[SuperGATLayer]: return [self.conv1, self.conv2] @staticmethod - def get_trainer(task, args): + def get_trainer(args): return SuperGATTrainer @@ -579,5 +579,5 @@ def modules(self) -> List[SuperGATLayer]: return self.conv_list @staticmethod - def get_trainer(task, args): + def get_trainer(args): return SuperGATTrainer diff --git a/cogdl/models/nn/sagn.py b/cogdl/models/nn/sagn.py new file mode 100644 index 00000000..c8e8cfb2 --- /dev/null +++ b/cogdl/models/nn/sagn.py @@ -0,0 +1,396 @@ +import copy +import os +import os.path as osp +import numpy as np +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .. import BaseModel, register_model +from .mlp import MLP +from cogdl.utils import spmm +from cogdl.trainers import BaseTrainer, register_trainer + + +def average_neighbor_features(graph, feats, nhop, norm="sym", style="all"): + results = [] + if norm == "sym": + graph.sym_norm() + elif norm == "row": + graph.row_norm() + else: + raise NotImplementedError + + x = feats + results.append(x) + for i in range(nhop): + x = spmm(graph, x) + if style == "all": + results.append(x) + if style != "all": + results = x + return results + + +def entropy(probs): + eps = 1e-9 + res = -probs * torch.log(probs + eps) - (1 - probs) * torch.log(1 - probs + eps) + return res + + +def prepare_feats(dataset, nhop, norm="sym"): + print("Preprocessing features...") + dataset_name = dataset.__class__.__name__ + data = dataset.data + if not osp.exists("./sagn"): + os.makedirs("sagn", exist_ok=True) + + feat_emb_path = f"./sagn/feats_{dataset_name}_hop_{nhop}.pt" + is_inductive = data.is_inductive() + train_nid = data.train_nid + + if is_inductive: + if osp.exists(feat_emb_path): + print("Loading existing features") + feats = torch.load(feat_emb_path) + else: + data.train() + feats_train = average_neighbor_features(data, data.x, nhop, norm=norm, style="all") + data.eval() + feats = average_neighbor_features(data, data.x, nhop, norm=norm, style="all") + + # train_nid, val_nid, test_nid = data.train_nid, data.val_nid, data.test_nid + for i in range(len(feats)): + feats[i][train_nid] = feats_train[i][train_nid] + # feats_train, feats_val, feats_test = feats[train_nid], feats[val_nid], feats[test_nid] + # feats[i] = torch.cat([feats_train, feats_val, feats_test], dim=0) + torch.save(feats, feat_emb_path) + else: + if osp.exists(feat_emb_path): + feats = torch.load(feat_emb_path) + else: + feats = average_neighbor_features(data, data.x, nhop, norm=norm, style="all") + print("Preprocessing features done...") + return feats + + +def prepare_labels(dataset, stage, nhop, threshold, probs=None, norm="row", load_emb=False): + dataset_name = dataset.__class__.__name__ + data = dataset.data + is_inductive = data.is_inductive() + multi_label = len(data.y.shape) > 1 + + device = data.x.device + num_classes = data.num_classes + train_nid = data.train_nid + val_nid = data.val_nid + test_nid = data.test_nid + + if not osp.exists("./sagn"): + os.makedirs("sagn", exist_ok=True) + label_emb_path = f"./sagn/label_emb_{dataset_name}.pt" + # teacher_prob_path = f"./sagn/teacher_prob_{dataset_name}.pt" + teacher_probs = probs + + if stage > 0 and probs is not None: + # teacher_probs = torch.load(teacher_prob_path) + node_idx = torch.cat([train_nid, val_nid, test_nid], dim=0) + if multi_label: + threshold = -threshold * np.log(threshold) - (1 - threshold) * np.log(1 - threshold) + entropy_distribution = entropy(teacher_probs) + confident_nid = torch.arange(len(teacher_probs))[(entropy_distribution.mean(1) <= threshold)] + else: + confident_nid = torch.arange(len(teacher_probs))[teacher_probs.max(1)[0] > threshold] + extra_confident_nid = confident_nid[confident_nid >= len(train_nid)] + confident_nid = node_idx[confident_nid] + extra_confident_nid = node_idx[extra_confident_nid] + + if multi_label: + pseudo_labels = teacher_probs + pseudo_labels[pseudo_labels >= 0.5] = 1 + pseudo_labels[pseudo_labels < 0.5] = 0 + labels_with_pseudos = torch.ones_like(data.y) + else: + pseudo_labels = torch.argmax(teacher_probs, dim=1) + labels_with_pseudos = torch.zeros_like(data.y) + train_nid_with_pseudos = np.union1d(train_nid.cpu().numpy(), confident_nid.cpu().numpy()) + train_nid_with_pseudos = torch.from_numpy(train_nid_with_pseudos).to(device) + labels_with_pseudos[train_nid] = data.y[train_nid] + labels_with_pseudos[extra_confident_nid] = pseudo_labels[extra_confident_nid] + else: + # confident_nid = train_nid + train_nid_with_pseudos = train_nid + labels_with_pseudos = data.y.clone() + # teacher_probs = None + # pseudo_labels = None + + if (not is_inductive) or stage > 0: + if multi_label: + label_emb = 0.5 * torch.ones((data.num_nodes, num_classes), device=device) + label_emb[train_nid_with_pseudos] = labels_with_pseudos.float()[train_nid_with_pseudos] + else: + label_emb = torch.zeros((data.num_nodes, num_classes), device=device) + label_emb[train_nid_with_pseudos] = F.one_hot( + labels_with_pseudos[train_nid_with_pseudos], num_classes=num_classes + ).float() + else: + label_emb = None + + if is_inductive: + if osp.exists(label_emb_path) and load_emb: + label_emb = torch.load(label_emb_path) + elif label_emb is not None: + data.train() + label_emb_train = average_neighbor_features(data, label_emb, nhop, norm=norm, style="last") + data.eval() + label_emb = average_neighbor_features(data, label_emb, nhop, norm=norm, style="last") + label_emb[train_nid] = label_emb_train[train_nid] + if load_emb: + torch.save(label_emb, label_emb_path) + else: + if osp.exists(label_emb_path) and load_emb: + label_emb = torch.load(label_emb_path) + elif label_emb is not None: + if label_emb is not None: + label_emb = average_neighbor_features(data, label_emb, nhop, norm=norm, style="last") + if stage == 0 and load_emb: + torch.save(label_emb, label_emb_path) + + return label_emb, labels_with_pseudos, train_nid_with_pseudos + + +@register_model("sagn") +class SAGN(BaseModel): + @staticmethod + def add_args(parser): + parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--negative-slope", type=float, default=0.2) + parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument("--input-drop", type=float, default=0.0) + parser.add_argument("--attn-drop", type=float, default=0.4) + parser.add_argument("--nhead", type=int, default=2) + parser.add_argument("--mlp-layer", type=int, default=4) + parser.add_argument("--use-labels", action="store_true") + parser.add_argument("--nhop", type=int, default=4) + + @classmethod + def build_model_from_args(cls, args): + return cls( + args.num_features, + args.num_classes, + args.hidden_size, + args.nhop, + args.mlp_layer, + args.nhead, + args.dropout, + args.input_drop, + args.attn_drop, + args.negative_slope, + args.use_labels, + ) + + def __init__( + self, + in_feats, + out_feats, + hidden_size, + nhop, + mlp_layer, + nhead, + dropout=0.5, + input_drop=0.0, + attn_drop=0.0, + negative_slope=0.2, + use_labels=False, + ): + super(SAGN, self).__init__() + self.dropout = dropout + self.nhead = nhead + self.hidden_size = hidden_size + self.attn_dropout = attn_drop + self.input_dropout = input_drop + self.use_labels = use_labels + self.negative_slope = negative_slope + + self.norm = nn.BatchNorm1d(hidden_size) + self.layers = nn.ModuleList( + [ + MLP(in_feats, hidden_size * nhead, hidden_size, mlp_layer, norm="batchnorm", dropout=dropout) + for _ in range(nhop + 1) + ] + ) + + self.mlp = MLP(hidden_size, out_feats, hidden_size, mlp_layer, norm="batchnorm", dropout=dropout) + self.res_conn = nn.Linear(in_feats, hidden_size * nhead, bias=False) + if use_labels: + self.label_mlp = MLP(out_feats, out_feats, hidden_size, 2 * mlp_layer, norm="batchnorm", dropout=dropout) + + self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, nhead, hidden_size))) + self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, nhead, hidden_size))) + + def reset_parameters(self): + gain = nn.init.calculate_gain("relu") + for layer in self.layers: + layer.reset_parameters() + nn.init.xavier_normal_(self.attn_r, gain=gain) + nn.init.xavier_normal_(self.attn_l, gain=gain) + if self.use_labels: + self.label_mlp.reset_parameters() + self.norm.reset_parameters() + self.mlp.reset_parameters() + + def forward(self, features, y_emb=None): + out = 0 + features = [F.dropout(x, p=self.input_dropout, training=self.training) for x in features] + hidden = [self.layers[i](features[i]).view(-1, self.nhead, self.hidden_size) for i in range(len(features))] + a_r = (hidden[0] * self.attn_r).sum(dim=-1).unsqueeze(-1) + a_ls = [(h * self.attn_l).sum(dim=-1).unsqueeze(-1) for h in hidden] + a = torch.cat([(a_l + a_r).unsqueeze(-1) for a_l in a_ls], dim=-1) + a = F.leaky_relu(a, negative_slope=self.negative_slope) + a = F.softmax(a, dim=-1) + a = F.dropout(a, p=self.attn_dropout, training=self.training) + + for i in range(a.shape[-1]): + out += hidden[i] * a[:, :, :, i] + out += self.res_conn(features[0]).view(-1, self.nhead, self.hidden_size) + out = out.mean(1) + out = F.relu(self.norm(out)) + out = F.dropout(out, p=self.dropout, training=self.training) + out = self.mlp(out) + + if self.use_labels and y_emb is not None: + out += self.label_mlp(y_emb) + return out + + @staticmethod + def get_trainer(args=None): + return SAGNTrainer + + +# @register_trainer("sagn_trainer") +class SAGNTrainer(BaseTrainer): + @staticmethod + def add_args(parser): + parser.add_argument("--nstage", type=int, nargs="+", default=[1000, 500, 500]) + parser.add_argument("--batch-size", type=int, default=2000) + parser.add_argument( + "--threshold", type=float, default=0.9, help="threshold used to generate pseudo hard labels" + ) + parser.add_argument("--label-nhop", type=int, default=4) + parser.add_argument("--data-gpu", action="store_true") + + @classmethod + def build_trainer_from_args(cls, args): + return cls(args) + + def __init__(self, args): + super(SAGNTrainer, self).__init__(args) + self.batch_size = args.batch_size + self.nstage = args.nstage + self.nhop = args.nhop + self.threshold = args.threshold + self.data_device = self.device if args.data_gpu else "cpu" + self.label_nhop = args.label_nhop if args.label_nhop > -1 else args.nhop + + def fit(self, model, dataset): + data = dataset.data + self.model = model.to(self.device) + self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.weight_decay) + self.loss_fn = dataset.get_loss_fn() + self.evaluator = dataset.get_evaluator() + + data.to(self.data_device) + feats = prepare_feats(dataset, self.nhop) + + train_nid, val_nid, test_nid = data.train_nid, data.val_nid, data.test_nid + all_nid = torch.cat([train_nid, val_nid, test_nid]) + + val_loader = torch.utils.data.DataLoader(val_nid, batch_size=self.batch_size, shuffle=False) + test_loader = torch.utils.data.DataLoader(test_nid, batch_size=self.batch_size, shuffle=False) + all_loader = torch.utils.data.DataLoader(all_nid, batch_size=self.batch_size, shuffle=False) + patience = 0 + best_val = 0 + best_model = None + probs = None + + test_metric_list = [] + for stage in range(len(self.nstage)): + print(f"In stage {stage}..") + with torch.no_grad(): + (label_emb, labels_with_pseudos, train_nid_with_pseudos) = prepare_labels( + dataset, stage, self.label_nhop, self.threshold, probs=probs + ) + + labels_with_pseudos = labels_with_pseudos.to(self.data_device) + if label_emb is not None: + label_emb = label_emb.to(self.data_device) + + epoch_iter = tqdm(range(self.nstage[stage])) + for epoch in epoch_iter: + train_loader = torch.utils.data.DataLoader( + train_nid_with_pseudos.cpu(), batch_size=self.batch_size, shuffle=True + ) + self.train_step(train_loader, feats, label_emb, labels_with_pseudos) + val_loss, val_metric = self.test_step(val_loader, feats, label_emb, data.y[val_nid]) + if val_metric > best_val: + best_val = val_metric + best_model = copy.deepcopy(model) + patience = 0 + else: + patience += 1 + if patience > self.patience: + epoch_iter.close() + break + epoch_iter.set_description(f"Epoch: {epoch: 03d}, ValLoss: {val_loss: .4f}, ValAcc: {val_metric: .4f}") + temp_model = self.model + self.model = best_model + test_loss, test_acc = self.test_step(test_loader, feats, label_emb, data.y[test_nid]) + test_metric_list.append(round(test_acc, 4)) + + self.model = temp_model + probs = self.test_step(all_loader, feats, label_emb, data.y[all_nid], return_probs=True) + test_metric = ", ".join([str(x) for x in test_metric_list]) + print(test_metric) + + return dict(Acc=test_metric_list[-1]) + + def train_step(self, train_loader, feats, label_emb, y): + device = next(self.model.parameters()).device + self.model.train() + for batch in train_loader: + self.optimizer.zero_grad() + batch = batch.to(device) + batch_x = [x[batch].to(device) for x in feats] + + if label_emb is not None: + batch_y_emb = label_emb[batch].to(device) + else: + batch_y_emb = None + pred = self.model(batch_x, batch_y_emb) + loss = self.loss_fn(pred, y[batch].to(device)) + loss.backward() + self.optimizer.step() + + def test_step(self, eval_loader, feats, label_emb, y, return_probs=False): + self.model.eval() + preds = [] + + device = next(self.model.parameters()).device + with torch.no_grad(): + for batch in eval_loader: + batch = batch.to(device) + batch_x = [x[batch].to(device) for x in feats] + if label_emb is not None: + batch_y_emb = label_emb[batch].to(device) + else: + batch_y_emb = None + pred = self.model(batch_x, batch_y_emb) + preds.append(pred.to(self.data_device)) + preds = torch.cat(preds, dim=0) + if return_probs: + return preds + loss = self.loss_fn(preds, y) + metric = self.evaluator(preds, y) + return loss, metric diff --git a/cogdl/models/nn/self_auxiliary_task.py b/cogdl/models/nn/self_auxiliary_task.py index 31648acb..b42ce762 100644 --- a/cogdl/models/nn/self_auxiliary_task.py +++ b/cogdl/models/nn/self_auxiliary_task.py @@ -437,5 +437,5 @@ def get_parameters(self): return list(self.gcn.parameters()) + list(self.agent.linear.parameters()) @staticmethod - def get_trainer(task, args): + def get_trainer(args): return SelfSupervisedJointTrainer diff --git a/cogdl/models/nn/sign.py b/cogdl/models/nn/sign.py index ef11c256..f999fc72 100644 --- a/cogdl/models/nn/sign.py +++ b/cogdl/models/nn/sign.py @@ -1,60 +1,68 @@ +import os + import torch -import torch.nn.functional as F from .. import BaseModel, register_model -from cogdl.utils import ( - spmm, - dropout_adj, -) +from .mlp import MLP +from cogdl.utils import spmm, dropout_adj, to_undirected -def get_adj(graph, asymm_norm=False, set_diag=True, remove_diag=False): - if set_diag: - graph.add_remaining_self_loops() - elif remove_diag: +def get_adj(graph, remove_diag=False): + if remove_diag: graph.remove_self_loops() - if asymm_norm: - graph.row_norm() else: - graph.sym_norm() + graph.add_remaining_self_loops() return graph +def multi_hop_sgc(graph, x, nhop): + results = [] + for _ in range(nhop): + x = spmm(graph, x) + results.append(x) + return results + + +def multi_hop_ppr_diffusion(graph, x, nhop, alpha=0.5): + results = [] + for _ in range(nhop): + x = (1 - alpha) * x + spmm(graph, x) + results.append(x) + return results + + @register_model("sign") -class MLP(BaseModel): +class SIGN(BaseModel): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off - parser.add_argument('--num-features', type=int) - parser.add_argument("--num-classes", type=int) - parser.add_argument('--hidden-size', type=int, default=512) - parser.add_argument('--num-layers', type=int, default=3) - parser.add_argument('--dropout', type=float, default=0.3) - parser.add_argument('--dropedge-rate', type=float, default=0.2) - - parser.add_argument('--directed', action='store_true') - parser.add_argument('--num-propagations', type=int, default=1) - parser.add_argument('--asymm-norm', action='store_true') - parser.add_argument('--set-diag', action='store_true') - parser.add_argument('--remove-diag', action='store_true') - + MLP.add_args(parser) + parser.add_argument("--dropedge-rate", type=float, default=0.2) + parser.add_argument("--directed", action="store_true") + parser.add_argument("--nhop", type=int, default=3) + parser.add_argument("--adj-norm", type=str, default=["sym"], nargs="+") + parser.add_argument("--remove-diag", action="store_true") + parser.add_argument("--diffusion", type=str, default="ppr") # fmt: on @classmethod def build_model_from_args(cls, args): + cls.dataset_name = args.dataset if hasattr(args, "dataset") else None return cls( args.num_features, args.hidden_size, args.num_classes, args.num_layers, args.dropout, - args.directed, args.dropedge_rate, - args.num_propagations, - args.asymm_norm, - args.set_diag, + args.nhop, + args.adj_norm, + args.diffusion, args.remove_diag, + not args.directed, + args.norm, + args.activation, ) def __init__( @@ -65,76 +73,94 @@ def __init__( num_layers, dropout, dropedge_rate, - undirected, - num_propagations, - asymm_norm, - set_diag, - remove_diag, + nhop, + adj_norm, + diffusion="ppr", + remove_diag=False, + undirected=True, + norm="batchnorm", + activation="relu", ): - - super(MLP, self).__init__() - - self.dropout = dropout + super(SIGN, self).__init__() self.dropedge_rate = dropedge_rate self.undirected = undirected - self.num_propagations = num_propagations - self.asymm_norm = asymm_norm - self.set_diag = set_diag + self.num_propagations = nhop + self.adj_norm = adj_norm self.remove_diag = remove_diag - - self.lins = torch.nn.ModuleList() - self.lins.append(torch.nn.Linear((1 + 2 * self.num_propagations) * num_features, hidden_size)) - self.bns = torch.nn.ModuleList() - self.bns.append(torch.nn.BatchNorm1d(hidden_size)) - for _ in range(num_layers - 2): - self.lins.append(torch.nn.Linear(hidden_size, hidden_size)) - self.bns.append(torch.nn.BatchNorm1d(hidden_size)) - self.lins.append(torch.nn.Linear(hidden_size, num_classes)) + self.diffusion = diffusion + + num_features = num_features * (1 + nhop * len(adj_norm)) + self.mlp = MLP( + in_feats=num_features, + out_feats=num_classes, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + activation=activation, + norm=norm, + ) self.cache_x = None - def reset_parameters(self): - for lin in self.lins: - lin.reset_parameters() - for bn in self.bns: - bn.reset_parameters() - - def _preprocessing(self, graph, x): - op_embedding = [] - op_embedding.append(x) - - edge_index = graph.edge_index + def _preprocessing(self, graph, x, drop_edge=False): + device = x.device + graph.to("cpu") + x = x.to("cpu") - # Convert to numpy arrays on cpu - edge_index, _ = dropout_adj(edge_index, drop_rate=self.dropedge_rate) + graph.eval() - # if self.undirected: - # edge_index = to_undirected(edge_index, num_nodes) + op_embedding = [x] - graph = get_adj(graph, asymm_norm=self.asymm_norm, set_diag=self.set_diag, remove_diag=self.remove_diag) - - with graph.local_graph(): - graph.edge_index = edge_index - for _ in range(self.num_propagations): - x = spmm(graph, x) - op_embedding.append(x) - - for _ in range(self.num_propagations): - nx = spmm(graph, x) - op_embedding.append(nx) - - return torch.cat(op_embedding, dim=1) + edge_index = graph.edge_index + if self.undirected: + edge_index = to_undirected(edge_index) + + if drop_edge: + edge_index, _ = dropout_adj(edge_index, drop_rate=self.dropedge_rate) + + graph = get_adj(graph, remove_diag=self.remove_diag) + + for norm in self.adj_norm: + with graph.local_graph(): + graph.edge_index = edge_index + graph.normalize(norm) + if self.diffusion == "ppr": + results = multi_hop_ppr_diffusion(graph, graph.x, self.num_propagations) + else: + results = multi_hop_sgc(graph, graph.x, self.num_propagations) + op_embedding.extend(results) + + graph.to(device) + return torch.cat(op_embedding, dim=1).to(device) + + def preprocessing(self, graph, x): + print("Preprocessing...") + dataset_name = None + if self.dataset_name is not None: + adj_norm = ",".join(self.adj_norm) + dataset_name = f"{self.dataset_name}_{self.num_propagations}_{self.diffusion}_{adj_norm}.pt" + if os.path.exists(dataset_name): + return torch.load(dataset_name).to(x.device) + if graph.is_inductive(): + graph.train() + x_train = self._preprocessing(graph, x, drop_edge=True) + graph.eval() + x_all = self._preprocessing(graph, x, drop_edge=False) + train_nid = graph.train_nid + x_all[train_nid] = x_train[train_nid] + else: + x_all = self._preprocessing(graph, x, drop_edge=False) + + if dataset_name is not None: + torch.save(x_all.cpu(), dataset_name) + print("Preprocessing Done...") + return x_all def forward(self, graph): if self.cache_x is None: - x = graph.x - self.cache_x = self._preprocessing(graph, x) + x = graph.x.contiguous() + self.cache_x = self.preprocessing(graph, x) x = self.cache_x - for i, lin in enumerate(self.lins[:-1]): - x = lin(x) - x = self.bns[i](x) - x = F.relu(x) - x = F.dropout(x, p=self.dropout, training=self.training) - x = self.lins[-1](x) + x = self.mlp(x) return x diff --git a/cogdl/models/nn/unsup_graphsage.py b/cogdl/models/nn/unsup_graphsage.py index dd2c2657..8604470e 100644 --- a/cogdl/models/nn/unsup_graphsage.py +++ b/cogdl/models/nn/unsup_graphsage.py @@ -133,5 +133,5 @@ def sampling(self, edge_index, num_sample): return sage_sampler(self.adjlist, edge_index, num_sample) @staticmethod - def get_trainer(taskType, args): + def get_trainer(args): return SelfSupervisedPretrainer diff --git a/cogdl/models/self_supervised_model.py b/cogdl/models/self_supervised_model.py index 30ea6430..7c118c30 100644 --- a/cogdl/models/self_supervised_model.py +++ b/cogdl/models/self_supervised_model.py @@ -8,7 +8,7 @@ def self_supervised_loss(self, data): raise NotImplementedError @staticmethod - def get_trainer(task, args): + def get_trainer(args): return None diff --git a/cogdl/models/supervised_model.py b/cogdl/models/supervised_model.py index 3714c280..1a1b0ca6 100644 --- a/cogdl/models/supervised_model.py +++ b/cogdl/models/supervised_model.py @@ -27,7 +27,7 @@ def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any: raise NotImplementedError @staticmethod - def get_trainer(taskType: Any, args: Any) -> "Optional[Type[SupervisedHeterogeneousNodeClassificationTrainer]]": + def get_trainer(args: Any = None) -> "Optional[Type[SupervisedHeterogeneousNodeClassificationTrainer]]": return None @@ -41,8 +41,5 @@ def predict(self, data: Any) -> Any: raise NotImplementedError @staticmethod - def get_trainer( - taskType: Any, - args: Any, - ) -> "Optional[Type[SupervisedHomogeneousNodeClassificationTrainer]]": + def get_trainer(args: Any = None) -> "Optional[Type[SupervisedHomogeneousNodeClassificationTrainer]]": return None diff --git a/cogdl/options.py b/cogdl/options.py index ff0c7cb6..04ed2c05 100644 --- a/cogdl/options.py +++ b/cogdl/options.py @@ -129,7 +129,7 @@ def parse_args_and_arch(parser, args): TRAINER_REGISTRY[args.trainer].add_args(parser) else: for model in args.model: - tr = MODEL_REGISTRY[model].get_trainer(None, None) + tr = MODEL_REGISTRY[model].get_trainer(args) if tr is not None: tr.add_args(parser) # Parse a second time. diff --git a/cogdl/pipelines.py b/cogdl/pipelines.py index 73857ab9..1c3b09d8 100644 --- a/cogdl/pipelines.py +++ b/cogdl/pipelines.py @@ -174,7 +174,7 @@ def __init__(self, app: str, model: str, **kwargs): args.model = args.model[0] self.model = build_model(args) - self.trainer = self.model.get_trainer(self.model, args) + self.trainer = self.model.get_trainer(args) if self.trainer is not None: self.trainer = self.trainer(args) diff --git a/cogdl/tasks/attributed_graph_clustering.py b/cogdl/tasks/attributed_graph_clustering.py index c8eb7c67..3041d24e 100644 --- a/cogdl/tasks/attributed_graph_clustering.py +++ b/cogdl/tasks/attributed_graph_clustering.py @@ -95,7 +95,7 @@ def train(self) -> Dict[str, float]: features_matrix = torch.tensor(features_matrix) features_matrix = F.normalize(features_matrix, p=2, dim=1) else: - trainer = self.model.get_trainer(AttributedGraphClustering, self.args)(self.args) + trainer = self.model.get_trainer(self.args)(self.args) self.model = trainer.fit(self.model, self.data) features_matrix = self.model.get_features(self.data) diff --git a/cogdl/tasks/base_task.py b/cogdl/tasks/base_task.py index 0a91a1a6..2b4de0a2 100644 --- a/cogdl/tasks/base_task.py +++ b/cogdl/tasks/base_task.py @@ -63,12 +63,12 @@ def set_loss_fn(self, dataset): def set_evaluator(self, dataset): self.evaluator = dataset.get_evaluator() - def get_trainer(self, model, args): + def get_trainer(self, args): if hasattr(args, "trainer") and args.trainer is not None: - if "self_auxiliary_task" in args.trainer and not hasattr(model, "embed"): + if "self_auxiliary_task" in args.trainer and not hasattr(self.model, "embed"): raise ValueError("Model ({}) must implement embed method".format(args.model)) return build_trainer(args) - elif model.get_trainer(None, args) is not None: - return model.get_trainer(None, args)(args) + elif self.model.get_trainer(args) is not None: + return self.model.get_trainer(args)(args) else: return None diff --git a/cogdl/tasks/heterogeneous_node_classification.py b/cogdl/tasks/heterogeneous_node_classification.py index 66d685f6..d8c933b2 100644 --- a/cogdl/tasks/heterogeneous_node_classification.py +++ b/cogdl/tasks/heterogeneous_node_classification.py @@ -41,11 +41,7 @@ def __init__(self, args, dataset=None, model=None): model = build_model(args) if model is None else model self.model: SupervisedHeterogeneousNodeClassificationModel = model.to(self.device) - self.trainer = ( - self.model.get_trainer(HeterogeneousNodeClassification, args)(self.args) - if self.model.get_trainer(HeterogeneousNodeClassification, args) - else None - ) + self.trainer = self.model.get_trainer(args)(args) if self.model.get_trainer(args) else None self.patience = args.patience self.max_epoch = args.max_epoch diff --git a/cogdl/tasks/node_classification.py b/cogdl/tasks/node_classification.py index 30eb3f69..50e6c539 100644 --- a/cogdl/tasks/node_classification.py +++ b/cogdl/tasks/node_classification.py @@ -51,7 +51,7 @@ def __init__( self.set_loss_fn(dataset) self.set_evaluator(dataset) - self.trainer = self.get_trainer(self.model, self.args) + self.trainer = self.get_trainer(self.args) if not self.trainer: self.optimizer = ( torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) @@ -110,8 +110,9 @@ def train(self): break print(f"Valid accurracy = {best_score: .4f}") self.model = best_model - test_acc, _ = self._test_step(split="test") - val_acc, _ = self._test_step(split="val") + acc, _ = self._test_step(post=True) + val_acc, test_acc = acc["val"], acc["test"] + print(f"Test accuracy = {test_acc:.4f}") return dict(Acc=test_acc, ValAcc=val_acc) @@ -123,11 +124,13 @@ def _train_step(self): torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) self.optimizer.step() - def _test_step(self, split=None, logits=None): + def _test_step(self, split=None, post=False): self.data.eval() self.model.eval() with torch.no_grad(): - logits = logits if logits else self.model.predict(self.data) + logits = self.model.predict(self.data) + if post and hasattr(self.model, "postprocess"): + logits = self.model.postprocess(self.data, logits) if split == "train": mask = self.data.train_mask elif split == "val": @@ -142,7 +145,6 @@ def _test_step(self, split=None, logits=None): metric = self.evaluator(logits[mask], self.data.y[mask]) return metric, loss else: - masks = {x: self.data[x + "_mask"] for x in ["train", "val", "test"]} metrics = {key: self.evaluator(logits[mask], self.data.y[mask]) for key, mask in masks.items()} losses = {key: self.loss_fn(logits[mask], self.data.y[mask]) for key, mask in masks.items()} diff --git a/cogdl/tasks/unsupervised_node_classification.py b/cogdl/tasks/unsupervised_node_classification.py index 831b55c5..2478b73c 100644 --- a/cogdl/tasks/unsupervised_node_classification.py +++ b/cogdl/tasks/unsupervised_node_classification.py @@ -67,7 +67,7 @@ def __init__(self, args, dataset=None, model=None): self.is_weighted = self.data.edge_attr is not None self.device = "cpu" if not torch.cuda.is_available() or args.cpu else args.device_id[0] - self.trainer = self.get_trainer(self.model, args) + self.trainer = self.get_trainer(args) def enhance_emb(self, G, embs): A = sp.csr_matrix(nx.adjacency_matrix(G)) diff --git a/cogdl/trainers/base_trainer.py b/cogdl/trainers/base_trainer.py index 886cd504..fe4e3060 100644 --- a/cogdl/trainers/base_trainer.py +++ b/cogdl/trainers/base_trainer.py @@ -1,7 +1,22 @@ from abc import ABC, abstractmethod +import torch class BaseTrainer(ABC): + def __init__(self, args=None): + if args is not None: + device_id = args.device_id if hasattr(args, "device_id") else [0] + self.device = ( + "cpu" if not torch.cuda.is_available() or (hasattr(args, "cpu") and args.cpu) else device_id[0] + ) + self.patience = args.patience if hasattr(args, "patience") else 10 + self.max_epoch = args.max_epoch if hasattr(args, "max_epoch") else 100 + self.lr = args.lr + self.weight_decay = args.weight_decay + self.loss_fn, self.evaluator = None, None + self.data, self.train_loader, self.optimizer = None, None, None + self.num_workers = args.num_workers if hasattr(args, "num_workers") else 0 + @classmethod @abstractmethod def build_trainer_from_args(cls, args): diff --git a/cogdl/trainers/sampled_trainer.py b/cogdl/trainers/sampled_trainer.py index 10403666..90e3ec99 100644 --- a/cogdl/trainers/sampled_trainer.py +++ b/cogdl/trainers/sampled_trainer.py @@ -39,6 +39,7 @@ def _test_step(self, split="val"): raise NotImplementedError def __init__(self, args): + super(SampledTrainer, self).__init__(args) self.device = "cpu" if not torch.cuda.is_available() or args.cpu else args.device_id[0] self.patience = args.patience self.max_epoch = args.max_epoch diff --git a/cogdl/utils/graph_utils.py b/cogdl/utils/graph_utils.py index 92ab1213..558c9e47 100644 --- a/cogdl/utils/graph_utils.py +++ b/cogdl/utils/graph_utils.py @@ -72,7 +72,7 @@ def add_remaining_self_loops(edge_index, edge_weight=None, fill_value=1, num_nod def row_normalization(num_nodes, row, col, val=None): device = row.device if val is None: - val = torch.ones(row.shape[0]).to(device) + val = torch.ones(row.shape[0], device=device) row_sum = get_degrees(row, col, num_nodes) row_sum_inv = row_sum.pow(-1).view(-1) row_sum_inv[torch.isinf(row_sum_inv)] = 0 diff --git a/tests/datasets/test_ogb.py b/tests/datasets/test_ogb.py index 51a3997c..5a60fd56 100644 --- a/tests/datasets/test_ogb.py +++ b/tests/datasets/test_ogb.py @@ -8,7 +8,6 @@ def test_ogbn_arxiv(): dataset = build_dataset(args) data = dataset.data assert data.num_nodes == 169343 - assert data.num_edges == 1136420 def test_ogbg_molhiv(): diff --git a/tests/tasks/test_node_classification.py b/tests/tasks/test_node_classification.py index 5e6d7658..2a805d39 100644 --- a/tests/tasks/test_node_classification.py +++ b/tests/tasks/test_node_classification.py @@ -122,7 +122,7 @@ def test_graphsage_cora(): args.batch_size = 128 args.num_layers = 2 args.patience = 1 - args.max_epoch = 5 + args.max_epoch = 2 args.hidden_size = [32, 32] args.sample_size = [3, 5] args.num_workers = 1 @@ -283,7 +283,7 @@ def test_graph_mix(): args.task = "node_classification" args.dataset = "cora" args.model = "gcnmix" - args.max_epoch = 10 + args.max_epoch = 2 args.rampup_starts = 1 args.rampup_ends = 100 args.mixup_consistency = 5.0 @@ -362,7 +362,7 @@ def test_deepergcn_cora(): args.num_layers = 2 args.connection = "res+" args.cluster_number = 3 - args.max_epoch = 10 + args.max_epoch = 2 args.patience = 1 args.learn_beta = True args.learn_msg_scale = True @@ -436,16 +436,22 @@ def test_sign_cora(): args.lr = 0.00005 args.hidden_size = 2048 args.num_layers = 3 - args.num_propagations = 3 + args.nhop = 3 args.dropout = 0.3 args.directed = False args.dropedge_rate = 0.2 - args.asymm_norm = False - args.set_diag = False + args.adj_norm = [ + "sym", + ] args.remove_diag = False + args.diffusion = "ppr" + task = build_task(args) ret = task.train() assert 0 < ret["Acc"] < 1 + args.diffusion = "sgc" + ret = task.train() + assert 0 < ret["Acc"] < 1 def test_jknet_jknet_cora(): @@ -633,6 +639,52 @@ def test_gcn_ppi(): assert 0 <= task.train()["Acc"] <= 1 +def build_custom_dataset(): + args = get_default_args() + args.dataset = "cora" + dataset = build_dataset(args) + dataset.data._adj_train = dataset.data._adj_full + return dataset + + +def test_sagn_cora(): + args = get_default_args() + dataset = build_custom_dataset() + args.model = "sagn" + args.nhop = args.label_nhop = 2 + args.threshold = 0.5 + args.use_labels = True + args.nstage = [2, 2] + args.batch_size = 32 + args.data_gpu = False + args.attn_drop = 0.0 + args.input_drop = 0.0 + args.nhead = 2 + args.negative_slope = 0.2 + args.mlp_layer = 2 + task = build_task(args, dataset=dataset) + assert 0 <= task.train()["Acc"] <= 1 + + +def test_c_s_cora(): + args = get_default_args() + args.use_embeddings = True + args.correct_alpha = 0.5 + args.smooth_alpha = 0.5 + args.num_correct_prop = 2 + args.num_smooth_prop = 2 + args.correct_norm = "sym" + args.smooth_norm = "sym" + args.scale = 1.0 + args.autoscale = True + args.dataset = "cora" + args.model = "correct_smooth_mlp" + task = build_task(args) + assert 0 <= task.train()["Acc"] <= 1 + args.autoscale = False + assert 0 <= task.train()["Acc"] <= 1 + + if __name__ == "__main__": test_gdc_gcn_cora() test_gcn_cora()