Skip to content

Commit

Permalink
[Bugfix] Fix MoEGCN & actnn import (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
cenyk1230 authored Aug 24, 2021
1 parent 481c08f commit 35a338f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 84 deletions.
94 changes: 11 additions & 83 deletions cogdl/models/nn/moe_gcn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from cogdl.layers import GCNLayer
from cogdl.utils import get_activation
from fmoe import FMoETransformerMLP

from .. import BaseModel, register_model
from cogdl.utils import spmm, get_activation


from fmoe import FMoETransformerMLP


class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
Expand All @@ -33,74 +28,11 @@ def forward(self, inp):
return output


class GraphConv(nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""

def __init__(self, in_features, out_features, dropout=0.0, residual=False, activation=None, norm=None, bias=True):
super(GraphConv, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
self.act = get_activation(activation)
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
if residual:
self.residual = nn.Linear(in_features, out_features)
else:
self.residual = None

if norm is not None:
if norm == "batchnorm":
self.norm = nn.BatchNorm1d(out_features)
elif norm == "layernorm":
self.norm = nn.LayerNorm(out_features)
else:
raise NotImplementedError
else:
self.norm = None

if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()

def forward(self, graph, x):
support = torch.mm(x, self.weight)
out = spmm(graph, support)
if self.bias is not None:
out = out + self.bias

if self.residual is not None:
res = self.residual(x)
if self.act is not None:
res = self.act(res)
out = out + res

if self.dropout is not None:
out = self.dropout(out)

if self.norm is not None:
out = self.norm(out)

return out


class GraphConvBlock(nn.Module):
def __init__(self, conv_func, conv_params, out_feats, dropout=0.0, in_feats=None, residual=False):
def __init__(self, conv_func, conv_params, in_feats, out_feats, dropout=0.0, residual=False):
super(GraphConvBlock, self).__init__()

self.graph_conv = conv_func(**conv_params)
self.graph_conv = conv_func(**conv_params, in_features=in_feats, out_features=out_feats)
self.pos_ff = CustomizedMoEPositionwiseFF(out_feats, out_feats * 2, dropout, moe_num_expert=64, moe_top_k=2)
self.dropout = dropout
if residual is True:
Expand Down Expand Up @@ -170,11 +102,9 @@ def __init__(
self, in_feats, hidden_size, out_feats, num_layers, dropout, activation="relu", residual=True, norm=None
):
super(MoEGCN, self).__init__()
shapes = [in_feats] + [hidden_size] * (num_layers - 1) + [out_feats]
conv_func = GraphConv
shapes = [in_feats] + [hidden_size] * num_layers
conv_func = GCNLayer
conv_params = {
"in_features": in_feats,
"out_features": out_feats,
"dropout": dropout,
"norm": norm,
"residual": residual,
Expand All @@ -185,22 +115,22 @@ def __init__(
GraphConvBlock(
conv_func,
conv_params,
shapes[i],
shapes[i + 1],
dropout=dropout if i != num_layers - 1 else 0,
dropout=dropout,
)
for i in range(num_layers)
]
)
self.num_layers = num_layers
self.dropout = dropout
self.act = get_activation(activation)
self.final_cls = nn.Linear(hidden_size, out_feats)

def get_embeddings(self, graph):
graph.sym_norm()

h = graph.x
for i in range(self.num_layers - 1):
h = F.dropout(h, self.dropout, training=self.training)
h = self.layers[i](graph, h)
return h

Expand All @@ -209,9 +139,7 @@ def forward(self, graph):
h = graph.x
for i in range(self.num_layers):
h = self.layers[i](graph, h)
if i != self.num_layers - 1:
h = self.act(h)
h = F.dropout(h, self.dropout, training=self.training)
h = self.final_cls(h)
return h

def predict(self, data):
Expand Down
6 changes: 5 additions & 1 deletion cogdl/operators/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
import torch
from torch.utils.cpp_extension import load
from actnn.ops import quantize_activation, dequantize_activation

path = os.path.join(os.path.dirname(__file__))

Expand Down Expand Up @@ -65,6 +64,11 @@ def backward(ctx, grad_out):
return None, None, grad_feat, grad_edge_weight, None


try:
from actnn.ops import quantize_activation, dequantize_activation
except Exception:
pass

class ActSPMMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, rowptr, colind, feat, edge_weight_csr=None, sym=False):
Expand Down

0 comments on commit 35a338f

Please sign in to comment.