Skip to content

Commit

Permalink
Merge pull request #3 from ChujieChen/develop
Browse files Browse the repository at this point in the history
finished LSTM for polarity
  • Loading branch information
ChujieChen authored Apr 25, 2020
2 parents d270d44 + d9345d8 commit d276211
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/yews/models/polarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def __init__(self, **kwargs):
self.contains_unkown = kwargs["contains_unkown"]
self.start = kwargs['start']
self.end = kwargs['end']
self.lstm = nn.LSTM(input_size, self.hidden_size,bidirectional=self.bidirectional)
self.num_layers = kwargs['num_layers']
self.lstm = nn.LSTM(input_size, self.hidden_size, self.num_layers,
bidirectional=self.bidirectional)
self.fc = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), 3 if self.contains_unkown else 2)

def forward(self, x):
Expand All @@ -163,11 +165,16 @@ def forward(self, x):
return out
def polarity_lstm(**kwargs):
r"""A LSTM based model.
Args:
pretrained (bool): If True, returns a model pre-trained on Wenchuan)
progress (bool): If True, displays a progress bar of the download to stderr
Kwargs (form like a dict and should be pass like **kwargs):
hidden_size (default 64): recommended to be similar as the length of trimmed subsequence
num_layers (default 2): layers are stacked and results are from the final layer
start (default 250): start index of the subsequence
end (default 350): end index of the subsequence
bidirectional (default False): run lstm from left to right and from right to left
contains_unkown (default False): True if targets have 0,1,2
"""
default_kwargs = {"hidden_size":64,
"num_layers":2,
"start": 250,
"end": 350,
"bidirectional":False,
Expand All @@ -179,6 +186,6 @@ def polarity_lstm(**kwargs):
print(default_kwargs)
print("\n##########################")
if(default_kwargs['end'] < default_kwargs['start']):
raise ValueError('<-- end must be largger than start -->')
raise ValueError('<-- end cannot be smaller than start -->')
model = PolarityLSTM(**default_kwargs)
return model

0 comments on commit d276211

Please sign in to comment.