From 35a338f29066e4b1a5d7f46217f09ebceaf13106 Mon Sep 17 00:00:00 2001 From: Yukuo Cen Date: Tue, 24 Aug 2021 15:21:30 +0800 Subject: [PATCH] [Bugfix] Fix MoEGCN & actnn import (#271) --- cogdl/models/nn/moe_gcn.py | 94 +++++--------------------------------- cogdl/operators/spmm.py | 6 ++- 2 files changed, 16 insertions(+), 84 deletions(-) diff --git a/cogdl/models/nn/moe_gcn.py b/cogdl/models/nn/moe_gcn.py index 6e524c84..060b7a65 100644 --- a/cogdl/models/nn/moe_gcn.py +++ b/cogdl/models/nn/moe_gcn.py @@ -1,15 +1,10 @@ -import math - -import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.parameter import Parameter +from cogdl.layers import GCNLayer +from cogdl.utils import get_activation +from fmoe import FMoETransformerMLP from .. import BaseModel, register_model -from cogdl.utils import spmm, get_activation - - -from fmoe import FMoETransformerMLP class CustomizedMoEPositionwiseFF(FMoETransformerMLP): @@ -33,74 +28,11 @@ def forward(self, inp): return output -class GraphConv(nn.Module): - """ - Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 - """ - - def __init__(self, in_features, out_features, dropout=0.0, residual=False, activation=None, norm=None, bias=True): - super(GraphConv, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = Parameter(torch.FloatTensor(in_features, out_features)) - self.act = get_activation(activation) - if dropout > 0: - self.dropout = nn.Dropout(dropout) - else: - self.dropout = None - if residual: - self.residual = nn.Linear(in_features, out_features) - else: - self.residual = None - - if norm is not None: - if norm == "batchnorm": - self.norm = nn.BatchNorm1d(out_features) - elif norm == "layernorm": - self.norm = nn.LayerNorm(out_features) - else: - raise NotImplementedError - else: - self.norm = None - - if bias: - self.bias = Parameter(torch.FloatTensor(out_features)) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.weight.size(1)) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - - def forward(self, graph, x): - support = torch.mm(x, self.weight) - out = spmm(graph, support) - if self.bias is not None: - out = out + self.bias - - if self.residual is not None: - res = self.residual(x) - if self.act is not None: - res = self.act(res) - out = out + res - - if self.dropout is not None: - out = self.dropout(out) - - if self.norm is not None: - out = self.norm(out) - - return out - - class GraphConvBlock(nn.Module): - def __init__(self, conv_func, conv_params, out_feats, dropout=0.0, in_feats=None, residual=False): + def __init__(self, conv_func, conv_params, in_feats, out_feats, dropout=0.0, residual=False): super(GraphConvBlock, self).__init__() - self.graph_conv = conv_func(**conv_params) + self.graph_conv = conv_func(**conv_params, in_features=in_feats, out_features=out_feats) self.pos_ff = CustomizedMoEPositionwiseFF(out_feats, out_feats * 2, dropout, moe_num_expert=64, moe_top_k=2) self.dropout = dropout if residual is True: @@ -170,11 +102,9 @@ def __init__( self, in_feats, hidden_size, out_feats, num_layers, dropout, activation="relu", residual=True, norm=None ): super(MoEGCN, self).__init__() - shapes = [in_feats] + [hidden_size] * (num_layers - 1) + [out_feats] - conv_func = GraphConv + shapes = [in_feats] + [hidden_size] * num_layers + conv_func = GCNLayer conv_params = { - "in_features": in_feats, - "out_features": out_feats, "dropout": dropout, "norm": norm, "residual": residual, @@ -185,8 +115,9 @@ def __init__( GraphConvBlock( conv_func, conv_params, + shapes[i], shapes[i + 1], - dropout=dropout if i != num_layers - 1 else 0, + dropout=dropout, ) for i in range(num_layers) ] @@ -194,13 +125,12 @@ def __init__( self.num_layers = num_layers self.dropout = dropout self.act = get_activation(activation) + self.final_cls = nn.Linear(hidden_size, out_feats) def get_embeddings(self, graph): graph.sym_norm() - h = graph.x for i in range(self.num_layers - 1): - h = F.dropout(h, self.dropout, training=self.training) h = self.layers[i](graph, h) return h @@ -209,9 +139,7 @@ def forward(self, graph): h = graph.x for i in range(self.num_layers): h = self.layers[i](graph, h) - if i != self.num_layers - 1: - h = self.act(h) - h = F.dropout(h, self.dropout, training=self.training) + h = self.final_cls(h) return h def predict(self, data): diff --git a/cogdl/operators/spmm.py b/cogdl/operators/spmm.py index 7ad0a7b7..948622c5 100644 --- a/cogdl/operators/spmm.py +++ b/cogdl/operators/spmm.py @@ -2,7 +2,6 @@ import numpy as np import torch from torch.utils.cpp_extension import load -from actnn.ops import quantize_activation, dequantize_activation path = os.path.join(os.path.dirname(__file__)) @@ -65,6 +64,11 @@ def backward(ctx, grad_out): return None, None, grad_feat, grad_edge_weight, None +try: + from actnn.ops import quantize_activation, dequantize_activation +except Exception: + pass + class ActSPMMFunction(torch.autograd.Function): @staticmethod def forward(ctx, rowptr, colind, feat, edge_weight_csr=None, sym=False):