-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathMyConvLSTMCell.py
58 lines (48 loc) · 2.84 KB
/
MyConvLSTMCell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
class MyConvLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, kernel_size=3, stride=1, padding=1):
super(MyConvLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.conv_i_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv_i_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
bias=False)
self.conv_f_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv_f_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
bias=False)
self.conv_c_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv_c_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
bias=False)
self.conv_o_xx = nn.Conv2d(input_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv_o_hh = nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size, stride=stride, padding=padding,
bias=False)
torch.nn.init.xavier_normal(self.conv_i_xx.weight)
torch.nn.init.constant(self.conv_i_xx.bias, 0)
torch.nn.init.xavier_normal(self.conv_i_hh.weight)
torch.nn.init.xavier_normal(self.conv_f_xx.weight)
torch.nn.init.constant(self.conv_f_xx.bias, 0)
torch.nn.init.xavier_normal(self.conv_f_hh.weight)
torch.nn.init.xavier_normal(self.conv_c_xx.weight)
torch.nn.init.constant(self.conv_c_xx.bias, 0)
torch.nn.init.xavier_normal(self.conv_c_hh.weight)
torch.nn.init.xavier_normal(self.conv_o_xx.weight)
torch.nn.init.constant(self.conv_o_xx.bias, 0)
torch.nn.init.xavier_normal(self.conv_o_hh.weight)
def forward(self, x, state):
if state is None:
state = (Variable(torch.randn(x.size(0), x.size(1), x.size(2), x.size(3)).cuda()),
Variable(torch.randn(x.size(0), x.size(1), x.size(2), x.size(3)).cuda()))
ht_1, ct_1 = state
it = F.sigmoid(self.conv_i_xx(x) + self.conv_i_hh(ht_1))
ft = F.sigmoid(self.conv_f_xx(x) + self.conv_f_hh(ht_1))
ct_tilde = F.tanh(self.conv_c_xx(x) + self.conv_c_hh(ht_1))
ct = (ct_tilde * it) + (ct_1 * ft)
ot = F.sigmoid(self.conv_o_xx(x) + self.conv_o_hh(ht_1))
ht = ot * F.tanh(ct)
return ht, ct