forked from yewsg/yews
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
working LSTM (bidirectional untested)
- Loading branch information
1 parent
1f46560
commit 3ce6f2a
Showing
1 changed file
with
38 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from .utils import load_state_dict_from_url | ||
|
||
__all__ = ['PolarityV1', 'polarity_v1', 'PolarityCCJ', 'polarity_ccj'] | ||
__all__ = ['PolarityV1', 'polarity_v1', 'PolarityLSTM', 'polarity_lstm'] | ||
|
||
model_urls = { | ||
'polarity_v1': 'https://www.dropbox.com/s/ckb4glf35agi9xa/polarity_v1_wenchuan-bdd92da2.pth?dl=1', | ||
|
@@ -100,6 +101,9 @@ def __init__(self): | |
self.fc = nn.Linear(64 * 1, 3) | ||
|
||
def forward(self, x): | ||
print("##### input to forward #####") | ||
print(x.shape) | ||
|
||
out = self.layer1(x) | ||
out = self.layer2(out) | ||
out = self.layer3(out) | ||
|
@@ -110,14 +114,8 @@ def forward(self, x): | |
out = self.layer8(out) | ||
out = self.layer9(out) | ||
out = self.layer10(out) | ||
print("##### before view #####") | ||
print(out.shape) | ||
out = out.view(out.size(0), -1) | ||
print("##### before fc #####") | ||
print(out.shape) | ||
out = self.fc(out) | ||
print("##### after fc #####") | ||
print(out.shape) | ||
return out | ||
|
||
def polarity_v1(pretrained=False, progress=True, **kwargs): | ||
|
@@ -138,52 +136,49 @@ def polarity_v1(pretrained=False, progress=True, **kwargs): | |
return model | ||
|
||
|
||
class PolarityCCJ(nn.Module): | ||
r"""a simple recurrent neural network from | ||
<https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network> | ||
class PolarityLSTM(nn.Module): | ||
r"""a LSTM neural network | ||
@author: Chujie Chen | ||
@Email: [email protected] | ||
@date: 04/24/2020 | ||
""" | ||
def __init__(self, **kwargs): | ||
super().__init__() | ||
input_size = 1 | ||
if hidden_size in kwargs: | ||
hidden_size = 20 | ||
num_layers = 100 | ||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers) | ||
self.fc = nn.Linear(hidden_size, 3) | ||
self.hidden_size = kwargs["hidden_size"] | ||
self.bidirectional = kwargs["bidirectional"] | ||
self.contains_unkown = kwargs["contains_unkown"] | ||
self.start = kwargs['start'] | ||
self.end = kwargs['end'] | ||
self.lstm = nn.LSTM(input_size, self.hidden_size,bidirectional=self.bidirectional) | ||
self.fc = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), 3 if self.contains_unkown else 2) | ||
|
||
def forward(self, x): | ||
batch, input_size, seq_len = x.shape | ||
x = x.narrow(2,self.start, self.end-self.start) | ||
x = x.permute(2, 0, 1) # seq_len, batch, input_size | ||
# If (h_0, c_0) is not provided, both h_0 and c_0 default to zero | ||
out, (hidden, cell_state) = self.lstm(x) | ||
# hidden has size [input_size,batch, hidden_size] | ||
print("###### hidden #####") | ||
print(hidden.shape) | ||
x = hidden.permute(1,2,0) | ||
x = x.view(x.size(0), -1) | ||
print("#############") | ||
out = self.fc(x) | ||
print("#### out ######") | ||
print(out.shape) | ||
print("#############") | ||
output, (h_n, c_n) = self.lstm(x, None) | ||
output = output[-1:, :, :] | ||
output = output.view(output.size(1), -1) | ||
out = self.fc(output) | ||
return out | ||
|
||
def initHidden(self): | ||
return torch.zeros(1, self.hidden_size) | ||
|
||
def polarity_ccj(pretrained=False, progress=True, **kwargs): | ||
r"""Original CPIC model architecture from the | ||
`"Deep learning for ..." <https://arxiv.org/abs/1901.06396>`_ paper. The | ||
pretrained model is trained on 60,000 Wenchuan aftershock dataset | ||
demonstrated in the paper. | ||
def polarity_lstm(**kwargs): | ||
r"""A LSTM based model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on Wenchuan) | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
model = PolarityCCJ(**kwargs) | ||
# if pretrained: | ||
# state_dict = load_state_dict_from_url(model_urls['polarity_v1'], | ||
# progress=progress) | ||
# model.load_state_dict(state_dict) | ||
default_kwargs = {"hidden_size":64, | ||
"start": 250, | ||
"end": 350, | ||
"bidirectional":False, | ||
"contains_unkown":False} | ||
for k,v in kwargs.items(): | ||
if k in default_kwargs: | ||
default_kwargs[k] = v | ||
print("#### model parameters ####\n") | ||
print(default_kwargs) | ||
print("\n##########################") | ||
if(default_kwargs['end'] < default_kwargs['start']): | ||
raise ValueError('<-- end must be largger than start -->') | ||
model = PolarityLSTM(**default_kwargs) | ||
return model |