Skip to content

Commit

Permalink
working LSTM (bidirectional untested)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChujieChen committed Apr 25, 2020
1 parent 1f46560 commit 3ce6f2a
Showing 1 changed file with 38 additions and 43 deletions.
81 changes: 38 additions & 43 deletions src/yews/models/polarity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch.nn as nn

from .utils import load_state_dict_from_url

__all__ = ['PolarityV1', 'polarity_v1', 'PolarityCCJ', 'polarity_ccj']
__all__ = ['PolarityV1', 'polarity_v1', 'PolarityLSTM', 'polarity_lstm']

model_urls = {
'polarity_v1': 'https://www.dropbox.com/s/ckb4glf35agi9xa/polarity_v1_wenchuan-bdd92da2.pth?dl=1',
Expand Down Expand Up @@ -100,6 +101,9 @@ def __init__(self):
self.fc = nn.Linear(64 * 1, 3)

def forward(self, x):
print("##### input to forward #####")
print(x.shape)

out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
Expand All @@ -110,14 +114,8 @@ def forward(self, x):
out = self.layer8(out)
out = self.layer9(out)
out = self.layer10(out)
print("##### before view #####")
print(out.shape)
out = out.view(out.size(0), -1)
print("##### before fc #####")
print(out.shape)
out = self.fc(out)
print("##### after fc #####")
print(out.shape)
return out

def polarity_v1(pretrained=False, progress=True, **kwargs):
Expand All @@ -138,52 +136,49 @@ def polarity_v1(pretrained=False, progress=True, **kwargs):
return model


class PolarityCCJ(nn.Module):
r"""a simple recurrent neural network from
<https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>
class PolarityLSTM(nn.Module):
r"""a LSTM neural network
@author: Chujie Chen
@Email: [email protected]
@date: 04/24/2020
"""
def __init__(self, **kwargs):
super().__init__()
input_size = 1
if hidden_size in kwargs:
hidden_size = 20
num_layers = 100
self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
self.fc = nn.Linear(hidden_size, 3)
self.hidden_size = kwargs["hidden_size"]
self.bidirectional = kwargs["bidirectional"]
self.contains_unkown = kwargs["contains_unkown"]
self.start = kwargs['start']
self.end = kwargs['end']
self.lstm = nn.LSTM(input_size, self.hidden_size,bidirectional=self.bidirectional)
self.fc = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), 3 if self.contains_unkown else 2)

def forward(self, x):
batch, input_size, seq_len = x.shape
x = x.narrow(2,self.start, self.end-self.start)
x = x.permute(2, 0, 1) # seq_len, batch, input_size
# If (h_0, c_0) is not provided, both h_0 and c_0 default to zero
out, (hidden, cell_state) = self.lstm(x)
# hidden has size [input_size,batch, hidden_size]
print("###### hidden #####")
print(hidden.shape)
x = hidden.permute(1,2,0)
x = x.view(x.size(0), -1)
print("#############")
out = self.fc(x)
print("#### out ######")
print(out.shape)
print("#############")
output, (h_n, c_n) = self.lstm(x, None)
output = output[-1:, :, :]
output = output.view(output.size(1), -1)
out = self.fc(output)
return out

def initHidden(self):
return torch.zeros(1, self.hidden_size)

def polarity_ccj(pretrained=False, progress=True, **kwargs):
r"""Original CPIC model architecture from the
`"Deep learning for ..." <https://arxiv.org/abs/1901.06396>`_ paper. The
pretrained model is trained on 60,000 Wenchuan aftershock dataset
demonstrated in the paper.
def polarity_lstm(**kwargs):
r"""A LSTM based model.
Args:
pretrained (bool): If True, returns a model pre-trained on Wenchuan)
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = PolarityCCJ(**kwargs)
# if pretrained:
# state_dict = load_state_dict_from_url(model_urls['polarity_v1'],
# progress=progress)
# model.load_state_dict(state_dict)
default_kwargs = {"hidden_size":64,
"start": 250,
"end": 350,
"bidirectional":False,
"contains_unkown":False}
for k,v in kwargs.items():
if k in default_kwargs:
default_kwargs[k] = v
print("#### model parameters ####\n")
print(default_kwargs)
print("\n##########################")
if(default_kwargs['end'] < default_kwargs['start']):
raise ValueError('<-- end must be largger than start -->')
model = PolarityLSTM(**default_kwargs)
return model

0 comments on commit 3ce6f2a

Please sign in to comment.