Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Correct model based on paper and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vanewu authored and hhexiy committed Jun 30, 2019
1 parent 1313fc2 commit 812865e
Showing 1 changed file with 42 additions and 19 deletions.
61 changes: 42 additions & 19 deletions scripts/esim/esim.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class ESIM(nn.HybridBlock):
""""Enhanced LSTM for Natural Language Inference" Qian Chen,
Xiaodan Zhu, Zhenhua Ling, Si Wei, Hui Jiang, Diana Inkpen. ACL (2017)
https://arxiv.org/pdf/1609.06038.pdf
Parameters
----------
Expand All @@ -46,38 +47,50 @@ class ESIM(nn.HybridBlock):
Dropout prob
"""

def __init__(self, nwords, nword_dims, nhiddens, ndense_units,
nclasses, drop_out, **kwargs):
def __init__(self, vocab_size, nword_dims, nhidden_units, ndense_units,
nclasses, dropout=0.0, **kwargs):
super(ESIM, self).__init__(**kwargs)
with self.name_scope():
self.embedding_encoder = nn.Embedding(nwords, nword_dims)
self.embedding_encoder = nn.Embedding(vocab_size, nword_dims)
self.batch_norm = nn.BatchNorm(axis=-1)
self.lstm_encoder1 = rnn.LSTM(nhiddens, bidirectional=True)
self.lstm_encoder2 = rnn.LSTM(nhiddens, bidirectional=True)
self.lstm_encoder1 = rnn.LSTM(nhidden_units, bidirectional=True)

self.projection = nn.HybridSequential()
self.projection.add(nn.BatchNorm(axis=-1),
nn.Dropout(dropout),
nn.Dense(nhidden_units, activation='relu', flatten=False))

self.lstm_encoder2 = rnn.LSTM(nhidden_units, bidirectional=True)

self.fc_encoder = nn.HybridSequential()
self.fc_encoder.add(nn.BatchNorm(axis=-1),
nn.Dropout(dropout),
nn.Dense(ndense_units),
nn.ELU(),
nn.BatchNorm(axis=-1),
nn.Dropout(drop_out),
nn.Dense(ndense_units),
nn.ELU(),
nn.BatchNorm(axis=-1),
nn.Dropout(drop_out),
nn.Dropout(dropout),
nn.Dense(nclasses))

self.avg_pool = nn.GlobalAvgPool1D()
self.max_pool = nn.GlobalMaxPool1D()

def _soft_attention_align(self, F, x1, x2, mask1, mask2):
# attention shape: (batch, seq_len, seq_len)
# x1 shape: (batch, x1_seq_len, nhidden_units*2)
# x2 shape: (batch, x2_seq_len, nhidden_units*2)
# mask1 shape: (batch, x1_seq_len)
# mask2 shape: (batch, x2_seq_len)
# attention shape: (batch, x1_seq_len, x2_seq_len)
attention = F.batch_dot(x1, x2, transpose_b=True)

# weight1 shape: (batch, x1_seq_len, x2_seq_len)
weight1 = F.softmax(attention + F.expand_dims(mask2, axis=1), axis=-1)
# x1_align shape: (batch, x1_seq_len, nhidden_units*2)
x1_align = F.batch_dot(weight1, x2)
weight2 = F.softmax(attention + F.expand_dims(mask1, axis=1), axis=-1)
x2_align = F.batch_dot(weight2, x1)

# weight2 shape: (batch, x1_seq_len, x2_seq_len)
weight2 = F.softmax(attention + F.expand_dims(mask1, axis=2), axis=1)
# x2_align shape: (batch, x2_seq_len, nhidden_units*2)
x2_align = F.batch_dot(weight2, x1, transpose_a=True)

return x1_align, x2_align

Expand All @@ -87,14 +100,19 @@ def _submul(self, F, x1, x2):

return F.concat(mul, sub, dim=-1)

def _apply_multiple(self, F, x):
def _pooling(self, F, x):
# x : NCW C <----> input channels W <----> seq_len
# p1, p2 shape: (batch, input channels)
p1 = F.squeeze(self.avg_pool(x), axis=-1)
p2 = F.squeeze(self.max_pool(x), axis=-1)

return F.concat(p1, p2, dim=-1)

def hybrid_forward(self, F, x1, x2, mask1, mask2): # pylint: disable=arguments-differ
# x1_embed x2_embed shape: (batch, seq_len, nword_dims)
# x1, x2 shape: (batch, x1_seq_len), (batch, x2_seq_len)
# mask1, mask2 shape: (batch, x1_seq_len), (batch, x2_seq_len)
# x1_embed shape: (batch, x1_seq_len, nword_dims)
# x2_embed shape: (batch, x2_seq_len, nword_dims)
x1_embed = self.batch_norm(self.embedding_encoder(x1))
x2_embed = self.batch_norm(self.embedding_encoder(x2))

Expand All @@ -111,12 +129,17 @@ def hybrid_forward(self, F, x1, x2, mask1, mask2): # pylint: disable=arguments-
x2_combined = F.concat(x2_lstm_encode, x2_algin,
self._submul(F, x2_lstm_encode, x2_algin), dim=-1)

x1_compose = self.lstm_encoder2(x1_combined)
x2_compose = self.lstm_encoder2(x2_combined)
# x1_compose shape: (batch, x1_seq_len, nhidden_units*2)
# x2_compose shape: (batch, x2_seq_len, nhidden_units*2)
x1_compose = self.lstm_encoder2(self.projection(x1_combined))
x2_compose = self.lstm_encoder2(self.projection(x2_combined))

# aggregate
x1_agg = self._apply_multiple(F, x1_compose)
x2_agg = self._apply_multiple(F, x2_compose)
# NWC ------> NCW
x1_compose = F.transpose(x1_compose, axes=(0, 2, 1))
x2_compose = F.transpose(x2_compose, axes=(0, 2, 1))
x1_agg = self._pooling(F, x1_compose)
x2_agg = self._pooling(F, x2_compose)

# fully connection
output = self.fc_encoder(F.concat(x1_agg, x2_agg, dim=-1))
Expand Down

0 comments on commit 812865e

Please sign in to comment.