From d9345d8d8c9982658669442d9c42b53ed289a2a6 Mon Sep 17 00:00:00 2001 From: ChujieChen Date: Sat, 25 Apr 2020 16:51:50 +0000 Subject: [PATCH] finished LSTM for polarity --- src/yews/models/polarity.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/yews/models/polarity.py b/src/yews/models/polarity.py index 6b6f12f..22406f9 100644 --- a/src/yews/models/polarity.py +++ b/src/yews/models/polarity.py @@ -150,7 +150,9 @@ def __init__(self, **kwargs): self.contains_unkown = kwargs["contains_unkown"] self.start = kwargs['start'] self.end = kwargs['end'] - self.lstm = nn.LSTM(input_size, self.hidden_size,bidirectional=self.bidirectional) + self.num_layers = kwargs['num_layers'] + self.lstm = nn.LSTM(input_size, self.hidden_size, self.num_layers, + bidirectional=self.bidirectional) self.fc = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), 3 if self.contains_unkown else 2) def forward(self, x): @@ -163,11 +165,16 @@ def forward(self, x): return out def polarity_lstm(**kwargs): r"""A LSTM based model. - Args: - pretrained (bool): If True, returns a model pre-trained on Wenchuan) - progress (bool): If True, displays a progress bar of the download to stderr + Kwargs (form like a dict and should be pass like **kwargs): + hidden_size (default 64): recommended to be similar as the length of trimmed subsequence + num_layers (default 2): layers are stacked and results are from the final layer + start (default 250): start index of the subsequence + end (default 350): end index of the subsequence + bidirectional (default False): run lstm from left to right and from right to left + contains_unkown (default False): True if targets have 0,1,2 """ default_kwargs = {"hidden_size":64, + "num_layers":2, "start": 250, "end": 350, "bidirectional":False, @@ -179,6 +186,6 @@ def polarity_lstm(**kwargs): print(default_kwargs) print("\n##########################") if(default_kwargs['end'] < default_kwargs['start']): - raise ValueError('<-- end must be largger than start -->') + raise ValueError('<-- end cannot be smaller than start -->') model = PolarityLSTM(**default_kwargs) return model