-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathlayer.py
95 lines (70 loc) · 2.57 KB
/
layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class GraphConvolution(nn.Module):
def __init__(self, input_dim, output_dim, num_vetex, act=F.relu, dropout=0.5, bias=True):
super(GraphConvolution, self).__init__()
self.alpha = 1.
self.act = act
self.dropout = nn.Dropout(dropout)
self.weight = nn.Parameter(torch.randn(input_dim, output_dim)).to(device)
if bias:
self.bias = nn.Parameter(torch.randn(output_dim)).to(device)
else:
self.bias = None
for w in [self.weight]:
nn.init.xavier_normal_(w)
def normalize(self, m):
rowsum = torch.sum(m, 0)
r_inv = torch.pow(rowsum, -0.5)
r_mat_inv = torch.diag(r_inv).float()
m_norm = torch.mm(r_mat_inv, m)
m_norm = torch.mm(m_norm, r_mat_inv)
return m_norm
def forward(self, adj, x):
x = self.dropout(x)
# K-ordered Chebyshev polynomial
adj_norm = self.normalize(adj)
sqr_norm = self.normalize(torch.mm(adj,adj))
m_norm = self.alpha*adj_norm + (1.-self.alpha)*sqr_norm
x_tmp = torch.einsum('abcd,de->abce', x, self.weight)
x_out = torch.einsum('ij,abid->abjd', m_norm, x_tmp)
if self.bias is not None:
x_out += self.bias
x_out = self.act(x_out)
return x_out
class StandConvolution(nn.Module):
def __init__(self, dims, num_classes, dropout):
super(StandConvolution, self).__init__()
self.dropout = nn.Dropout(dropout)
self.conv = nn.Sequential(
nn.Conv2d(dims[0], dims[1], kernel_size=5, stride=2),
nn.InstanceNorm2d(dims[1]),
nn.ReLU(inplace=True),
#nn.AvgPool2d(3, stride=2),
nn.Conv2d(dims[1], dims[2], kernel_size=5, stride=2),
nn.InstanceNorm2d(dims[2]),
nn.ReLU(inplace=True),
#nn.AvgPool2d(3, stride=2),
nn.Conv2d(dims[2], dims[3], kernel_size=5, stride=2),
nn.InstanceNorm2d(dims[3]),
nn.ReLU(inplace=True),
#nn.AvgPool2d(3, stride=2)
).to(device)
self.fc = nn.Linear(dims[3]*3, num_classes).to(device)
def forward(self, x):
x = self.dropout(x.permute(0,3,1,2))
x_tmp = self.conv(x)
x_out = self.fc(x_tmp.view(x.size(0), -1))
return x_out
class StandRecurrent(nn.Module):
def __init__(self, dims, num_classes, dropout):
super(StandRecurrent, self).__init__()
self.lstm = nn.LSTM(dims[0]*45, dims[1], batch_first=True,
dropout=0).to(device)
self.fc = nn.Linear(dims[1], num_classes).to(device)
def forward(self, x):
x_tmp,_ = self.lstm(x.contiguous().view(x.size(0), x.size(1), -1))
x_out = self.fc(x_tmp[:,-1])
return x_out